فهرست منبع

Refactor checkout/switch functions to reduce code duplication

- _get_current_head_tree()
- _get_worktree_update_config()
- _perform_tree_switch()
Jelmer Vernooij 2 ماه پیش
والد
کامیت
fb04d15ce5
1فایلهای تغییر یافته به همراه112 افزوده شده و 119 حذف شده
  1. 112 119
      dulwich/porcelain.py

+ 112 - 119
dulwich/porcelain.py

@@ -5082,6 +5082,26 @@ def check_ignore(
                 yield _quote_path(output_path) if quote_path else output_path
 
 
+def _get_current_head_tree(repo: Repo) -> bytes | None:
+    """Get the current HEAD tree ID.
+
+    Args:
+      repo: Repository object
+
+    Returns:
+      Tree ID of current HEAD, or None if no HEAD exists (empty repo)
+    """
+    try:
+        current_head = repo.refs[b"HEAD"]
+        current_commit = repo[current_head]
+        assert isinstance(current_commit, Commit), "Expected a Commit object"
+        tree_id: bytes = current_commit.tree
+        return tree_id
+    except KeyError:
+        # No HEAD yet (empty repo)
+        return None
+
+
 def _check_uncommitted_changes(
     repo: Repo, target_tree_id: bytes, force: bool = False
 ) -> None:
@@ -5099,11 +5119,8 @@ def _check_uncommitted_changes(
         return
 
     # Get current HEAD tree for comparison
-    try:
-        current_head = repo.refs[b"HEAD"]
-        current_commit = repo[current_head]
-        assert isinstance(current_commit, Commit), "Expected a Commit object"
-    except KeyError:
+    current_tree_id = _get_current_head_tree(repo)
+    if current_tree_id is None:
         # No HEAD yet (empty repo)
         return
 
@@ -5139,6 +5156,92 @@ def _check_uncommitted_changes(
                 )
 
 
+def _get_worktree_update_config(
+    repo: Repo,
+) -> tuple[
+    bool,
+    Callable[[bytes], bool],
+    Callable[[str | bytes | os.PathLike[str], str | bytes | os.PathLike[str]], None],
+]:
+    """Get configuration for working tree updates.
+
+    Args:
+      repo: Repository object
+
+    Returns:
+      Tuple of (honor_filemode, validate_path_element, symlink_fn)
+    """
+    config = repo.get_config()
+    honor_filemode = config.get_boolean(b"core", b"filemode", os.name != "nt")
+
+    if config.get_boolean(b"core", b"core.protectNTFS", os.name == "nt"):
+        validate_path_element = validate_path_element_ntfs
+    else:
+        validate_path_element = validate_path_element_default
+
+    if config.get_boolean(b"core", b"symlinks", True):
+
+        def symlink_wrapper(
+            source: str | bytes | os.PathLike[str],
+            target: str | bytes | os.PathLike[str],
+        ) -> None:
+            symlink(source, target)  # type: ignore[arg-type,unused-ignore]
+
+        symlink_fn = symlink_wrapper
+    else:
+
+        def symlink_fallback(
+            source: str | bytes | os.PathLike[str],
+            target: str | bytes | os.PathLike[str],
+        ) -> None:
+            mode = "w" + ("b" if isinstance(source, bytes) else "")
+            with open(target, mode) as f:
+                f.write(source)
+
+        symlink_fn = symlink_fallback
+
+    return honor_filemode, validate_path_element, symlink_fn
+
+
+def _perform_tree_switch(
+    repo: Repo,
+    current_tree_id: bytes | None,
+    target_tree_id: bytes,
+    force: bool = False,
+) -> None:
+    """Perform the actual working tree switch.
+
+    Args:
+      repo: Repository object
+      current_tree_id: Current tree ID (or None for empty repo)
+      target_tree_id: Target tree ID to switch to
+      force: If True, force removal of untracked files and allow overwriting modified files
+    """
+    honor_filemode, validate_path_element, symlink_fn = _get_worktree_update_config(
+        repo
+    )
+
+    # Get blob normalizer for line ending conversion
+    blob_normalizer = repo.get_blob_normalizer()
+
+    # Update working tree
+    tree_change_iterator: Iterator[TreeChange] = tree_changes(
+        repo.object_store, current_tree_id, target_tree_id
+    )
+    update_working_tree(
+        repo,
+        current_tree_id,
+        target_tree_id,
+        change_iterator=tree_change_iterator,
+        honor_filemode=honor_filemode,
+        validate_path_element=validate_path_element,
+        symlink_fn=symlink_fn,
+        force_remove_untracked=force,
+        blob_normalizer=blob_normalizer,
+        allow_overwrite_modified=force,
+    )
+
+
 def update_head(
     repo: RepoPath,
     target: str | bytes,
@@ -5293,68 +5396,14 @@ def checkout(
         target_tree_id = target_commit.tree
 
         # Get current HEAD tree for comparison
-        try:
-            current_head = r.refs[b"HEAD"]
-            current_commit = r[current_head]
-            assert isinstance(current_commit, Commit), "Expected a Commit object"
-            current_tree_id = current_commit.tree
-        except KeyError:
-            # No HEAD yet (empty repo)
-            current_tree_id = None
+        current_tree_id = _get_current_head_tree(r)
 
         # Check for uncommitted changes if not forcing
         if current_tree_id is not None:
             _check_uncommitted_changes(r, target_tree_id, force)
 
-        # Get configuration for working directory update
-        config = r.get_config()
-        honor_filemode = config.get_boolean(b"core", b"filemode", os.name != "nt")
-
-        if config.get_boolean(b"core", b"core.protectNTFS", os.name == "nt"):
-            validate_path_element = validate_path_element_ntfs
-        else:
-            validate_path_element = validate_path_element_default
-
-        if config.get_boolean(b"core", b"symlinks", True):
-
-            def symlink_wrapper(
-                source: str | bytes | os.PathLike[str],
-                target: str | bytes | os.PathLike[str],
-            ) -> None:
-                symlink(source, target)  # type: ignore[arg-type,unused-ignore]
-
-            symlink_fn = symlink_wrapper
-        else:
-
-            def symlink_fallback(
-                source: str | bytes | os.PathLike[str],
-                target: str | bytes | os.PathLike[str],
-            ) -> None:
-                mode = "w" + ("b" if isinstance(source, bytes) else "")
-                with open(target, mode) as f:
-                    f.write(source)
-
-            symlink_fn = symlink_fallback
-
-        # Get blob normalizer for line ending conversion
-        blob_normalizer = r.get_blob_normalizer()
-
         # Update working tree
-        tree_change_iterator: Iterator[TreeChange] = tree_changes(
-            r.object_store, current_tree_id, target_tree_id
-        )
-        update_working_tree(
-            r,
-            current_tree_id,
-            target_tree_id,
-            change_iterator=tree_change_iterator,
-            honor_filemode=honor_filemode,
-            validate_path_element=validate_path_element,
-            symlink_fn=symlink_fn,
-            force_remove_untracked=force,
-            blob_normalizer=blob_normalizer,
-            allow_overwrite_modified=force,
-        )
+        _perform_tree_switch(r, current_tree_id, target_tree_id, force)
 
         # Update HEAD
         if new_branch:
@@ -5433,8 +5482,6 @@ def restore(
     with open_repo_closing(repo) as r:
         from .index import _fs_to_tree_path, build_file_from_blob
 
-        wt = r.get_worktree()
-
         # Determine the source tree
         if source is None:
             if staged:
@@ -5586,68 +5633,14 @@ def switch(
         target_tree_id = target_commit.tree
 
         # Get current HEAD tree for comparison
-        try:
-            current_head = r.refs[b"HEAD"]
-            current_commit = r[current_head]
-            assert isinstance(current_commit, Commit), "Expected a Commit object"
-            current_tree_id = current_commit.tree
-        except KeyError:
-            # No HEAD yet (empty repo)
-            current_tree_id = None
+        current_tree_id = _get_current_head_tree(r)
 
         # Check for uncommitted changes if not forcing
         if current_tree_id is not None:
             _check_uncommitted_changes(r, target_tree_id, force)
 
-        # Get configuration for working directory update
-        config = r.get_config()
-        honor_filemode = config.get_boolean(b"core", b"filemode", os.name != "nt")
-
-        if config.get_boolean(b"core", b"core.protectNTFS", os.name == "nt"):
-            validate_path_element = validate_path_element_ntfs
-        else:
-            validate_path_element = validate_path_element_default
-
-        if config.get_boolean(b"core", b"symlinks", True):
-
-            def symlink_wrapper(
-                source: str | bytes | os.PathLike[str],
-                target: str | bytes | os.PathLike[str],
-            ) -> None:
-                symlink(source, target)  # type: ignore[arg-type,unused-ignore]
-
-            symlink_fn = symlink_wrapper
-        else:
-
-            def symlink_fallback(
-                source: str | bytes | os.PathLike[str],
-                target: str | bytes | os.PathLike[str],
-            ) -> None:
-                mode = "w" + ("b" if isinstance(source, bytes) else "")
-                with open(target, mode) as f:
-                    f.write(source)
-
-            symlink_fn = symlink_fallback
-
-        # Get blob normalizer for line ending conversion
-        blob_normalizer = r.get_blob_normalizer()
-
         # Update working tree
-        tree_change_iterator: Iterator[TreeChange] = tree_changes(
-            r.object_store, current_tree_id, target_tree_id
-        )
-        update_working_tree(
-            r,
-            current_tree_id,
-            target_tree_id,
-            change_iterator=tree_change_iterator,
-            honor_filemode=honor_filemode,
-            validate_path_element=validate_path_element,
-            symlink_fn=symlink_fn,
-            force_remove_untracked=force,
-            blob_normalizer=blob_normalizer,
-            allow_overwrite_modified=force,
-        )
+        _perform_tree_switch(r, current_tree_id, target_tree_id, force)
 
         # Update HEAD
         if create: