浏览代码

Update working tree in pull. Fixes #452

Jelmer Vernooij 2 月之前
父节点
当前提交
84b5ecbd44
共有 6 个文件被更改,包括 392 次插入111 次删除
  1. 2 0
      NEWS
  2. 151 0
      dulwich/index.py
  3. 185 97
      dulwich/porcelain.py
  4. 1 3
      tests/test_cli_merge.py
  5. 51 5
      tests/test_porcelain.py
  6. 2 6
      tests/test_porcelain_merge.py

+ 2 - 0
NEWS

@@ -24,6 +24,8 @@
 
  * Add basic merge command. (Jelmer Vernooij)
 
+ * Update working tree in pull. (Jelmer Vernooij, #452)
+
 0.22.8	2025-03-02
 
  * Allow passing in plain strings to ``dulwich.porcelain.tag_create``

+ 151 - 0
dulwich/index.py

@@ -1007,6 +1007,157 @@ def _has_directory_changed(tree_path: bytes, entry) -> bool:
     return False
 
 
+def update_working_tree(
+    repo,
+    old_tree_id,
+    new_tree_id,
+    honor_filemode=True,
+    validate_path_element=None,
+    symlink_fn=None,
+    force_remove_untracked=False,
+):
+    """Update the working tree and index to match a new tree.
+
+    This function handles:
+    - Adding new files
+    - Updating modified files
+    - Removing deleted files
+    - Cleaning up empty directories
+
+    Args:
+      repo: Repository object
+      old_tree_id: SHA of the tree before the update
+      new_tree_id: SHA of the tree to update to
+      honor_filemode: An optional flag to honor core.filemode setting
+      validate_path_element: Function to validate path elements to check out
+      symlink_fn: Function to use for creating symlinks
+      force_remove_untracked: If True, remove files that exist in working
+        directory but not in target tree, even if old_tree_id is None
+    """
+    import os
+
+    # Set default validate_path_element if not provided
+    if validate_path_element is None:
+        validate_path_element = validate_path_element_default
+
+    # Get the trees
+    old_tree = repo[old_tree_id] if old_tree_id else None
+    repo[new_tree_id]
+
+    # Open the index
+    index = repo.open_index()
+
+    # Track which paths we've dealt with
+    handled_paths = set()
+
+    # Get repo path as string for comparisons
+    repo_path_str = repo.path if isinstance(repo.path, str) else repo.path.decode()
+
+    # First, update/add all files in the new tree
+    for entry in iter_tree_contents(repo.object_store, new_tree_id):
+        handled_paths.add(entry.path)
+
+        # Skip .git directory
+        if entry.path.startswith(b".git"):
+            continue
+
+        # Validate path element
+        if not validate_path(entry.path, validate_path_element):
+            continue
+
+        # Build full path
+        full_path = os.path.join(repo_path_str, entry.path.decode())
+
+        # Get the blob
+        blob = repo.object_store[entry.sha]
+
+        # Ensure parent directory exists
+        parent_dir = os.path.dirname(full_path)
+        if parent_dir and not os.path.exists(parent_dir):
+            os.makedirs(parent_dir)
+
+        # Write the file
+        st = build_file_from_blob(
+            blob,
+            entry.mode,
+            full_path.encode(),
+            honor_filemode=honor_filemode,
+            symlink_fn=symlink_fn,
+        )
+
+        # Update index
+        index[entry.path] = index_entry_from_stat(st, entry.sha)
+
+    # Remove files that existed in old tree but not in new tree
+    if old_tree:
+        for entry in iter_tree_contents(repo.object_store, old_tree_id):
+            if entry.path not in handled_paths:
+                # Skip .git directory
+                if entry.path.startswith(b".git"):
+                    continue
+
+                # File was deleted
+                full_path = os.path.join(repo_path_str, entry.path.decode())
+
+                # Remove from working tree
+                if os.path.exists(full_path):
+                    os.remove(full_path)
+
+                # Remove from index
+                if entry.path in index:
+                    del index[entry.path]
+
+                # Clean up empty directories
+                dir_path = os.path.dirname(full_path)
+                while (
+                    dir_path and dir_path != repo_path_str and os.path.exists(dir_path)
+                ):
+                    try:
+                        if not os.listdir(dir_path):
+                            os.rmdir(dir_path)
+                            dir_path = os.path.dirname(dir_path)
+                        else:
+                            break
+                    except OSError:
+                        break
+
+    # If force_remove_untracked is True, remove any files in working directory
+    # that are not in the target tree (useful for reset --hard)
+    if force_remove_untracked:
+        # Walk through all files in the working directory
+        for root, dirs, files in os.walk(repo_path_str):
+            # Skip .git directory
+            if ".git" in dirs:
+                dirs.remove(".git")
+
+            for file in files:
+                full_path = os.path.join(root, file)
+                # Get relative path from repo root
+                rel_path = os.path.relpath(full_path, repo_path_str)
+                rel_path_bytes = rel_path.encode()
+
+                # If this file is not in the target tree, remove it
+                if rel_path_bytes not in handled_paths:
+                    os.remove(full_path)
+
+                    # Remove from index if present
+                    if rel_path_bytes in index:
+                        del index[rel_path_bytes]
+
+        # Clean up empty directories
+        for root, dirs, files in os.walk(repo_path_str, topdown=False):
+            if ".git" in root:
+                continue
+            if root != repo_path_str and not files and not dirs:
+                try:
+                    os.rmdir(root)
+                except OSError:
+                    pass
+
+    # Write the updated index
+    index.write()
+
+
 def get_unstaged_changes(
     index: Index, root_path: Union[str, bytes], filter_blob_callback=None
 ):

+ 185 - 97
dulwich/porcelain.py

@@ -102,6 +102,7 @@ from .index import (
     build_file_from_blob,
     get_unstaged_changes,
     index_entry_from_stat,
+    update_working_tree,
 )
 from .object_store import iter_tree_contents, tree_lookup_path
 from .objects import (
@@ -1165,7 +1166,48 @@ def reset(repo, mode, treeish="HEAD") -> None:
 
     with open_repo_closing(repo) as r:
         tree = parse_tree(r, treeish)
-        r.reset_index(tree.id)
+
+        # Get current HEAD tree for comparison
+        try:
+            current_head = r.refs[b"HEAD"]
+            current_tree = r[current_head].tree
+        except KeyError:
+            current_tree = None
+
+        # Get configuration for working directory update
+        config = r.get_config()
+        honor_filemode = config.get_boolean(b"core", b"filemode", os.name != "nt")
+
+        # Import validation functions
+        from .index import validate_path_element_default, validate_path_element_ntfs
+
+        if config.get_boolean(b"core", b"core.protectNTFS", os.name == "nt"):
+            validate_path_element = validate_path_element_ntfs
+        else:
+            validate_path_element = validate_path_element_default
+
+        # Import symlink function
+        from .index import symlink
+
+        if config.get_boolean(b"core", b"symlinks", True):
+            symlink_fn = symlink
+        else:
+
+            def symlink_fn(source, target) -> None:
+                mode = "w" + ("b" if isinstance(source, bytes) else "")
+                with open(target, mode) as f:
+                    f.write(source)
+
+        # Update working tree and index
+        update_working_tree(
+            r,
+            current_tree,
+            tree.id,
+            honor_filemode=honor_filemode,
+            validate_path_element=validate_path_element,
+            symlink_fn=symlink_fn,
+            force_remove_untracked=True,
+        )
 
 
 def get_remote_repo(
@@ -1281,6 +1323,7 @@ def pull(
     outstream=default_bytes_out_stream,
     errstream=default_bytes_err_stream,
     fast_forward=True,
+    ff_only=False,
     force=False,
     filter_spec=None,
     protocol_version=None,
@@ -1295,6 +1338,9 @@ def pull(
         bytestring/string.
       outstream: A stream file to write to output
       errstream: A stream file to write to errors
+      fast_forward: If True, raise an exception when fast-forward is not possible
+      ff_only: If True, only allow fast-forward merges. Raises DivergedBranches
+        when branches have diverged rather than performing a merge.
       filter_spec: A git-rev-list-style object filter spec, as an ASCII string.
         Only used if the server supports the Git protocol-v2 'filter'
         feature, and ignored otherwise.
@@ -1333,22 +1379,43 @@ def pull(
             filter_spec=filter_spec,
             protocol_version=protocol_version,
         )
+
+        # Store the old HEAD tree before making changes
+        try:
+            old_head = r.refs[b"HEAD"]
+            old_tree_id = r[old_head].tree
+        except KeyError:
+            old_tree_id = None
+
+        merged = False
         for lh, rh, force_ref in selected_refs:
             if not force_ref and rh in r.refs:
                 try:
                     check_diverged(r, r.refs.follow(rh)[1], fetch_result.refs[lh])
                 except DivergedBranches as exc:
-                    if fast_forward:
+                    if ff_only or fast_forward:
                         raise
                     else:
-                        raise NotImplementedError("merge is not yet supported") from exc
+                        # Perform merge
+                        merge_result, conflicts = _do_merge(r, fetch_result.refs[lh])
+                        if conflicts:
+                            raise Error(
+                                f"Merge conflicts occurred: {conflicts}"
+                            ) from exc
+                        merged = True
+                        # Skip updating ref since merge already updated HEAD
+                        continue
             r.refs[rh] = fetch_result.refs[lh]
-        if selected_refs:
+
+        # Only update HEAD if we didn't perform a merge
+        if selected_refs and not merged:
             r[b"HEAD"] = fetch_result.refs[selected_refs[0][1]]
 
-        # Perform 'git checkout .' - syncs staged changes
-        tree = r[b"HEAD"].tree
-        r.reset_index(tree=tree)
+        # Update working tree to match the new HEAD
+        # Skip if merge was performed as merge already updates the working tree
+        if not merged and old_tree_id is not None:
+            new_tree_id = r[b"HEAD"].tree
+            update_working_tree(r, old_tree_id, new_tree_id)
         if remote_name is not None:
             _import_remote_refs(r.refs, remote_name, fetch_result.refs)
 
@@ -2447,6 +2514,115 @@ def write_tree(repo):
         return r.open_index().commit(r.object_store)
 
 
+def _do_merge(
+    r,
+    merge_commit_id,
+    no_commit=False,
+    no_ff=False,
+    message=None,
+    author=None,
+    committer=None,
+):
+    """Internal merge implementation that operates on an open repository.
+
+    Args:
+      r: Open repository object
+      merge_commit_id: SHA of commit to merge
+      no_commit: If True, do not create a merge commit
+      no_ff: If True, force creation of a merge commit
+      message: Optional merge commit message
+      author: Optional author for merge commit
+      committer: Optional committer for merge commit
+
+    Returns:
+      Tuple of (merge_commit_sha, conflicts) where merge_commit_sha is None
+      if no_commit=True or there were conflicts
+    """
+    from .graph import find_merge_base
+    from .merge import three_way_merge
+
+    # Get HEAD commit
+    try:
+        head_commit_id = r.refs[b"HEAD"]
+    except KeyError:
+        raise Error("No HEAD reference found")
+
+    head_commit = r[head_commit_id]
+    merge_commit = r[merge_commit_id]
+
+    # Check if fast-forward is possible
+    merge_bases = find_merge_base(r, [head_commit_id, merge_commit_id])
+
+    if not merge_bases:
+        raise Error("No common ancestor found")
+
+    # Use the first merge base
+    base_commit_id = merge_bases[0]
+
+    # Check for fast-forward
+    if base_commit_id == head_commit_id and not no_ff:
+        # Fast-forward merge
+        r.refs[b"HEAD"] = merge_commit_id
+        # Update the working directory
+        update_working_tree(r, head_commit.tree, merge_commit.tree)
+        return (merge_commit_id, [])
+
+    if base_commit_id == merge_commit_id:
+        # Already up to date
+        return (None, [])
+
+    # Perform three-way merge
+    base_commit = r[base_commit_id]
+    merged_tree, conflicts = three_way_merge(
+        r.object_store, base_commit, head_commit, merge_commit
+    )
+
+    # Add merged tree to object store
+    r.object_store.add_object(merged_tree)
+
+    # Update index and working directory
+    update_working_tree(r, head_commit.tree, merged_tree.id)
+
+    if conflicts or no_commit:
+        # Don't create a commit if there are conflicts or no_commit is True
+        return (None, conflicts)
+
+    # Create merge commit
+    merge_commit_obj = Commit()
+    merge_commit_obj.tree = merged_tree.id
+    merge_commit_obj.parents = [head_commit_id, merge_commit_id]
+
+    # Set author/committer
+    if author is None:
+        author = get_user_identity(r.get_config_stack())
+    if committer is None:
+        committer = author
+
+    merge_commit_obj.author = author
+    merge_commit_obj.committer = committer
+
+    # Set timestamps
+    timestamp = int(time.time())
+    timezone = 0  # UTC
+    merge_commit_obj.author_time = timestamp
+    merge_commit_obj.author_timezone = timezone
+    merge_commit_obj.commit_time = timestamp
+    merge_commit_obj.commit_timezone = timezone
+
+    # Set commit message
+    if message is None:
+        message = f"Merge commit '{merge_commit_id.decode()[:7]}'\n"
+    merge_commit_obj.message = message.encode() if isinstance(message, str) else message
+
+    # Add commit to object store
+    r.object_store.add_object(merge_commit_obj)
+
+    # Update HEAD
+    r.refs[b"HEAD"] = merge_commit_obj.id
+
+    return (merge_commit_obj.id, [])
+
+
 def merge(
     repo,
     committish,
@@ -2474,101 +2650,13 @@ def merge(
     Raises:
       Error: If there is no HEAD reference or commit cannot be found
     """
-    from .graph import find_merge_base
-    from .index import build_index_from_tree
-    from .merge import three_way_merge
-
     with open_repo_closing(repo) as r:
-        # Get HEAD commit
-        try:
-            head_commit_id = r.refs[b"HEAD"]
-        except KeyError:
-            raise Error("No HEAD reference found")
-
         # Parse the commit to merge
         try:
             merge_commit_id = parse_commit(r, committish)
         except KeyError:
             raise Error(f"Cannot find commit '{committish}'")
 
-        head_commit = r[head_commit_id]
-        merge_commit = r[merge_commit_id]
-
-        # Check if fast-forward is possible
-        merge_bases = find_merge_base(r, [head_commit_id, merge_commit_id])
-
-        if not merge_bases:
-            raise Error("No common ancestor found")
-
-        # Use the first merge base
-        base_commit_id = merge_bases[0]
-
-        # Check for fast-forward
-        if base_commit_id == head_commit_id and not no_ff:
-            # Fast-forward merge
-            r.refs[b"HEAD"] = merge_commit_id
-            # Update the working directory
-            index = r.open_index()
-            tree = r[merge_commit.tree]
-            build_index_from_tree(r.path, index, r.object_store, tree.id)
-            index.write()
-            return (merge_commit_id, [])
-
-        if base_commit_id == merge_commit_id:
-            # Already up to date
-            return (None, [])
-
-        # Perform three-way merge
-        base_commit = r[base_commit_id]
-        merged_tree, conflicts = three_way_merge(
-            r.object_store, base_commit, head_commit, merge_commit
+        return _do_merge(
+            r, merge_commit_id, no_commit, no_ff, message, author, committer
         )
-
-        # Add merged tree to object store
-        r.object_store.add_object(merged_tree)
-
-        # Update index and working directory
-        index = r.open_index()
-        build_index_from_tree(r.path, index, r.object_store, merged_tree.id)
-        index.write()
-
-        if conflicts or no_commit:
-            # Don't create a commit if there are conflicts or no_commit is True
-            return (None, conflicts)
-
-        # Create merge commit
-        merge_commit_obj = Commit()
-        merge_commit_obj.tree = merged_tree.id
-        merge_commit_obj.parents = [head_commit_id, merge_commit_id]
-
-        # Set author/committer
-        if author is None:
-            author = get_user_identity(r.get_config_stack())
-        if committer is None:
-            committer = author
-
-        merge_commit_obj.author = author
-        merge_commit_obj.committer = committer
-
-        # Set timestamps
-        timestamp = int(time())
-        timezone = 0  # UTC
-        merge_commit_obj.author_time = timestamp
-        merge_commit_obj.author_timezone = timezone
-        merge_commit_obj.commit_time = timestamp
-        merge_commit_obj.commit_timezone = timezone
-
-        # Set commit message
-        if message is None:
-            message = f"Merge commit '{merge_commit_id.decode()[:7]}'\n"
-        merge_commit_obj.message = (
-            message.encode() if isinstance(message, str) else message
-        )
-
-        # Add commit to object store
-        r.object_store.add_object(merge_commit_obj)
-
-        # Update HEAD
-        r.refs[b"HEAD"] = merge_commit_obj.id
-
-        return (merge_commit_obj.id, [])

+ 1 - 3
tests/test_cli_merge.py

@@ -94,9 +94,7 @@ class CLIMergeTests(TestCase):
             with open(os.path.join(tmpdir, "file1.txt"), "w") as f:
                 f.write("Feature content\n")
             porcelain.add(tmpdir, paths=["file1.txt"])
-            porcelain.commit(
-                tmpdir, message=b"Modify file1 in feature"
-            )
+            porcelain.commit(tmpdir, message=b"Modify file1 in feature")
 
             # Go back to master and modify file1 differently
             porcelain.checkout_branch(tmpdir, "master")

+ 51 - 5
tests/test_porcelain.py

@@ -2476,9 +2476,8 @@ class PullTests(PorcelainTestCase):
         with Repo(self.target_path) as r:
             self.assertEqual(r[b"refs/heads/master"].id, c3a)
 
-        self.assertRaises(
-            NotImplementedError,
-            porcelain.pull,
+        # Pull with merge should now work
+        porcelain.pull(
             self.target_path,
             self.repo.path,
             b"refs/heads/master",
@@ -2487,9 +2486,16 @@ class PullTests(PorcelainTestCase):
             fast_forward=False,
         )
 
-        # Check the target repo for pushed changes
+        # Check the target repo for merged changes
         with Repo(self.target_path) as r:
-            self.assertEqual(r[b"refs/heads/master"].id, c3a)
+            # HEAD should now be a merge commit
+            head = r[b"HEAD"]
+            # It should have two parents
+            self.assertEqual(len(head.parents), 2)
+            # One parent should be the previous HEAD (c3a)
+            self.assertIn(c3a, head.parents)
+            # The other parent should be from the source repo
+            self.assertIn(self.repo[b"HEAD"].id, head.parents)
 
     def test_no_refspec(self) -> None:
         outstream = BytesIO()
@@ -2523,6 +2529,46 @@ class PullTests(PorcelainTestCase):
         with Repo(self.target_path) as r:
             self.assertEqual(r[b"HEAD"].id, self.repo[b"HEAD"].id)
 
+    def test_pull_updates_working_tree(self) -> None:
+        """Test that pull updates the working tree with new files."""
+        outstream = BytesIO()
+        errstream = BytesIO()
+
+        # Create a new file with content in the source repo
+        new_file = os.path.join(self.repo.path, "newfile.txt")
+        with open(new_file, "w") as f:
+            f.write("This is new content")
+
+        porcelain.add(repo=self.repo.path, paths=[new_file])
+        porcelain.commit(
+            repo=self.repo.path,
+            message=b"Add new file",
+            author=b"test <email>",
+            committer=b"test <email>",
+        )
+
+        # Before pull, the file should not exist in target
+        target_file = os.path.join(self.target_path, "newfile.txt")
+        self.assertFalse(os.path.exists(target_file))
+
+        # Pull changes into the cloned repo
+        porcelain.pull(
+            self.target_path,
+            self.repo.path,
+            b"refs/heads/master",
+            outstream=outstream,
+            errstream=errstream,
+        )
+
+        # After pull, the file should exist with correct content
+        self.assertTrue(os.path.exists(target_file))
+        with open(target_file) as f:
+            self.assertEqual(f.read(), "This is new content")
+
+        # Check the HEAD is updated too
+        with Repo(self.target_path) as r:
+            self.assertEqual(r[b"HEAD"].id, self.repo[b"HEAD"].id)
+
 
 class StatusTests(PorcelainTestCase):
     def test_empty(self) -> None:

+ 2 - 6
tests/test_porcelain_merge.py

@@ -143,9 +143,7 @@ class PorcelainMergeTests(TestCase):
             with open(os.path.join(tmpdir, "file1.txt"), "w") as f:
                 f.write("Feature content\n")
             porcelain.add(tmpdir, paths=["file1.txt"])
-            porcelain.commit(
-                tmpdir, message=b"Modify file1 in feature"
-            )
+            porcelain.commit(tmpdir, message=b"Modify file1 in feature")
 
             # Go back to master and modify file2
             porcelain.checkout_branch(tmpdir, "master")
@@ -184,9 +182,7 @@ class PorcelainMergeTests(TestCase):
             with open(os.path.join(tmpdir, "file1.txt"), "w") as f:
                 f.write("Feature content\n")
             porcelain.add(tmpdir, paths=["file1.txt"])
-            porcelain.commit(
-                tmpdir, message=b"Modify file1 in feature"
-            )
+            porcelain.commit(tmpdir, message=b"Modify file1 in feature")
 
             # Go back to master and modify file1 differently
             porcelain.checkout_branch(tmpdir, "master")