Browse Source

Optimize update_working_tree and prevent overwriting uncommitted changes

- Replace iter_tree_contents with diff_tree for massive performance improvement
- Only process files that actually changed between trees
- Add safety check to prevent overwriting uncommitted changes by default
Jelmer Vernooij 1 month ago
parent
commit
257174490a
6 changed files with 613 additions and 207 deletions
  1. 4 0
      dulwich/errors.py
  2. 239 134
      dulwich/index.py
  3. 66 19
      dulwich/porcelain.py
  4. 4 0
      dulwich/stash.py
  5. 87 41
      tests/test_index.py
  6. 213 13
      tests/test_porcelain.py

+ 4 - 0
dulwich/errors.py

@@ -180,3 +180,7 @@ class RefFormatError(Exception):
 
 
 class HookError(Exception):
 class HookError(Exception):
     """An error occurred while executing a hook."""
     """An error occurred while executing a hook."""
+
+
+class WorkingTreeModifiedError(Exception):
+    """Indicates that the working tree has modifications that would be overwritten."""

+ 239 - 134
dulwich/index.py

@@ -42,6 +42,7 @@ from typing import (
 )
 )
 
 
 if TYPE_CHECKING:
 if TYPE_CHECKING:
+    from .diff_tree import TreeChange
     from .file import _GitFile
     from .file import _GitFile
     from .line_ending import BlobNormalizer
     from .line_ending import BlobNormalizer
     from .repo import Repo
     from .repo import Repo
@@ -1613,7 +1614,7 @@ def _check_symlink_matches(
 ) -> bool:
 ) -> bool:
     """Check if symlink target matches expected target.
     """Check if symlink target matches expected target.
 
 
-    Returns True if symlink needs to be written, False if it matches.
+    Returns True if symlink matches, False if it doesn't match.
     """
     """
     try:
     try:
         current_target = os.readlink(full_path)
         current_target = os.readlink(full_path)
@@ -1621,14 +1622,14 @@ def _check_symlink_matches(
         expected_target = blob_obj.as_raw_string()
         expected_target = blob_obj.as_raw_string()
         if isinstance(current_target, str):
         if isinstance(current_target, str):
             current_target = current_target.encode()
             current_target = current_target.encode()
-        return current_target != expected_target
+        return current_target == expected_target
     except FileNotFoundError:
     except FileNotFoundError:
         # Symlink doesn't exist
         # Symlink doesn't exist
-        return True
+        return False
     except OSError as e:
     except OSError as e:
         if e.errno == errno.EINVAL:
         if e.errno == errno.EINVAL:
             # Not a symlink
             # Not a symlink
-            return True
+            return False
         raise
         raise
 
 
 
 
@@ -1644,19 +1645,43 @@ def _check_file_matches(
 ) -> bool:
 ) -> bool:
     """Check if a file on disk matches the expected git object.
     """Check if a file on disk matches the expected git object.
 
 
-    Returns True if file needs to be written, False if it matches.
+    Returns True if file matches, False if it doesn't match.
     """
     """
     # Check mode first (if honor_filemode is True)
     # Check mode first (if honor_filemode is True)
     if honor_filemode:
     if honor_filemode:
         current_mode = stat.S_IMODE(current_stat.st_mode)
         current_mode = stat.S_IMODE(current_stat.st_mode)
         expected_mode = stat.S_IMODE(entry_mode)
         expected_mode = stat.S_IMODE(entry_mode)
-        if current_mode != expected_mode:
-            return True
+
+        # For regular files, only check the user executable bit, not group/other permissions
+        # This matches Git's behavior where umask differences don't count as modifications
+        if stat.S_ISREG(current_stat.st_mode):
+            # Normalize regular file modes to ignore group/other write permissions
+            current_mode_normalized = (
+                current_mode & 0o755
+            )  # Keep only user rwx and all read+execute
+            expected_mode_normalized = expected_mode & 0o755
+
+            # For Git compatibility, regular files should be either 644 or 755
+            if expected_mode_normalized not in (0o644, 0o755):
+                expected_mode_normalized = 0o644  # Default for regular files
+            if current_mode_normalized not in (0o644, 0o755):
+                # Determine if it should be executable based on user execute bit
+                if current_mode & 0o100:  # User execute bit is set
+                    current_mode_normalized = 0o755
+                else:
+                    current_mode_normalized = 0o644
+
+            if current_mode_normalized != expected_mode_normalized:
+                return False
+        else:
+            # For non-regular files (symlinks, etc.), check mode exactly
+            if current_mode != expected_mode:
+                return False
 
 
     # If mode matches (or we don't care), check content via size first
     # If mode matches (or we don't care), check content via size first
     blob_obj = repo_object_store[entry_sha]
     blob_obj = repo_object_store[entry_sha]
     if current_stat.st_size != blob_obj.raw_length():
     if current_stat.st_size != blob_obj.raw_length():
-        return True
+        return False
 
 
     # Size matches, check actual content
     # Size matches, check actual content
     try:
     try:
@@ -1668,9 +1693,9 @@ def _check_file_matches(
                     blob_obj, tree_path
                     blob_obj, tree_path
                 )
                 )
                 expected_content = normalized_blob.as_raw_string()
                 expected_content = normalized_blob.as_raw_string()
-            return current_content != expected_content
+            return current_content == expected_content
     except (FileNotFoundError, PermissionError, IsADirectoryError):
     except (FileNotFoundError, PermissionError, IsADirectoryError):
-        return True
+        return False
 
 
 
 
 def _transition_to_submodule(repo, path, full_path, current_stat, entry, index):
 def _transition_to_submodule(repo, path, full_path, current_stat, entry, index):
@@ -1710,7 +1735,7 @@ def _transition_to_file(
         and not stat.S_ISLNK(entry.mode)
         and not stat.S_ISLNK(entry.mode)
     ):
     ):
         # File to file - check if update needed
         # File to file - check if update needed
-        needs_update = _check_file_matches(
+        file_matches = _check_file_matches(
             object_store,
             object_store,
             full_path,
             full_path,
             entry.sha,
             entry.sha,
@@ -1720,13 +1745,15 @@ def _transition_to_file(
             blob_normalizer,
             blob_normalizer,
             path,
             path,
         )
         )
+        needs_update = not file_matches
     elif (
     elif (
         current_stat is not None
         current_stat is not None
         and stat.S_ISLNK(current_stat.st_mode)
         and stat.S_ISLNK(current_stat.st_mode)
         and stat.S_ISLNK(entry.mode)
         and stat.S_ISLNK(entry.mode)
     ):
     ):
         # Symlink to symlink - check if update needed
         # Symlink to symlink - check if update needed
-        needs_update = _check_symlink_matches(full_path, object_store, entry.sha)
+        symlink_matches = _check_symlink_matches(full_path, object_store, entry.sha)
+        needs_update = not symlink_matches
     else:
     else:
         needs_update = True
         needs_update = True
 
 
@@ -1814,12 +1841,14 @@ def update_working_tree(
     repo: "Repo",
     repo: "Repo",
     old_tree_id: Optional[bytes],
     old_tree_id: Optional[bytes],
     new_tree_id: bytes,
     new_tree_id: bytes,
+    change_iterator: Iterator["TreeChange"],
     honor_filemode: bool = True,
     honor_filemode: bool = True,
     validate_path_element: Optional[Callable[[bytes], bool]] = None,
     validate_path_element: Optional[Callable[[bytes], bool]] = None,
     symlink_fn: Optional[Callable] = None,
     symlink_fn: Optional[Callable] = None,
     force_remove_untracked: bool = False,
     force_remove_untracked: bool = False,
     blob_normalizer: Optional["BlobNormalizer"] = None,
     blob_normalizer: Optional["BlobNormalizer"] = None,
     tree_encoding: str = "utf-8",
     tree_encoding: str = "utf-8",
+    allow_overwrite_modified: bool = False,
 ) -> None:
 ) -> None:
     """Update the working tree and index to match a new tree.
     """Update the working tree and index to match a new tree.
 
 
@@ -1833,6 +1862,7 @@ def update_working_tree(
       repo: Repository object
       repo: Repository object
       old_tree_id: SHA of the tree before the update
       old_tree_id: SHA of the tree before the update
       new_tree_id: SHA of the tree to update to
       new_tree_id: SHA of the tree to update to
+      change_iterator: Iterator of TreeChange objects to apply
       honor_filemode: An optional flag to honor core.filemode setting
       honor_filemode: An optional flag to honor core.filemode setting
       validate_path_element: Function to validate path elements to check out
       validate_path_element: Function to validate path elements to check out
       symlink_fn: Function to use for creating symlinks
       symlink_fn: Function to use for creating symlinks
@@ -1841,168 +1871,243 @@ def update_working_tree(
       blob_normalizer: An optional BlobNormalizer to use for converting line
       blob_normalizer: An optional BlobNormalizer to use for converting line
         endings when writing blobs to the working directory.
         endings when writing blobs to the working directory.
       tree_encoding: Encoding used for tree paths (default: utf-8)
       tree_encoding: Encoding used for tree paths (default: utf-8)
+      allow_overwrite_modified: If False, raise an error when attempting to
+        overwrite files that have been modified compared to old_tree_id
     """
     """
     if validate_path_element is None:
     if validate_path_element is None:
         validate_path_element = validate_path_element_default
         validate_path_element = validate_path_element_default
 
 
+    from .diff_tree import (
+        CHANGE_ADD,
+        CHANGE_COPY,
+        CHANGE_DELETE,
+        CHANGE_MODIFY,
+        CHANGE_RENAME,
+        CHANGE_UNCHANGED,
+    )
+
     repo_path = repo.path if isinstance(repo.path, bytes) else repo.path.encode()
     repo_path = repo.path if isinstance(repo.path, bytes) else repo.path.encode()
     index = repo.open_index()
     index = repo.open_index()
 
 
-    # Build sets of paths for efficient lookup
-    new_paths = {}
-    for entry in iter_tree_contents(repo.object_store, new_tree_id):
-        if entry.path.startswith(b".git") or not validate_path(
-            entry.path, validate_path_element
-        ):
-            continue
-        new_paths[entry.path] = entry
-
-    old_paths = {}
-    if old_tree_id:
-        for entry in iter_tree_contents(repo.object_store, old_tree_id):
-            if not entry.path.startswith(b".git"):
-                old_paths[entry.path] = entry
-
-    # Process all paths
-    all_paths = set(new_paths.keys()) | set(old_paths.keys())
-
-    # Check for paths that need to become directories
-    paths_needing_dir = set()
-    for path in new_paths:
-        parts = path.split(b"/")
-        for i in range(1, len(parts)):
-            parent = b"/".join(parts[:i])
-            if parent in old_paths and parent not in new_paths:
-                paths_needing_dir.add(parent)
+    # Convert iterator to list since we need multiple passes
+    changes = list(change_iterator)
+
+    # Check for path conflicts where files need to become directories
+    paths_becoming_dirs = set()
+    for change in changes:
+        if change.type in (CHANGE_ADD, CHANGE_MODIFY, CHANGE_RENAME, CHANGE_COPY):
+            path = change.new.path
+            if b"/" in path:  # This is a file inside a directory
+                # Check if any parent path exists as a file in the old tree or changes
+                parts = path.split(b"/")
+                for i in range(1, len(parts)):
+                    parent = b"/".join(parts[:i])
+                    # See if this parent path is being deleted (was a file, becoming a dir)
+                    for other_change in changes:
+                        if (
+                            other_change.type == CHANGE_DELETE
+                            and other_change.old
+                            and other_change.old.path == parent
+                        ):
+                            paths_becoming_dirs.add(parent)
 
 
     # Check if any path that needs to become a directory has been modified
     # Check if any path that needs to become a directory has been modified
-    current_stat: Optional[os.stat_result]
-    stat_cache: dict[bytes, Optional[os.stat_result]] = {}
-    for path in paths_needing_dir:
+    for path in paths_becoming_dirs:
         full_path = _tree_to_fs_path(repo_path, path, tree_encoding)
         full_path = _tree_to_fs_path(repo_path, path, tree_encoding)
         try:
         try:
             current_stat = os.lstat(full_path)
             current_stat = os.lstat(full_path)
         except FileNotFoundError:
         except FileNotFoundError:
-            # File doesn't exist, proceed
-            stat_cache[full_path] = None
-        except PermissionError:
-            # Can't read file, proceed
-            pass
-        else:
-            stat_cache[full_path] = current_stat
-            if stat.S_ISREG(current_stat.st_mode):
+            continue  # File doesn't exist, nothing to check
+        except OSError as e:
+            raise OSError(
+                f"Cannot access {path.decode('utf-8', errors='replace')}: {e}"
+            ) from e
+
+        if stat.S_ISREG(current_stat.st_mode):
+            # Find the old entry for this path
+            old_change = None
+            for change in changes:
+                if (
+                    change.type == CHANGE_DELETE
+                    and change.old
+                    and change.old.path == path
+                ):
+                    old_change = change
+                    break
+
+            if old_change:
                 # Check if file has been modified
                 # Check if file has been modified
-                old_entry = old_paths[path]
-                if _check_file_matches(
+                file_matches = _check_file_matches(
                     repo.object_store,
                     repo.object_store,
                     full_path,
                     full_path,
-                    old_entry.sha,
-                    old_entry.mode,
+                    old_change.old.sha,
+                    old_change.old.mode,
                     current_stat,
                     current_stat,
                     honor_filemode,
                     honor_filemode,
                     blob_normalizer,
                     blob_normalizer,
                     path,
                     path,
-                ):
-                    # File has been modified, can't replace with directory
+                )
+                if not file_matches:
                     raise OSError(
                     raise OSError(
                         f"Cannot replace modified file with directory: {path!r}"
                         f"Cannot replace modified file with directory: {path!r}"
                     )
                     )
 
 
-    # Process in two passes: deletions first, then additions/updates
-    # This handles case-only renames on case-insensitive filesystems correctly
-    paths_to_remove = []
-    paths_to_update = []
-
-    for path in sorted(all_paths):
-        if path in new_paths:
-            paths_to_update.append(path)
-        else:
-            paths_to_remove.append(path)
+    # Check for uncommitted modifications before making any changes
+    if not allow_overwrite_modified and old_tree_id:
+        for change in changes:
+            # Only check files that are being modified or deleted
+            if change.type in (CHANGE_MODIFY, CHANGE_DELETE) and change.old:
+                path = change.old.path
+                if path.startswith(b".git") or not validate_path(
+                    path, validate_path_element
+                ):
+                    continue
 
 
-    # First process removals
-    for path in paths_to_remove:
-        full_path = _tree_to_fs_path(repo_path, path, tree_encoding)
+                full_path = _tree_to_fs_path(repo_path, path, tree_encoding)
+                try:
+                    current_stat = os.lstat(full_path)
+                except FileNotFoundError:
+                    continue  # File doesn't exist, nothing to check
+                except OSError as e:
+                    raise OSError(
+                        f"Cannot access {path.decode('utf-8', errors='replace')}: {e}"
+                    ) from e
+
+                if stat.S_ISREG(current_stat.st_mode):
+                    # Check if working tree file differs from old tree
+                    file_matches = _check_file_matches(
+                        repo.object_store,
+                        full_path,
+                        change.old.sha,
+                        change.old.mode,
+                        current_stat,
+                        honor_filemode,
+                        blob_normalizer,
+                        path,
+                    )
+                    if not file_matches:
+                        from .errors import WorkingTreeModifiedError
+
+                        raise WorkingTreeModifiedError(
+                            f"Your local changes to '{path.decode('utf-8', errors='replace')}' "
+                            f"would be overwritten by checkout. "
+                            f"Please commit your changes or stash them before you switch branches."
+                        )
+
+    # Apply the changes
+    for change in changes:
+        if change.type == CHANGE_DELETE:
+            # Remove file/directory
+            path = change.old.path
+            if path.startswith(b".git") or not validate_path(
+                path, validate_path_element
+            ):
+                continue
 
 
-        # Determine current state - use cache if available
-        try:
-            current_stat = stat_cache[full_path]
-        except KeyError:
+            full_path = _tree_to_fs_path(repo_path, path, tree_encoding)
             try:
             try:
-                current_stat = os.lstat(full_path)
+                delete_stat: Optional[os.stat_result] = os.lstat(full_path)
             except FileNotFoundError:
             except FileNotFoundError:
-                current_stat = None
+                delete_stat = None
+            except OSError as e:
+                raise OSError(
+                    f"Cannot access {path.decode('utf-8', errors='replace')}: {e}"
+                ) from e
 
 
-        _transition_to_absent(repo, path, full_path, current_stat, index)
+            _transition_to_absent(repo, path, full_path, delete_stat, index)
 
 
-    # Then process additions/updates
-    for path in paths_to_update:
-        full_path = _tree_to_fs_path(repo_path, path, tree_encoding)
+        elif change.type in (CHANGE_ADD, CHANGE_MODIFY, CHANGE_UNCHANGED):
+            # Add or modify file
+            path = change.new.path
+            if path.startswith(b".git") or not validate_path(
+                path, validate_path_element
+            ):
+                continue
 
 
-        # Determine current state - use cache if available
-        try:
-            current_stat = stat_cache[full_path]
-        except KeyError:
+            full_path = _tree_to_fs_path(repo_path, path, tree_encoding)
             try:
             try:
-                current_stat = os.lstat(full_path)
+                modify_stat: Optional[os.stat_result] = os.lstat(full_path)
             except FileNotFoundError:
             except FileNotFoundError:
-                current_stat = None
+                modify_stat = None
+            except OSError as e:
+                raise OSError(
+                    f"Cannot access {path.decode('utf-8', errors='replace')}: {e}"
+                ) from e
+
+            if S_ISGITLINK(change.new.mode):
+                _transition_to_submodule(
+                    repo, path, full_path, modify_stat, change.new, index
+                )
+            else:
+                _transition_to_file(
+                    repo.object_store,
+                    path,
+                    full_path,
+                    modify_stat,
+                    change.new,
+                    index,
+                    honor_filemode,
+                    symlink_fn,
+                    blob_normalizer,
+                    tree_encoding,
+                )
 
 
-        new_entry = new_paths[path]
+        elif change.type in (CHANGE_RENAME, CHANGE_COPY):
+            # Handle rename/copy: remove old, add new
+            old_path = change.old.path
+            new_path = change.new.path
 
 
-        # Path should exist
-        if S_ISGITLINK(new_entry.mode):
-            _transition_to_submodule(
-                repo, path, full_path, current_stat, new_entry, index
-            )
-        else:
-            _transition_to_file(
-                repo.object_store,
-                path,
-                full_path,
-                current_stat,
-                new_entry,
-                index,
-                honor_filemode,
-                symlink_fn,
-                blob_normalizer,
-                tree_encoding,
-            )
+            if not old_path.startswith(b".git") and validate_path(
+                old_path, validate_path_element
+            ):
+                old_full_path = _tree_to_fs_path(repo_path, old_path, tree_encoding)
+                try:
+                    old_current_stat = os.lstat(old_full_path)
+                except FileNotFoundError:
+                    old_current_stat = None
+                except OSError as e:
+                    raise OSError(
+                        f"Cannot access {old_path.decode('utf-8', errors='replace')}: {e}"
+                    ) from e
+                _transition_to_absent(
+                    repo, old_path, old_full_path, old_current_stat, index
+                )
 
 
-    # Handle force_remove_untracked
-    if force_remove_untracked:
-        for root, dirs, files in os.walk(repo_path):
-            if b".git" in os.fsencode(root):
-                continue
-            root_bytes = os.fsencode(root)
-            for file in files:
-                full_path = os.path.join(root_bytes, os.fsencode(file))
-                tree_path = os.path.relpath(full_path, repo_path)
-                if os.sep != "/":
-                    tree_path = tree_path.replace(os.sep.encode(), b"/")
-
-                if tree_path not in new_paths:
-                    _remove_file_with_readonly_handling(full_path)
-                    if tree_path in index:
-                        del index[tree_path]
-
-        # Clean up empty directories
-        for root, dirs, files in os.walk(repo_path, topdown=False):
-            root_bytes = os.fsencode(root)
-            if (
-                b".git" not in root_bytes
-                and root_bytes != repo_path
-                and not files
-                and not dirs
+            if not new_path.startswith(b".git") and validate_path(
+                new_path, validate_path_element
             ):
             ):
+                new_full_path = _tree_to_fs_path(repo_path, new_path, tree_encoding)
                 try:
                 try:
-                    os.rmdir(root)
+                    new_current_stat = os.lstat(new_full_path)
                 except FileNotFoundError:
                 except FileNotFoundError:
-                    # Directory was already removed
-                    pass
+                    new_current_stat = None
                 except OSError as e:
                 except OSError as e:
-                    if e.errno != errno.ENOTEMPTY:
-                        # Only ignore "directory not empty" errors
-                        raise
+                    raise OSError(
+                        f"Cannot access {new_path.decode('utf-8', errors='replace')}: {e}"
+                    ) from e
+
+                if S_ISGITLINK(change.new.mode):
+                    _transition_to_submodule(
+                        repo,
+                        new_path,
+                        new_full_path,
+                        new_current_stat,
+                        change.new,
+                        index,
+                    )
+                else:
+                    _transition_to_file(
+                        repo.object_store,
+                        new_path,
+                        new_full_path,
+                        new_current_stat,
+                        change.new,
+                        index,
+                        honor_filemode,
+                        symlink_fn,
+                        blob_normalizer,
+                        tree_encoding,
+                    )
 
 
     index.write()
     index.write()
 
 

+ 66 - 19
dulwich/porcelain.py

@@ -85,6 +85,7 @@ import stat
 import sys
 import sys
 import time
 import time
 from collections import namedtuple
 from collections import namedtuple
+from collections.abc import Iterator
 from contextlib import closing, contextmanager
 from contextlib import closing, contextmanager
 from dataclasses import dataclass
 from dataclasses import dataclass
 from io import BytesIO, RawIOBase
 from io import BytesIO, RawIOBase
@@ -103,6 +104,8 @@ from .diff_tree import (
     CHANGE_MODIFY,
     CHANGE_MODIFY,
     CHANGE_RENAME,
     CHANGE_RENAME,
     RENAME_CHANGE_TYPES,
     RENAME_CHANGE_TYPES,
+    TreeChange,
+    tree_changes,
 )
 )
 from .errors import SendPackError
 from .errors import SendPackError
 from .graph import can_fast_forward
 from .graph import can_fast_forward
@@ -1896,13 +1899,6 @@ def reset(repo, mode, treeish: Union[str, bytes, Commit, Tree, Tag] = "HEAD") ->
 
 
         elif mode == "hard":
         elif mode == "hard":
             # Hard reset: update HEAD, index, and working tree
             # Hard reset: update HEAD, index, and working tree
-            # Get current HEAD tree for comparison
-            try:
-                current_head = r.refs[b"HEAD"]
-                current_tree = r[current_head].tree
-            except KeyError:
-                current_tree = None
-
             # Get configuration for working directory update
             # Get configuration for working directory update
             config = r.get_config()
             config = r.get_config()
             honor_filemode = config.get_boolean(b"core", b"filemode", os.name != "nt")
             honor_filemode = config.get_boolean(b"core", b"filemode", os.name != "nt")
@@ -1929,15 +1925,28 @@ def reset(repo, mode, treeish: Union[str, bytes, Commit, Tree, Tag] = "HEAD") ->
 
 
             # Update working tree and index
             # Update working tree and index
             blob_normalizer = r.get_blob_normalizer()
             blob_normalizer = r.get_blob_normalizer()
+            # For reset --hard, use current index tree as old tree to get proper deletions
+            index = r.open_index()
+            if len(index) > 0:
+                index_tree_id = index.commit(r.object_store)
+            else:
+                # Empty index
+                index_tree_id = None
+
+            changes = tree_changes(
+                r.object_store, index_tree_id, tree.id, want_unchanged=True
+            )
             update_working_tree(
             update_working_tree(
                 r,
                 r,
-                current_tree,
+                index_tree_id,
                 tree.id,
                 tree.id,
+                change_iterator=changes,
                 honor_filemode=honor_filemode,
                 honor_filemode=honor_filemode,
                 validate_path_element=validate_path_element,
                 validate_path_element=validate_path_element,
                 symlink_fn=symlink_fn,
                 symlink_fn=symlink_fn,
                 force_remove_untracked=True,
                 force_remove_untracked=True,
                 blob_normalizer=blob_normalizer,
                 blob_normalizer=blob_normalizer,
+                allow_overwrite_modified=True,  # Allow overwriting modified files
             )
             )
         else:
         else:
             raise Error(f"Invalid reset mode: {mode}")
             raise Error(f"Invalid reset mode: {mode}")
@@ -2106,6 +2115,8 @@ def pull(
       fast_forward: If True, raise an exception when fast-forward is not possible
       fast_forward: If True, raise an exception when fast-forward is not possible
       ff_only: If True, only allow fast-forward merges. Raises DivergedBranches
       ff_only: If True, only allow fast-forward merges. Raises DivergedBranches
         when branches have diverged rather than performing a merge.
         when branches have diverged rather than performing a merge.
+      force: If True, allow overwriting local changes in the working tree.
+        If False, pull will abort if it would overwrite uncommitted changes.
       filter_spec: A git-rev-list-style object filter spec, as an ASCII string.
       filter_spec: A git-rev-list-style object filter spec, as an ASCII string.
         Only used if the server supports the Git protocol-v2 'filter'
         Only used if the server supports the Git protocol-v2 'filter'
         feature, and ignored otherwise.
         feature, and ignored otherwise.
@@ -2181,8 +2192,14 @@ def pull(
         if not merged and old_tree_id is not None:
         if not merged and old_tree_id is not None:
             new_tree_id = r[b"HEAD"].tree
             new_tree_id = r[b"HEAD"].tree
             blob_normalizer = r.get_blob_normalizer()
             blob_normalizer = r.get_blob_normalizer()
+            changes = tree_changes(r.object_store, old_tree_id, new_tree_id)
             update_working_tree(
             update_working_tree(
-                r, old_tree_id, new_tree_id, blob_normalizer=blob_normalizer
+                r,
+                old_tree_id,
+                new_tree_id,
+                change_iterator=changes,
+                blob_normalizer=blob_normalizer,
+                allow_overwrite_modified=force,
             )
             )
         if remote_name is not None:
         if remote_name is not None:
             _import_remote_refs(r.refs, remote_name, fetch_result.refs)
             _import_remote_refs(r.refs, remote_name, fetch_result.refs)
@@ -3319,15 +3336,20 @@ def checkout(
         blob_normalizer = r.get_blob_normalizer()
         blob_normalizer = r.get_blob_normalizer()
 
 
         # Update working tree
         # Update working tree
+        tree_change_iterator: Iterator[TreeChange] = tree_changes(
+            r.object_store, current_tree_id, target_tree_id
+        )
         update_working_tree(
         update_working_tree(
             r,
             r,
             current_tree_id,
             current_tree_id,
             target_tree_id,
             target_tree_id,
+            change_iterator=tree_change_iterator,
             honor_filemode=honor_filemode,
             honor_filemode=honor_filemode,
             validate_path_element=validate_path_element,
             validate_path_element=validate_path_element,
             symlink_fn=symlink_fn,
             symlink_fn=symlink_fn,
             force_remove_untracked=force,
             force_remove_untracked=force,
             blob_normalizer=blob_normalizer,
             blob_normalizer=blob_normalizer,
+            allow_overwrite_modified=force,
         )
         )
 
 
         # Update HEAD
         # Update HEAD
@@ -3829,7 +3851,10 @@ def _do_merge(
         # Fast-forward merge
         # Fast-forward merge
         r.refs[b"HEAD"] = merge_commit_id
         r.refs[b"HEAD"] = merge_commit_id
         # Update the working directory
         # Update the working directory
-        update_working_tree(r, head_commit.tree, merge_commit.tree)
+        changes = tree_changes(r.object_store, head_commit.tree, merge_commit.tree)
+        update_working_tree(
+            r, head_commit.tree, merge_commit.tree, change_iterator=changes
+        )
         return (merge_commit_id, [])
         return (merge_commit_id, [])
 
 
     if base_commit_id == merge_commit_id:
     if base_commit_id == merge_commit_id:
@@ -3848,7 +3873,8 @@ def _do_merge(
     r.object_store.add_object(merged_tree)
     r.object_store.add_object(merged_tree)
 
 
     # Update index and working directory
     # Update index and working directory
-    update_working_tree(r, head_commit.tree, merged_tree.id)
+    changes = tree_changes(r.object_store, head_commit.tree, merged_tree.id)
+    update_working_tree(r, head_commit.tree, merged_tree.id, change_iterator=changes)
 
 
     if conflicts or no_commit:
     if conflicts or no_commit:
         # Don't create a commit if there are conflicts or no_commit is True
         # Don't create a commit if there are conflicts or no_commit is True
@@ -4134,7 +4160,15 @@ def cherry_pick(
         r.reset_index(merged_tree.id)
         r.reset_index(merged_tree.id)
 
 
         # Update working tree from the new index
         # Update working tree from the new index
-        update_working_tree(r, head_commit.tree, merged_tree.id)
+        # Allow overwriting because we're applying the merge result
+        changes = tree_changes(r.object_store, head_commit.tree, merged_tree.id)
+        update_working_tree(
+            r,
+            head_commit.tree,
+            merged_tree.id,
+            change_iterator=changes,
+            allow_overwrite_modified=True,
+        )
 
 
         if conflicts:
         if conflicts:
             # Save state for later continuation
             # Save state for later continuation
@@ -4248,7 +4282,10 @@ def revert(
 
 
             if conflicts:
             if conflicts:
                 # Update working tree with conflicts
                 # Update working tree with conflicts
-                update_working_tree(r, current_tree, merged_tree.id)
+                changes = tree_changes(r.object_store, current_tree, merged_tree.id)
+                update_working_tree(
+                    r, current_tree, merged_tree.id, change_iterator=changes
+                )
                 conflicted_paths = [c.decode("utf-8", "replace") for c in conflicts]
                 conflicted_paths = [c.decode("utf-8", "replace") for c in conflicts]
                 raise Error(f"Conflicts while reverting: {', '.join(conflicted_paths)}")
                 raise Error(f"Conflicts while reverting: {', '.join(conflicted_paths)}")
 
 
@@ -4256,7 +4293,10 @@ def revert(
             r.object_store.add_object(merged_tree)
             r.object_store.add_object(merged_tree)
 
 
             # Update working tree
             # Update working tree
-            update_working_tree(r, current_tree, merged_tree.id)
+            changes = tree_changes(r.object_store, current_tree, merged_tree.id)
+            update_working_tree(
+                r, current_tree, merged_tree.id, change_iterator=changes
+            )
             current_tree = merged_tree.id
             current_tree = merged_tree.id
 
 
             if not no_commit:
             if not no_commit:
@@ -4831,7 +4871,8 @@ def bisect_start(
                 old_tree = r[r.head()].tree if r.head() else None
                 old_tree = r[r.head()].tree if r.head() else None
                 r.refs[b"HEAD"] = next_sha
                 r.refs[b"HEAD"] = next_sha
                 commit = r[next_sha]
                 commit = r[next_sha]
-                update_working_tree(r, old_tree, commit.tree)
+                changes = tree_changes(r.object_store, old_tree, commit.tree)
+                update_working_tree(r, old_tree, commit.tree, change_iterator=changes)
             return next_sha
             return next_sha
 
 
 
 
@@ -4855,7 +4896,8 @@ def bisect_bad(repo=".", rev: Optional[Union[str, bytes, Commit, Tag]] = None):
             old_tree = r[r.head()].tree if r.head() else None
             old_tree = r[r.head()].tree if r.head() else None
             r.refs[b"HEAD"] = next_sha
             r.refs[b"HEAD"] = next_sha
             commit = r[next_sha]
             commit = r[next_sha]
-            update_working_tree(r, old_tree, commit.tree)
+            changes = tree_changes(r.object_store, old_tree, commit.tree)
+            update_working_tree(r, old_tree, commit.tree, change_iterator=changes)
 
 
         return next_sha
         return next_sha
 
 
@@ -4880,7 +4922,8 @@ def bisect_good(repo=".", rev: Optional[Union[str, bytes, Commit, Tag]] = None):
             old_tree = r[r.head()].tree if r.head() else None
             old_tree = r[r.head()].tree if r.head() else None
             r.refs[b"HEAD"] = next_sha
             r.refs[b"HEAD"] = next_sha
             commit = r[next_sha]
             commit = r[next_sha]
-            update_working_tree(r, old_tree, commit.tree)
+            changes = tree_changes(r.object_store, old_tree, commit.tree)
+            update_working_tree(r, old_tree, commit.tree, change_iterator=changes)
 
 
         return next_sha
         return next_sha
 
 
@@ -4918,7 +4961,8 @@ def bisect_skip(
             old_tree = r[r.head()].tree if r.head() else None
             old_tree = r[r.head()].tree if r.head() else None
             r.refs[b"HEAD"] = next_sha
             r.refs[b"HEAD"] = next_sha
             commit = r[next_sha]
             commit = r[next_sha]
-            update_working_tree(r, old_tree, commit.tree)
+            changes = tree_changes(r.object_store, old_tree, commit.tree)
+            update_working_tree(r, old_tree, commit.tree, change_iterator=changes)
 
 
         return next_sha
         return next_sha
 
 
@@ -4946,7 +4990,10 @@ def bisect_reset(repo=".", commit: Optional[Union[str, bytes, Commit, Tag]] = No
             new_head = r.head()
             new_head = r.head()
             if new_head:
             if new_head:
                 new_commit = r[new_head]
                 new_commit = r[new_head]
-                update_working_tree(r, old_tree, new_commit.tree)
+                changes = tree_changes(r.object_store, old_tree, new_commit.tree)
+                update_working_tree(
+                    r, old_tree, new_commit.tree, change_iterator=changes
+                )
         except KeyError:
         except KeyError:
             # No HEAD after reset
             # No HEAD after reset
             pass
             pass

+ 4 - 0
dulwich/stash.py

@@ -25,6 +25,7 @@ import os
 import sys
 import sys
 from typing import TYPE_CHECKING, Optional, TypedDict
 from typing import TYPE_CHECKING, Optional, TypedDict
 
 
+from .diff_tree import tree_changes
 from .file import GitFile
 from .file import GitFile
 from .index import (
 from .index import (
     IndexEntry,
     IndexEntry,
@@ -317,10 +318,13 @@ class Stash:
         # Update from stash tree to HEAD tree
         # Update from stash tree to HEAD tree
         # This will remove files that were in stash but not in HEAD,
         # This will remove files that were in stash but not in HEAD,
         # and restore files to their HEAD versions
         # and restore files to their HEAD versions
+        changes = tree_changes(self._repo.object_store, stash_tree_id, head_tree_id)
         update_working_tree(
         update_working_tree(
             self._repo,
             self._repo,
             old_tree_id=stash_tree_id,
             old_tree_id=stash_tree_id,
             new_tree_id=head_tree_id,
             new_tree_id=head_tree_id,
+            change_iterator=changes,
+            allow_overwrite_modified=True,  # We need to overwrite modified files
         )
         )
 
 
         return cid
         return cid

+ 87 - 41
tests/test_index.py

@@ -29,6 +29,7 @@ import sys
 import tempfile
 import tempfile
 from io import BytesIO
 from io import BytesIO
 
 
+from dulwich.diff_tree import tree_changes
 from dulwich.index import (
 from dulwich.index import (
     Index,
     Index,
     IndexEntry,
     IndexEntry,
@@ -1776,10 +1777,12 @@ class TestUpdateWorkingTree(TestCase):
 
 
         # Update working tree with normalizer
         # Update working tree with normalizer
         normalizer = TestBlobNormalizer()
         normalizer = TestBlobNormalizer()
+        changes = tree_changes(self.repo.object_store, None, tree.id)
         update_working_tree(
         update_working_tree(
             self.repo,
             self.repo,
             None,  # old_tree_id
             None,  # old_tree_id
             tree.id,  # new_tree_id
             tree.id,  # new_tree_id
+            change_iterator=changes,
             blob_normalizer=normalizer,
             blob_normalizer=normalizer,
         )
         )
 
 
@@ -1806,10 +1809,12 @@ class TestUpdateWorkingTree(TestCase):
         self.repo.object_store.add_object(tree)
         self.repo.object_store.add_object(tree)
 
 
         # Update working tree without normalizer
         # Update working tree without normalizer
+        changes = tree_changes(self.repo.object_store, None, tree.id)
         update_working_tree(
         update_working_tree(
             self.repo,
             self.repo,
             None,  # old_tree_id
             None,  # old_tree_id
             tree.id,  # new_tree_id
             tree.id,  # new_tree_id
+            change_iterator=changes,
             blob_normalizer=None,
             blob_normalizer=None,
         )
         )
 
 
@@ -1841,7 +1846,8 @@ class TestUpdateWorkingTree(TestCase):
         self.repo.object_store.add_object(tree1)
         self.repo.object_store.add_object(tree1)
 
 
         # Update to tree1 (create directory with files)
         # Update to tree1 (create directory with files)
-        update_working_tree(self.repo, None, tree1.id)
+        changes = tree_changes(self.repo.object_store, None, tree1.id)
+        update_working_tree(self.repo, None, tree1.id, change_iterator=changes)
 
 
         # Verify directory and files exist
         # Verify directory and files exist
         dir_path = os.path.join(self.tempdir, "dir")
         dir_path = os.path.join(self.tempdir, "dir")
@@ -1854,7 +1860,8 @@ class TestUpdateWorkingTree(TestCase):
         self.repo.object_store.add_object(tree2)
         self.repo.object_store.add_object(tree2)
 
 
         # Update to empty tree
         # Update to empty tree
-        update_working_tree(self.repo, tree1.id, tree2.id)
+        changes = tree_changes(self.repo.object_store, tree1.id, tree2.id)
+        update_working_tree(self.repo, tree1.id, tree2.id, change_iterator=changes)
 
 
         # Verify directory was removed
         # Verify directory was removed
         self.assertFalse(os.path.exists(dir_path))
         self.assertFalse(os.path.exists(dir_path))
@@ -1868,7 +1875,8 @@ class TestUpdateWorkingTree(TestCase):
         self.repo.object_store.add_object(tree1)
         self.repo.object_store.add_object(tree1)
 
 
         # Update to tree with submodule
         # Update to tree with submodule
-        update_working_tree(self.repo, None, tree1.id)
+        changes = tree_changes(self.repo.object_store, None, tree1.id)
+        update_working_tree(self.repo, None, tree1.id, change_iterator=changes)
 
 
         # Verify submodule directory exists with .git file
         # Verify submodule directory exists with .git file
         submodule_path = os.path.join(self.tempdir, "submodule")
         submodule_path = os.path.join(self.tempdir, "submodule")
@@ -1885,7 +1893,8 @@ class TestUpdateWorkingTree(TestCase):
         self.repo.object_store.add_object(tree2)
         self.repo.object_store.add_object(tree2)
 
 
         # Update to tree with file (should remove submodule directory and create file)
         # Update to tree with file (should remove submodule directory and create file)
-        update_working_tree(self.repo, tree1.id, tree2.id)
+        changes = tree_changes(self.repo.object_store, tree1.id, tree2.id)
+        update_working_tree(self.repo, tree1.id, tree2.id, change_iterator=changes)
 
 
         # Verify it's now a file
         # Verify it's now a file
         self.assertTrue(os.path.isfile(submodule_path))
         self.assertTrue(os.path.isfile(submodule_path))
@@ -1904,7 +1913,8 @@ class TestUpdateWorkingTree(TestCase):
         self.repo.object_store.add_object(tree1)
         self.repo.object_store.add_object(tree1)
 
 
         # Update to tree1
         # Update to tree1
-        update_working_tree(self.repo, None, tree1.id)
+        changes = tree_changes(self.repo.object_store, None, tree1.id)
+        update_working_tree(self.repo, None, tree1.id, change_iterator=changes)
 
 
         # Verify nested structure exists
         # Verify nested structure exists
         path_a = os.path.join(self.tempdir, "a")
         path_a = os.path.join(self.tempdir, "a")
@@ -1919,7 +1929,8 @@ class TestUpdateWorkingTree(TestCase):
         self.repo.object_store.add_object(tree2)
         self.repo.object_store.add_object(tree2)
 
 
         # Update to empty tree
         # Update to empty tree
-        update_working_tree(self.repo, tree1.id, tree2.id)
+        changes = tree_changes(self.repo.object_store, tree1.id, tree2.id)
+        update_working_tree(self.repo, tree1.id, tree2.id, change_iterator=changes)
 
 
         # Verify all directories were removed
         # Verify all directories were removed
         self.assertFalse(os.path.exists(path_a))
         self.assertFalse(os.path.exists(path_a))
@@ -1936,7 +1947,8 @@ class TestUpdateWorkingTree(TestCase):
         self.repo.object_store.add_object(tree1)
         self.repo.object_store.add_object(tree1)
 
 
         # Update to tree1
         # Update to tree1
-        update_working_tree(self.repo, None, tree1.id)
+        changes = tree_changes(self.repo.object_store, None, tree1.id)
+        update_working_tree(self.repo, None, tree1.id, change_iterator=changes)
 
 
         # Verify file exists
         # Verify file exists
         file_path = os.path.join(self.tempdir, "path")
         file_path = os.path.join(self.tempdir, "path")
@@ -1953,7 +1965,8 @@ class TestUpdateWorkingTree(TestCase):
         self.repo.object_store.add_object(tree2)
         self.repo.object_store.add_object(tree2)
 
 
         # Update should succeed but leave the directory alone
         # Update should succeed but leave the directory alone
-        update_working_tree(self.repo, tree1.id, tree2.id)
+        changes = tree_changes(self.repo.object_store, tree1.id, tree2.id)
+        update_working_tree(self.repo, tree1.id, tree2.id, change_iterator=changes)
 
 
         # Directory should still exist with its contents
         # Directory should still exist with its contents
         self.assertTrue(os.path.isdir(file_path))
         self.assertTrue(os.path.isdir(file_path))
@@ -1971,7 +1984,8 @@ class TestUpdateWorkingTree(TestCase):
         self.repo.object_store.add_object(tree1)
         self.repo.object_store.add_object(tree1)
 
 
         # Update to tree1
         # Update to tree1
-        update_working_tree(self.repo, None, tree1.id)
+        changes = tree_changes(self.repo.object_store, None, tree1.id)
+        update_working_tree(self.repo, None, tree1.id, change_iterator=changes)
 
 
         # Verify file exists
         # Verify file exists
         file_path = os.path.join(self.tempdir, "path")
         file_path = os.path.join(self.tempdir, "path")
@@ -1986,7 +2000,8 @@ class TestUpdateWorkingTree(TestCase):
         self.repo.object_store.add_object(tree2)
         self.repo.object_store.add_object(tree2)
 
 
         # Update should remove the empty directory
         # Update should remove the empty directory
-        update_working_tree(self.repo, tree1.id, tree2.id)
+        changes = tree_changes(self.repo.object_store, tree1.id, tree2.id)
+        update_working_tree(self.repo, tree1.id, tree2.id, change_iterator=changes)
 
 
         # Directory should be gone
         # Directory should be gone
         self.assertFalse(os.path.exists(file_path))
         self.assertFalse(os.path.exists(file_path))
@@ -2007,7 +2022,8 @@ class TestUpdateWorkingTree(TestCase):
         self.repo.object_store.add_object(tree1)
         self.repo.object_store.add_object(tree1)
 
 
         # Update to tree with symlink
         # Update to tree with symlink
-        update_working_tree(self.repo, None, tree1.id)
+        changes = tree_changes(self.repo.object_store, None, tree1.id)
+        update_working_tree(self.repo, None, tree1.id, change_iterator=changes)
 
 
         link_path = os.path.join(self.tempdir, "link")
         link_path = os.path.join(self.tempdir, "link")
         self.assertTrue(os.path.islink(link_path))
         self.assertTrue(os.path.islink(link_path))
@@ -2022,7 +2038,8 @@ class TestUpdateWorkingTree(TestCase):
         tree2[b"link"] = (0o100644, blob2.id)
         tree2[b"link"] = (0o100644, blob2.id)
         self.repo.object_store.add_object(tree2)
         self.repo.object_store.add_object(tree2)
 
 
-        update_working_tree(self.repo, tree1.id, tree2.id)
+        changes = tree_changes(self.repo.object_store, tree1.id, tree2.id)
+        update_working_tree(self.repo, tree1.id, tree2.id, change_iterator=changes)
 
 
         self.assertFalse(os.path.islink(link_path))
         self.assertFalse(os.path.islink(link_path))
         self.assertTrue(os.path.isfile(link_path))
         self.assertTrue(os.path.isfile(link_path))
@@ -2030,7 +2047,8 @@ class TestUpdateWorkingTree(TestCase):
             self.assertEqual(b"file content", f.read())
             self.assertEqual(b"file content", f.read())
 
 
         # Test 2: Replace file with symlink
         # Test 2: Replace file with symlink
-        update_working_tree(self.repo, tree2.id, tree1.id)
+        changes = tree_changes(self.repo.object_store, tree2.id, tree1.id)
+        update_working_tree(self.repo, tree2.id, tree1.id, change_iterator=changes)
 
 
         self.assertTrue(os.path.islink(link_path))
         self.assertTrue(os.path.islink(link_path))
         self.assertEqual(b"target/path", os.readlink(link_path).encode())
         self.assertEqual(b"target/path", os.readlink(link_path).encode())
@@ -2044,7 +2062,8 @@ class TestUpdateWorkingTree(TestCase):
         self.repo.object_store.add_object(tree3)
         self.repo.object_store.add_object(tree3)
 
 
         # Should remove empty directory
         # Should remove empty directory
-        update_working_tree(self.repo, tree1.id, tree3.id)
+        changes = tree_changes(self.repo.object_store, tree1.id, tree3.id)
+        update_working_tree(self.repo, tree1.id, tree3.id, change_iterator=changes)
         self.assertFalse(os.path.exists(link_path))
         self.assertFalse(os.path.exists(link_path))
 
 
     def test_update_working_tree_modified_file_to_dir_transition(self):
     def test_update_working_tree_modified_file_to_dir_transition(self):
@@ -2059,7 +2078,8 @@ class TestUpdateWorkingTree(TestCase):
         self.repo.object_store.add_object(tree1)
         self.repo.object_store.add_object(tree1)
 
 
         # Update to tree1
         # Update to tree1
-        update_working_tree(self.repo, None, tree1.id)
+        changes = tree_changes(self.repo.object_store, None, tree1.id)
+        update_working_tree(self.repo, None, tree1.id, change_iterator=changes)
 
 
         file_path = os.path.join(self.tempdir, "path")
         file_path = os.path.join(self.tempdir, "path")
 
 
@@ -2078,7 +2098,8 @@ class TestUpdateWorkingTree(TestCase):
 
 
         # Update should fail because can't create directory where modified file exists
         # Update should fail because can't create directory where modified file exists
         with self.assertRaises(IOError):
         with self.assertRaises(IOError):
-            update_working_tree(self.repo, tree1.id, tree2.id)
+            changes = tree_changes(self.repo.object_store, tree1.id, tree2.id)
+            update_working_tree(self.repo, tree1.id, tree2.id, change_iterator=changes)
 
 
         # File should still exist with modifications
         # File should still exist with modifications
         self.assertTrue(os.path.isfile(file_path))
         self.assertTrue(os.path.isfile(file_path))
@@ -2101,7 +2122,8 @@ class TestUpdateWorkingTree(TestCase):
         self.repo.object_store.add_object(tree1)
         self.repo.object_store.add_object(tree1)
 
 
         # Update to tree1
         # Update to tree1
-        update_working_tree(self.repo, None, tree1.id)
+        changes = tree_changes(self.repo.object_store, None, tree1.id)
+        update_working_tree(self.repo, None, tree1.id, change_iterator=changes)
 
 
         script_path = os.path.join(self.tempdir, "script.sh")
         script_path = os.path.join(self.tempdir, "script.sh")
         self.assertTrue(os.path.isfile(script_path))
         self.assertTrue(os.path.isfile(script_path))
@@ -2116,7 +2138,8 @@ class TestUpdateWorkingTree(TestCase):
         self.repo.object_store.add_object(tree2)
         self.repo.object_store.add_object(tree2)
 
 
         # Update to tree2
         # Update to tree2
-        update_working_tree(self.repo, tree1.id, tree2.id)
+        changes = tree_changes(self.repo.object_store, tree1.id, tree2.id)
+        update_working_tree(self.repo, tree1.id, tree2.id, change_iterator=changes)
 
 
         # Check it's now executable
         # Check it's now executable
         mode = os.stat(script_path).st_mode
         mode = os.stat(script_path).st_mode
@@ -2133,7 +2156,8 @@ class TestUpdateWorkingTree(TestCase):
         self.repo.object_store.add_object(tree1)
         self.repo.object_store.add_object(tree1)
 
 
         # Update to tree with submodule
         # Update to tree with submodule
-        update_working_tree(self.repo, None, tree1.id)
+        changes = tree_changes(self.repo.object_store, None, tree1.id)
+        update_working_tree(self.repo, None, tree1.id, change_iterator=changes)
 
 
         # Add untracked file to submodule directory
         # Add untracked file to submodule directory
         submodule_path = os.path.join(self.tempdir, "submodule")
         submodule_path = os.path.join(self.tempdir, "submodule")
@@ -2146,7 +2170,8 @@ class TestUpdateWorkingTree(TestCase):
         self.repo.object_store.add_object(tree2)
         self.repo.object_store.add_object(tree2)
 
 
         # Update should not remove submodule directory with untracked files
         # Update should not remove submodule directory with untracked files
-        update_working_tree(self.repo, tree1.id, tree2.id)
+        changes = tree_changes(self.repo.object_store, tree1.id, tree2.id)
+        update_working_tree(self.repo, tree1.id, tree2.id, change_iterator=changes)
 
 
         # Directory should still exist with untracked file
         # Directory should still exist with untracked file
         self.assertTrue(os.path.isdir(submodule_path))
         self.assertTrue(os.path.isdir(submodule_path))
@@ -2169,7 +2194,8 @@ class TestUpdateWorkingTree(TestCase):
         self.repo.object_store.add_object(tree1)
         self.repo.object_store.add_object(tree1)
 
 
         # Update to tree1
         # Update to tree1
-        update_working_tree(self.repo, None, tree1.id)
+        changes = tree_changes(self.repo.object_store, None, tree1.id)
+        update_working_tree(self.repo, None, tree1.id, change_iterator=changes)
 
 
         # Verify structure exists
         # Verify structure exists
         dir_path = os.path.join(self.tempdir, "dir")
         dir_path = os.path.join(self.tempdir, "dir")
@@ -2191,7 +2217,8 @@ class TestUpdateWorkingTree(TestCase):
 
 
         # Update should fail because directory is not empty
         # Update should fail because directory is not empty
         with self.assertRaises(IsADirectoryError):
         with self.assertRaises(IsADirectoryError):
-            update_working_tree(self.repo, tree1.id, tree2.id)
+            changes = tree_changes(self.repo.object_store, tree1.id, tree2.id)
+            update_working_tree(self.repo, tree1.id, tree2.id, change_iterator=changes)
 
 
         # Directory should still exist
         # Directory should still exist
         self.assertTrue(os.path.isdir(dir_path))
         self.assertTrue(os.path.isdir(dir_path))
@@ -2208,7 +2235,8 @@ class TestUpdateWorkingTree(TestCase):
         self.repo.object_store.add_object(tree1)
         self.repo.object_store.add_object(tree1)
 
 
         # Update to tree1
         # Update to tree1
-        update_working_tree(self.repo, None, tree1.id)
+        changes = tree_changes(self.repo.object_store, None, tree1.id)
+        update_working_tree(self.repo, None, tree1.id, change_iterator=changes)
 
 
         # Create tree with uppercase file (different content)
         # Create tree with uppercase file (different content)
         blob2 = Blob()
         blob2 = Blob()
@@ -2220,7 +2248,8 @@ class TestUpdateWorkingTree(TestCase):
         self.repo.object_store.add_object(tree2)
         self.repo.object_store.add_object(tree2)
 
 
         # Update to tree2
         # Update to tree2
-        update_working_tree(self.repo, tree1.id, tree2.id)
+        changes = tree_changes(self.repo.object_store, tree1.id, tree2.id)
+        update_working_tree(self.repo, tree1.id, tree2.id, change_iterator=changes)
 
 
         # Check what exists (behavior depends on filesystem)
         # Check what exists (behavior depends on filesystem)
         lowercase_path = os.path.join(self.tempdir, "readme.txt")
         lowercase_path = os.path.join(self.tempdir, "readme.txt")
@@ -2246,7 +2275,8 @@ class TestUpdateWorkingTree(TestCase):
         self.repo.object_store.add_object(tree1)
         self.repo.object_store.add_object(tree1)
 
 
         # Update to tree1
         # Update to tree1
-        update_working_tree(self.repo, None, tree1.id)
+        changes = tree_changes(self.repo.object_store, None, tree1.id)
+        update_working_tree(self.repo, None, tree1.id, change_iterator=changes)
 
 
         # Verify deep structure exists
         # Verify deep structure exists
         current_path = self.tempdir
         current_path = self.tempdir
@@ -2259,7 +2289,8 @@ class TestUpdateWorkingTree(TestCase):
         self.repo.object_store.add_object(tree2)
         self.repo.object_store.add_object(tree2)
 
 
         # Update should remove all empty directories
         # Update should remove all empty directories
-        update_working_tree(self.repo, tree1.id, tree2.id)
+        changes = tree_changes(self.repo.object_store, tree1.id, tree2.id)
+        update_working_tree(self.repo, tree1.id, tree2.id, change_iterator=changes)
 
 
         # Verify top level directory is gone
         # Verify top level directory is gone
         top_level = os.path.join(self.tempdir, "level0")
         top_level = os.path.join(self.tempdir, "level0")
@@ -2277,7 +2308,8 @@ class TestUpdateWorkingTree(TestCase):
         self.repo.object_store.add_object(tree1)
         self.repo.object_store.add_object(tree1)
 
 
         # Update to tree1
         # Update to tree1
-        update_working_tree(self.repo, None, tree1.id)
+        changes = tree_changes(self.repo.object_store, None, tree1.id)
+        update_working_tree(self.repo, None, tree1.id, change_iterator=changes)
 
 
         # Make file read-only
         # Make file read-only
         file_path = os.path.join(self.tempdir, "readonly.txt")
         file_path = os.path.join(self.tempdir, "readonly.txt")
@@ -2293,7 +2325,8 @@ class TestUpdateWorkingTree(TestCase):
         self.repo.object_store.add_object(tree2)
         self.repo.object_store.add_object(tree2)
 
 
         # Update should handle read-only file
         # Update should handle read-only file
-        update_working_tree(self.repo, tree1.id, tree2.id)
+        changes = tree_changes(self.repo.object_store, tree1.id, tree2.id)
+        update_working_tree(self.repo, tree1.id, tree2.id, change_iterator=changes)
 
 
         # Verify content was updated
         # Verify content was updated
         with open(file_path, "rb") as f:
         with open(file_path, "rb") as f:
@@ -2316,7 +2349,8 @@ class TestUpdateWorkingTree(TestCase):
         self.repo.object_store.add_object(tree)
         self.repo.object_store.add_object(tree)
 
 
         # Update should skip invalid files based on validation
         # Update should skip invalid files based on validation
-        update_working_tree(self.repo, None, tree.id)
+        changes = tree_changes(self.repo.object_store, None, tree.id)
+        update_working_tree(self.repo, None, tree.id, change_iterator=changes)
 
 
         # Valid file should exist
         # Valid file should exist
         self.assertTrue(os.path.exists(os.path.join(self.tempdir, "valid.txt")))
         self.assertTrue(os.path.exists(os.path.join(self.tempdir, "valid.txt")))
@@ -2342,7 +2376,8 @@ class TestUpdateWorkingTree(TestCase):
         self.repo.object_store.add_object(tree1)
         self.repo.object_store.add_object(tree1)
 
 
         # Update to tree1
         # Update to tree1
-        update_working_tree(self.repo, None, tree1.id)
+        changes = tree_changes(self.repo.object_store, None, tree1.id)
+        update_working_tree(self.repo, None, tree1.id, change_iterator=changes)
 
 
         link_path = os.path.join(self.tempdir, "link")
         link_path = os.path.join(self.tempdir, "link")
         self.assertTrue(os.path.islink(link_path))
         self.assertTrue(os.path.islink(link_path))
@@ -2357,7 +2392,8 @@ class TestUpdateWorkingTree(TestCase):
         self.repo.object_store.add_object(tree2)
         self.repo.object_store.add_object(tree2)
 
 
         # Update should replace symlink with actual directory
         # Update should replace symlink with actual directory
-        update_working_tree(self.repo, tree1.id, tree2.id)
+        changes = tree_changes(self.repo.object_store, tree1.id, tree2.id)
+        update_working_tree(self.repo, tree1.id, tree2.id, change_iterator=changes)
 
 
         self.assertFalse(os.path.islink(link_path))
         self.assertFalse(os.path.islink(link_path))
         self.assertTrue(os.path.isdir(link_path))
         self.assertTrue(os.path.isdir(link_path))
@@ -2393,10 +2429,12 @@ class TestUpdateWorkingTree(TestCase):
         tree2[b"item"] = (S_IFGITLINK, submodule_sha)
         tree2[b"item"] = (S_IFGITLINK, submodule_sha)
         self.repo.object_store.add_object(tree2)
         self.repo.object_store.add_object(tree2)
 
 
-        update_working_tree(self.repo, None, tree1.id)
+        changes = tree_changes(self.repo.object_store, None, tree1.id)
+        update_working_tree(self.repo, None, tree1.id, change_iterator=changes)
         self.assertTrue(os.path.isfile(os.path.join(self.tempdir, "item")))
         self.assertTrue(os.path.isfile(os.path.join(self.tempdir, "item")))
 
 
-        update_working_tree(self.repo, tree1.id, tree2.id)
+        changes = tree_changes(self.repo.object_store, tree1.id, tree2.id)
+        update_working_tree(self.repo, tree1.id, tree2.id, change_iterator=changes)
         self.assertTrue(os.path.isdir(os.path.join(self.tempdir, "item")))
         self.assertTrue(os.path.isdir(os.path.join(self.tempdir, "item")))
 
 
         # Test 2: Submodule → Executable file
         # Test 2: Submodule → Executable file
@@ -2404,7 +2442,8 @@ class TestUpdateWorkingTree(TestCase):
         tree3[b"item"] = (0o100755, exec_blob.id)
         tree3[b"item"] = (0o100755, exec_blob.id)
         self.repo.object_store.add_object(tree3)
         self.repo.object_store.add_object(tree3)
 
 
-        update_working_tree(self.repo, tree2.id, tree3.id)
+        changes = tree_changes(self.repo.object_store, tree2.id, tree3.id)
+        update_working_tree(self.repo, tree2.id, tree3.id, change_iterator=changes)
         item_path = os.path.join(self.tempdir, "item")
         item_path = os.path.join(self.tempdir, "item")
         self.assertTrue(os.path.isfile(item_path))
         self.assertTrue(os.path.isfile(item_path))
         if sys.platform != "win32":
         if sys.platform != "win32":
@@ -2415,7 +2454,8 @@ class TestUpdateWorkingTree(TestCase):
         tree4[b"item"] = (0o120000, link_blob.id)
         tree4[b"item"] = (0o120000, link_blob.id)
         self.repo.object_store.add_object(tree4)
         self.repo.object_store.add_object(tree4)
 
 
-        update_working_tree(self.repo, tree3.id, tree4.id)
+        changes = tree_changes(self.repo.object_store, tree3.id, tree4.id)
+        update_working_tree(self.repo, tree3.id, tree4.id, change_iterator=changes)
         self.assertTrue(os.path.islink(item_path))
         self.assertTrue(os.path.islink(item_path))
 
 
         # Test 4: Symlink → Submodule
         # Test 4: Symlink → Submodule
@@ -2423,14 +2463,16 @@ class TestUpdateWorkingTree(TestCase):
         tree5[b"item"] = (S_IFGITLINK, submodule_sha)
         tree5[b"item"] = (S_IFGITLINK, submodule_sha)
         self.repo.object_store.add_object(tree5)
         self.repo.object_store.add_object(tree5)
 
 
-        update_working_tree(self.repo, tree4.id, tree5.id)
+        changes = tree_changes(self.repo.object_store, tree4.id, tree5.id)
+        update_working_tree(self.repo, tree4.id, tree5.id, change_iterator=changes)
         self.assertTrue(os.path.isdir(item_path))
         self.assertTrue(os.path.isdir(item_path))
 
 
         # Test 5: Clean up - Submodule → absent
         # Test 5: Clean up - Submodule → absent
         tree6 = Tree()
         tree6 = Tree()
         self.repo.object_store.add_object(tree6)
         self.repo.object_store.add_object(tree6)
 
 
-        update_working_tree(self.repo, tree5.id, tree6.id)
+        changes = tree_changes(self.repo.object_store, tree5.id, tree6.id)
+        update_working_tree(self.repo, tree5.id, tree6.id, change_iterator=changes)
         self.assertFalse(os.path.exists(item_path))
         self.assertFalse(os.path.exists(item_path))
 
 
         # Test 6: Symlink → Executable file
         # Test 6: Symlink → Executable file
@@ -2438,7 +2480,8 @@ class TestUpdateWorkingTree(TestCase):
         tree7[b"item2"] = (0o120000, link_blob.id)
         tree7[b"item2"] = (0o120000, link_blob.id)
         self.repo.object_store.add_object(tree7)
         self.repo.object_store.add_object(tree7)
 
 
-        update_working_tree(self.repo, tree6.id, tree7.id)
+        changes = tree_changes(self.repo.object_store, tree6.id, tree7.id)
+        update_working_tree(self.repo, tree6.id, tree7.id, change_iterator=changes)
         item2_path = os.path.join(self.tempdir, "item2")
         item2_path = os.path.join(self.tempdir, "item2")
         self.assertTrue(os.path.islink(item2_path))
         self.assertTrue(os.path.islink(item2_path))
 
 
@@ -2446,7 +2489,8 @@ class TestUpdateWorkingTree(TestCase):
         tree8[b"item2"] = (0o100755, exec_blob.id)
         tree8[b"item2"] = (0o100755, exec_blob.id)
         self.repo.object_store.add_object(tree8)
         self.repo.object_store.add_object(tree8)
 
 
-        update_working_tree(self.repo, tree7.id, tree8.id)
+        changes = tree_changes(self.repo.object_store, tree7.id, tree8.id)
+        update_working_tree(self.repo, tree7.id, tree8.id, change_iterator=changes)
         self.assertTrue(os.path.isfile(item2_path))
         self.assertTrue(os.path.isfile(item2_path))
         if sys.platform != "win32":
         if sys.platform != "win32":
             self.assertTrue(os.access(item2_path, os.X_OK))
             self.assertTrue(os.access(item2_path, os.X_OK))
@@ -2468,7 +2512,8 @@ class TestUpdateWorkingTree(TestCase):
         self.repo.object_store.add_object(tree1)
         self.repo.object_store.add_object(tree1)
 
 
         # Update to tree1
         # Update to tree1
-        update_working_tree(self.repo, None, tree1.id)
+        changes = tree_changes(self.repo.object_store, None, tree1.id)
+        update_working_tree(self.repo, None, tree1.id, change_iterator=changes)
 
 
         # Create a directory where file2.txt is, to cause a conflict
         # Create a directory where file2.txt is, to cause a conflict
         file2_path = os.path.join(self.tempdir, "file2.txt")
         file2_path = os.path.join(self.tempdir, "file2.txt")
@@ -2494,7 +2539,8 @@ class TestUpdateWorkingTree(TestCase):
 
 
         # Update should partially succeed - file1 updated, file2 blocked
         # Update should partially succeed - file1 updated, file2 blocked
         try:
         try:
-            update_working_tree(self.repo, tree1.id, tree2.id)
+            changes = tree_changes(self.repo.object_store, tree1.id, tree2.id)
+            update_working_tree(self.repo, tree1.id, tree2.id, change_iterator=changes)
         except IsADirectoryError:
         except IsADirectoryError:
             # Expected to fail on file2 because it's a directory
             # Expected to fail on file2 because it's a directory
             pass
             pass

+ 213 - 13
tests/test_porcelain.py

@@ -3540,11 +3540,12 @@ class ResetTests(PorcelainTestCase):
         with open(file2, "w") as f:
         with open(file2, "w") as f:
             f.write("new content")
             f.write("new content")
 
 
-        # Reset to commit that has file2 removed - should delete untracked file2
+        # Reset to commit that has file2 removed - untracked file2 should remain
         porcelain.reset(self.repo, "hard", sha2)
         porcelain.reset(self.repo, "hard", sha2)
 
 
         self.assertTrue(os.path.exists(file1))
         self.assertTrue(os.path.exists(file1))
-        self.assertFalse(os.path.exists(file2))
+        # Untracked files are not removed by reset --hard
+        self.assertTrue(os.path.exists(file2))
 
 
     def test_hard_reset_to_remote_branch(self) -> None:
     def test_hard_reset_to_remote_branch(self) -> None:
         """Test reset --hard to remote branch deletes local files not in remote."""
         """Test reset --hard to remote branch deletes local files not in remote."""
@@ -4024,16 +4025,30 @@ class CheckoutTests(PorcelainTestCase):
             [{"add": [], "delete": [], "modify": [b"nee"]}, [], []], status
             [{"add": [], "delete": [], "modify": [b"nee"]}, [], []], status
         )
         )
 
 
-        # The new checkout behavior allows switching if the file doesn't exist in target branch
-        # (changes can be preserved)
-        porcelain.checkout(self.repo, b"uni")
-        self.assertEqual(b"uni", porcelain.active_branch(self.repo))
+        # Checkout should fail when there are staged changes that would be lost
+        # This matches Git's behavior to prevent data loss
+        from dulwich.errors import WorkingTreeModifiedError
+
+        with self.assertRaises(WorkingTreeModifiedError) as cm:
+            porcelain.checkout(self.repo, b"uni")
+
+        self.assertIn("nee", str(cm.exception))
+
+        # Should still be on master branch
+        self.assertEqual(b"master", porcelain.active_branch(self.repo))
 
 
-        # The staged changes are lost and the file is removed from working tree
-        # because it doesn't exist in the target branch
+        # The staged changes should still be present
         status = list(porcelain.status(self.repo))
         status = list(porcelain.status(self.repo))
-        # File 'nee' is gone completely
-        self.assertEqual([{"add": [], "delete": [], "modify": []}, [], []], status)
+        self.assertEqual(
+            [{"add": [], "delete": [], "modify": [b"nee"]}, [], []], status
+        )
+        self.assertTrue(os.path.exists(nee_path))
+
+        # Force checkout should work and lose the changes
+        porcelain.checkout(self.repo, b"uni", force=True)
+        self.assertEqual(b"uni", porcelain.active_branch(self.repo))
+
+        # Now the file should be gone
         self.assertFalse(os.path.exists(nee_path))
         self.assertFalse(os.path.exists(nee_path))
 
 
     def test_checkout_to_branch_with_modified_file_not_present_forced(self) -> None:
     def test_checkout_to_branch_with_modified_file_not_present_forced(self) -> None:
@@ -4454,7 +4469,7 @@ class GeneralCheckoutTests(PorcelainTestCase):
         self.assertEqual(b"master", porcelain.active_branch(self.repo))
         self.assertEqual(b"master", porcelain.active_branch(self.repo))
 
 
     def test_checkout_force(self) -> None:
     def test_checkout_force(self) -> None:
-        """Test forced checkout discards local changes."""
+        """Test forced checkout discards local changes for files that differ between branches."""
         # Modify a file
         # Modify a file
         with open(self._foo_path, "w") as f:
         with open(self._foo_path, "w") as f:
             f.write("modified content\n")
             f.write("modified content\n")
@@ -4464,10 +4479,11 @@ class GeneralCheckoutTests(PorcelainTestCase):
 
 
         self.assertEqual(b"feature", porcelain.active_branch(self.repo))
         self.assertEqual(b"feature", porcelain.active_branch(self.repo))
 
 
-        # Local changes should be discarded
+        # Since foo has the same content in master and feature branches,
+        # checkout should NOT restore it - the modified content should remain
         with open(self._foo_path) as f:
         with open(self._foo_path) as f:
             content = f.read()
             content = f.read()
-        self.assertEqual("initial content\n", content)
+        self.assertEqual("modified content\n", content)
 
 
     def test_checkout_nonexistent_ref(self) -> None:
     def test_checkout_nonexistent_ref(self) -> None:
         """Test checkout of non-existent branch/commit."""
         """Test checkout of non-existent branch/commit."""
@@ -5294,6 +5310,190 @@ class PullTests(PorcelainTestCase):
         with Repo(self.target_path) as r:
         with Repo(self.target_path) as r:
             self.assertEqual(r[b"HEAD"].id, self.repo[b"HEAD"].id)
             self.assertEqual(r[b"HEAD"].id, self.repo[b"HEAD"].id)
 
 
+    def test_pull_protects_modified_files(self) -> None:
+        """Test that pull refuses to overwrite uncommitted changes by default."""
+        from dulwich.errors import WorkingTreeModifiedError
+
+        outstream = BytesIO()
+        errstream = BytesIO()
+
+        # Create a file with content in the source repo
+        test_file = os.path.join(self.repo.path, "testfile.txt")
+        with open(test_file, "w") as f:
+            f.write("original content")
+
+        porcelain.add(repo=self.repo.path, paths=[test_file])
+        porcelain.commit(
+            repo=self.repo.path,
+            message=b"Add test file",
+            author=b"test <email>",
+            committer=b"test <email>",
+        )
+
+        # Pull this change to target first
+        porcelain.pull(
+            self.target_path,
+            self.repo.path,
+            b"refs/heads/master",
+            outstream=outstream,
+            errstream=errstream,
+        )
+
+        # Now modify the file in source repo
+        with open(test_file, "w") as f:
+            f.write("updated content")
+
+        porcelain.add(repo=self.repo.path, paths=[test_file])
+        porcelain.commit(
+            repo=self.repo.path,
+            message=b"Update test file",
+            author=b"test <email>",
+            committer=b"test <email>",
+        )
+
+        # Modify the same file in target working directory (uncommitted)
+        target_file = os.path.join(self.target_path, "testfile.txt")
+        with open(target_file, "w") as f:
+            f.write("local modifications")
+
+        # Pull should fail because of uncommitted changes
+        with self.assertRaises(WorkingTreeModifiedError) as cm:
+            porcelain.pull(
+                self.target_path,
+                self.repo.path,
+                b"refs/heads/master",
+                outstream=outstream,
+                errstream=errstream,
+            )
+
+        self.assertIn("Your local changes", str(cm.exception))
+        self.assertIn("testfile.txt", str(cm.exception))
+
+        # Verify the file still has local modifications
+        with open(target_file) as f:
+            self.assertEqual(f.read(), "local modifications")
+
+    def test_pull_force_overwrites_modified_files(self) -> None:
+        """Test that pull with force=True overwrites uncommitted changes."""
+        outstream = BytesIO()
+        errstream = BytesIO()
+
+        # Create a file with content in the source repo
+        test_file = os.path.join(self.repo.path, "testfile.txt")
+        with open(test_file, "w") as f:
+            f.write("original content")
+
+        porcelain.add(repo=self.repo.path, paths=[test_file])
+        porcelain.commit(
+            repo=self.repo.path,
+            message=b"Add test file",
+            author=b"test <email>",
+            committer=b"test <email>",
+        )
+
+        # Pull this change to target first
+        porcelain.pull(
+            self.target_path,
+            self.repo.path,
+            b"refs/heads/master",
+            outstream=outstream,
+            errstream=errstream,
+        )
+
+        # Now modify the file in source repo
+        with open(test_file, "w") as f:
+            f.write("updated content")
+
+        porcelain.add(repo=self.repo.path, paths=[test_file])
+        porcelain.commit(
+            repo=self.repo.path,
+            message=b"Update test file",
+            author=b"test <email>",
+            committer=b"test <email>",
+        )
+
+        # Modify the same file in target working directory (uncommitted)
+        target_file = os.path.join(self.target_path, "testfile.txt")
+        with open(target_file, "w") as f:
+            f.write("local modifications")
+
+        # Pull with force=True should succeed and overwrite local changes
+        porcelain.pull(
+            self.target_path,
+            self.repo.path,
+            b"refs/heads/master",
+            outstream=outstream,
+            errstream=errstream,
+            force=True,
+        )
+
+        # Verify the file now has the remote content
+        with open(target_file) as f:
+            self.assertEqual(f.read(), "updated content")
+
+        # Check the HEAD is updated too
+        with Repo(self.target_path) as r:
+            self.assertEqual(r[b"HEAD"].id, self.repo[b"HEAD"].id)
+
+    def test_pull_allows_unmodified_files(self) -> None:
+        """Test that pull allows updating files that haven't been modified locally."""
+        outstream = BytesIO()
+        errstream = BytesIO()
+
+        # Create a file with content in the source repo
+        test_file = os.path.join(self.repo.path, "testfile.txt")
+        with open(test_file, "w") as f:
+            f.write("original content")
+
+        porcelain.add(repo=self.repo.path, paths=[test_file])
+        porcelain.commit(
+            repo=self.repo.path,
+            message=b"Add test file",
+            author=b"test <email>",
+            committer=b"test <email>",
+        )
+
+        # Pull this change to target first
+        porcelain.pull(
+            self.target_path,
+            self.repo.path,
+            b"refs/heads/master",
+            outstream=outstream,
+            errstream=errstream,
+        )
+
+        # Now modify the file in source repo
+        with open(test_file, "w") as f:
+            f.write("updated content")
+
+        porcelain.add(repo=self.repo.path, paths=[test_file])
+        porcelain.commit(
+            repo=self.repo.path,
+            message=b"Update test file",
+            author=b"test <email>",
+            committer=b"test <email>",
+        )
+
+        # Don't modify the file in target - it should be safe to update
+        target_file = os.path.join(self.target_path, "testfile.txt")
+
+        # Pull should succeed since the file wasn't modified locally
+        porcelain.pull(
+            self.target_path,
+            self.repo.path,
+            b"refs/heads/master",
+            outstream=outstream,
+            errstream=errstream,
+        )
+
+        # Verify the file now has the remote content
+        with open(target_file) as f:
+            self.assertEqual(f.read(), "updated content")
+
+        # Check the HEAD is updated too
+        with Repo(self.target_path) as r:
+            self.assertEqual(r[b"HEAD"].id, self.repo[b"HEAD"].id)
+
 
 
 class StatusTests(PorcelainTestCase):
 class StatusTests(PorcelainTestCase):
     def test_empty(self) -> None:
     def test_empty(self) -> None: