Explorar o código

Add -a argument to commit (#1701)

Jelmer Vernooij hai 1 mes
pai
achega
91b8e74651
Modificáronse 5 ficheiros con 226 adicións e 41 borrados
  1. 3 0
      NEWS
  2. 7 1
      dulwich/cli.py
  3. 52 40
      dulwich/porcelain.py
  4. 86 0
      tests/test_cli.py
  5. 78 0
      tests/test_porcelain.py

+ 3 - 0
NEWS

@@ -1,5 +1,8 @@
 0.23.3	UNRELEASED
 
+ * Add support for ``-a`` argument to
+   ``dulwich.cli.commit``. (Jelmer Vernooij)
+
  * Add support for merge drivers.
    (Jelmer Vernooij)
 

+ 7 - 1
dulwich/cli.py

@@ -463,8 +463,14 @@ class cmd_commit(Command):
     def run(self, args) -> None:
         parser = argparse.ArgumentParser()
         parser.add_argument("--message", "-m", required=True, help="Commit message")
+        parser.add_argument(
+            "-a",
+            "--all",
+            action="store_true",
+            help="Automatically stage all tracked files that have been modified",
+        )
         args = parser.parse_args(args)
-        porcelain.commit(".", message=args.message)
+        porcelain.commit(".", message=args.message, all=args.all)
 
 
 class cmd_commit_tree(Command):

+ 52 - 40
dulwich/porcelain.py

@@ -78,6 +78,7 @@ Functions should generally accept both unicode strings and bytestrings
 
 import datetime
 import fnmatch
+import logging
 import os
 import posixpath
 import stat
@@ -473,6 +474,7 @@ def commit(
     encoding=None,
     no_verify=False,
     signoff=False,
+    all=False,
 ):
     """Create a new commit.
 
@@ -487,9 +489,9 @@ def commit(
       signoff: GPG Sign the commit (bool, defaults to False,
         pass True to use default GPG key,
         pass a str containing Key ID to use a specific GPG key)
+      all: Automatically stage all tracked files that have been modified
     Returns: SHA1 of the new commit
     """
-    # FIXME: Support --all argument
     if getattr(message, "encode", None):
         message = message.encode(encoding or DEFAULT_ENCODING)
     if getattr(author, "encode", None):
@@ -501,7 +503,27 @@ def commit(
         author_timezone = local_timezone[0]
     if commit_timezone is None:
         commit_timezone = local_timezone[1]
+
     with open_repo_closing(repo) as r:
+        # If -a flag is used, stage all modified tracked files
+        if all:
+            index = r.open_index()
+            normalizer = r.get_blob_normalizer()
+            filter_callback = normalizer.checkin_normalize
+            unstaged_changes = list(
+                get_unstaged_changes(index, r.path, filter_callback)
+            )
+
+            if unstaged_changes:
+                # Convert bytes paths to strings for add function
+                modified_files = []
+                for path in unstaged_changes:
+                    if isinstance(path, bytes):
+                        path = path.decode()
+                    modified_files.append(path)
+
+                add(r, paths=modified_files)
+
         return r.do_commit(
             message=message,
             author=author,
@@ -644,13 +666,9 @@ def clone(
             submodule_update(repo, init=True)
         except FileNotFoundError as e:
             # .gitmodules file doesn't exist - no submodules to process
-            import logging
-
             logging.debug("No .gitmodules file found: %s", e)
         except KeyError as e:
             # Submodule configuration missing
-            import logging
-
             logging.warning("Submodule configuration error: %s", e)
             if errstream:
                 errstream.write(
@@ -3130,7 +3148,10 @@ def checkout(
                         r.object_store.__getitem__, path
                     )
                     obj = r[sha]
-
+                except KeyError:
+                    # Path doesn't exist in target tree
+                    pass
+                else:
                     # Create directories if needed
                     # Handle path as string
                     if isinstance(path, bytes):
@@ -3156,10 +3177,6 @@ def checkout(
                     # Update the index
                     r.stage(path)
 
-                except KeyError:
-                    # Path doesn't exist in target tree
-                    pass
-
             return
 
         # Normal checkout (switching branches/commits)
@@ -3212,14 +3229,15 @@ def checkout(
 
                     try:
                         target_tree.lookup_path(r.object_store.__getitem__, change)
+                    except KeyError:
+                        # File doesn't exist in target tree - change can be preserved
+                        pass
+                    else:
                         # File exists in target tree - would overwrite local changes
                         raise CheckoutError(
                             f"Your local changes to '{change.decode()}' would be "
                             "overwritten by checkout. Please commit or stash before switching."
                         )
-                    except KeyError:
-                        # File doesn't exist in target tree - change can be preserved
-                        pass
 
         # Get configuration for working directory update
         config = r.get_config()
@@ -4046,12 +4064,9 @@ def cherry_pick(
         parent_commit = r[cherry_pick_commit.parents[0]]
 
         # Perform three-way merge
-        try:
-            merged_tree, conflicts = three_way_merge(
-                r.object_store, parent_commit, head_commit, cherry_pick_commit
-            )
-        except Exception as e:
-            raise Error(f"Cherry-pick failed: {e}")
+        merged_tree, conflicts = three_way_merge(
+            r.object_store, parent_commit, head_commit, cherry_pick_commit
+        )
 
         # Add merged tree to object store
         r.object_store.add_object(merged_tree)
@@ -5339,7 +5354,10 @@ def lfs_fetch(repo=".", remote="origin", refs=None):
             for entry in r.object_store.iter_tree_contents(commit.tree):
                 try:
                     obj = r.object_store[entry.sha]
-                    if obj.type_name == b"blob":
+                except KeyError:
+                    pass
+                else:
+                    if isinstance(obj, Blob):
                         pointer = LFSPointer.from_bytes(obj.data)
                         if pointer and pointer.is_valid_oid():
                             # Check if we already have it
@@ -5347,19 +5365,13 @@ def lfs_fetch(repo=".", remote="origin", refs=None):
                                 store.open_object(pointer.oid)
                             except KeyError:
                                 pointers_to_fetch.append((pointer.oid, pointer.size))
-                except KeyError:
-                    pass
 
         # Fetch missing objects
         fetched = 0
         for oid, size in pointers_to_fetch:
-            try:
-                content = client.download(oid, size)
-                store.write_object([content])
-                fetched += 1
-            except Exception as e:
-                # Log error but continue
-                print(f"Failed to fetch {oid}: {e}")
+            content = client.download(oid, size)
+            store.write_object([content])
+            fetched += 1
 
         return fetched
 
@@ -5466,12 +5478,13 @@ def lfs_push(repo=".", remote="origin", refs=None):
             for entry in r.object_store.iter_tree_contents(commit.tree):
                 try:
                     obj = r.object_store[entry.sha]
-                    if obj.type_name == b"blob":
+                except KeyError:
+                    pass
+                else:
+                    if isinstance(obj, Blob):
                         pointer = LFSPointer.from_bytes(obj.data)
                         if pointer and pointer.is_valid_oid():
                             objects_to_push.add((pointer.oid, pointer.size))
-                except KeyError:
-                    pass
 
         # Push objects
         pushed = 0
@@ -5479,14 +5492,12 @@ def lfs_push(repo=".", remote="origin", refs=None):
             try:
                 with store.open_object(oid) as f:
                     content = f.read()
-                client.upload(oid, size, content)
-                pushed += 1
             except KeyError:
                 # Object not in local store
-                print(f"Warning: LFS object {oid} not found locally")
-            except Exception as e:
-                # Log error but continue
-                print(f"Failed to push {oid}: {e}")
+                logging.warn("LFS object %s not found locally", oid)
+            else:
+                client.upload(oid, size, content)
+                pushed += 1
 
         return pushed
 
@@ -5536,11 +5547,12 @@ def lfs_status(repo="."):
                     # Check if file has been modified
                     try:
                         staged_obj = r.object_store[entry.binsha]
+                    except KeyError:
+                        pass
+                    else:
                         staged_pointer = LFSPointer.from_bytes(staged_obj.data)
                         if staged_pointer and staged_pointer.oid != pointer.oid:
                             status["not_staged"].append(path_str)
-                    except KeyError:
-                        pass
 
         # TODO: Check for not committed and not pushed files
 

+ 86 - 0
tests/test_cli.py

@@ -179,6 +179,92 @@ class CommitCommandTest(DulwichCliTestCase):
         # Check that HEAD points to a commit
         self.assertIsNotNone(self.repo.head())
 
+    def test_commit_all_flag(self):
+        # Create initial commit
+        test_file = os.path.join(self.repo_path, "test.txt")
+        with open(test_file, "w") as f:
+            f.write("initial content")
+        self._run_cli("add", "test.txt")
+        self._run_cli("commit", "--message=Initial commit")
+
+        # Modify the file (don't stage it)
+        with open(test_file, "w") as f:
+            f.write("modified content")
+
+        # Create another file and don't add it (untracked)
+        untracked_file = os.path.join(self.repo_path, "untracked.txt")
+        with open(untracked_file, "w") as f:
+            f.write("untracked content")
+
+        # Commit with -a flag should stage and commit the modified file,
+        # but not the untracked file
+        result, stdout, stderr = self._run_cli(
+            "commit", "-a", "--message=Modified commit"
+        )
+        self.assertIsNotNone(self.repo.head())
+
+        # Check that the modification was committed
+        with open(test_file) as f:
+            content = f.read()
+        self.assertEqual(content, "modified content")
+
+        # Check that untracked file is still untracked
+        self.assertTrue(os.path.exists(untracked_file))
+
+    def test_commit_all_flag_no_changes(self):
+        # Create initial commit
+        test_file = os.path.join(self.repo_path, "test.txt")
+        with open(test_file, "w") as f:
+            f.write("initial content")
+        self._run_cli("add", "test.txt")
+        self._run_cli("commit", "--message=Initial commit")
+
+        # Try to commit with -a when there are no changes
+        # This should still work (git allows this)
+        result, stdout, stderr = self._run_cli(
+            "commit", "-a", "--message=No changes commit"
+        )
+        self.assertIsNotNone(self.repo.head())
+
+    def test_commit_all_flag_multiple_files(self):
+        # Create initial commit with multiple files
+        file1 = os.path.join(self.repo_path, "file1.txt")
+        file2 = os.path.join(self.repo_path, "file2.txt")
+
+        with open(file1, "w") as f:
+            f.write("content1")
+        with open(file2, "w") as f:
+            f.write("content2")
+
+        self._run_cli("add", "file1.txt", "file2.txt")
+        self._run_cli("commit", "--message=Initial commit")
+
+        # Modify both files
+        with open(file1, "w") as f:
+            f.write("modified content1")
+        with open(file2, "w") as f:
+            f.write("modified content2")
+
+        # Create an untracked file
+        untracked_file = os.path.join(self.repo_path, "untracked.txt")
+        with open(untracked_file, "w") as f:
+            f.write("untracked content")
+
+        # Commit with -a should stage both modified files but not untracked
+        result, stdout, stderr = self._run_cli(
+            "commit", "-a", "--message=Modified both files"
+        )
+        self.assertIsNotNone(self.repo.head())
+
+        # Verify modifications were committed
+        with open(file1) as f:
+            self.assertEqual(f.read(), "modified content1")
+        with open(file2) as f:
+            self.assertEqual(f.read(), "modified content2")
+
+        # Verify untracked file still exists
+        self.assertTrue(os.path.exists(untracked_file))
+
 
 class LogCommandTest(DulwichCliTestCase):
     """Tests for log command."""

+ 78 - 0
tests/test_porcelain.py

@@ -462,6 +462,84 @@ class CommitTests(PorcelainTestCase):
         self.assertEqual(commit._author_timezone, local_timezone)
         self.assertEqual(commit._commit_timezone, local_timezone)
 
+    def test_commit_all(self) -> None:
+        # Create initial commit
+        filename = os.path.join(self.repo.path, "test.txt")
+        with open(filename, "wb") as f:
+            f.write(b"initial content")
+        porcelain.add(self.repo.path, paths=["test.txt"])
+        initial_sha = porcelain.commit(self.repo.path, message=b"Initial commit")
+
+        # Modify the file without staging
+        with open(filename, "wb") as f:
+            f.write(b"modified content")
+
+        # Create an untracked file
+        untracked_file = os.path.join(self.repo.path, "untracked.txt")
+        with open(untracked_file, "wb") as f:
+            f.write(b"untracked content")
+
+        # Commit with all=True should stage modified files but not untracked
+        sha = porcelain.commit(self.repo.path, message=b"Modified commit", all=True)
+        self.assertIsInstance(sha, bytes)
+        self.assertEqual(len(sha), 40)
+        self.assertNotEqual(sha, initial_sha)
+
+        # Verify the commit contains the modification
+        commit = self.repo.get_object(sha)
+        assert isinstance(commit, Commit)
+        tree = self.repo.get_object(commit.tree)
+        # The modified file should be in the commit
+        self.assertIn(b"test.txt", tree)
+        # The untracked file should not be in the commit
+        self.assertNotIn(b"untracked.txt", tree)
+
+    def test_commit_all_no_changes(self) -> None:
+        # Create initial commit
+        filename = os.path.join(self.repo.path, "test.txt")
+        with open(filename, "wb") as f:
+            f.write(b"initial content")
+        porcelain.add(self.repo.path, paths=["test.txt"])
+        initial_sha = porcelain.commit(self.repo.path, message=b"Initial commit")
+
+        # Try to commit with all=True when there are no unstaged changes
+        sha = porcelain.commit(self.repo.path, message=b"No changes commit", all=True)
+        self.assertIsInstance(sha, bytes)
+        self.assertEqual(len(sha), 40)
+        self.assertNotEqual(sha, initial_sha)
+
+    def test_commit_all_multiple_files(self) -> None:
+        # Create initial commit with multiple files
+        file1 = os.path.join(self.repo.path, "file1.txt")
+        file2 = os.path.join(self.repo.path, "file2.txt")
+
+        with open(file1, "wb") as f:
+            f.write(b"content1")
+        with open(file2, "wb") as f:
+            f.write(b"content2")
+
+        porcelain.add(self.repo.path, paths=["file1.txt", "file2.txt"])
+        initial_sha = porcelain.commit(self.repo.path, message=b"Initial commit")
+
+        # Modify both files
+        with open(file1, "wb") as f:
+            f.write(b"modified content1")
+        with open(file2, "wb") as f:
+            f.write(b"modified content2")
+
+        # Commit with all=True should stage both modified files
+        sha = porcelain.commit(self.repo.path, message=b"Modified both files", all=True)
+        self.assertIsInstance(sha, bytes)
+        self.assertEqual(len(sha), 40)
+        self.assertNotEqual(sha, initial_sha)
+
+        # Verify both modifications are in the commit
+        commit = self.repo.get_object(sha)
+        assert isinstance(commit, Commit)
+        tree = self.repo.get_object(commit.tree)
+        self.assertIn(b"file1.txt", tree)
+        self.assertIn(b"file2.txt", tree)
+
 
 @skipIf(
     platform.python_implementation() == "PyPy" or sys.platform == "win32",