Преглед изворни кода

add --contains to branch command (#1875)

Implements the --contains flag for the dulwich branch command, 
which lists branches that contain a specific commit.
https://github.com/jelmer/dulwich/issues/1847
xifOO пре 4 месеци
родитељ
комит
59ed3ab7cf
4 измењених фајлова са 275 додато и 4 уклоњено
  1. 24 0
      dulwich/cli.py
  2. 24 4
      dulwich/porcelain.py
  3. 66 0
      tests/test_cli.py
  4. 161 0
      tests/test_porcelain.py

+ 24 - 0
dulwich/cli.py

@@ -2070,6 +2070,12 @@ class cmd_branch(Command):
         parser.add_argument(
             "--remotes", action="store_true", help="List remotes branches"
         )
+        parser.add_argument(
+            "--contains",
+            nargs="?",
+            const="HEAD",
+            help="List branches that contain a specific commit",
+        )
         args = parser.parse_args(args)
 
         if args.all:
@@ -2082,6 +2088,7 @@ class cmd_branch(Command):
                     sys.stdout.write(f"{branch.decode()}\n")
 
                 return 0
+
             except porcelain.Error as e:
                 sys.stderr.write(f"{e}")
                 return 1
@@ -2110,6 +2117,23 @@ class cmd_branch(Command):
                 sys.stderr.write(f"{e}")
                 return 1
 
+        if args.contains:
+            try:
+                branches_iter = porcelain.branches_containing(".", commit=args.contains)
+
+                for branch in branches_iter:
+                    sys.stdout.write(f"{branch.decode()}\n")
+
+                return 0
+
+            except KeyError as e:
+                sys.stderr.write(f"error: object name {e.args[0].decode()} not found\n")
+                return 1
+
+            except porcelain.Error as e:
+                sys.stderr.write(f"{e}")
+                return 1
+
         if args.remotes:
             try:
                 branches = porcelain.branch_remotes_list(".")

+ 24 - 4
dulwich/porcelain.py

@@ -3321,10 +3321,7 @@ def _get_branch_merge_status(repo: RepoPath) -> Iterator[tuple[bytes, bool]]:
     with open_repo_closing(repo) as r:
         current_sha = r.refs[b"HEAD"]
 
-        for branch_ref in r.refs.keys(base=b"refs/heads/"):
-            full_ref = b"refs/heads/" + branch_ref
-            branch_sha = r.refs[full_ref]
-
+        for branch_ref, branch_sha in r.refs.as_dict(base=b"refs/heads/").items():
             # Check if branch is an ancestor of HEAD (fully merged)
             is_merged = can_fast_forward(r, branch_sha, current_sha)
             yield branch_ref, is_merged
@@ -3358,6 +3355,29 @@ def no_merged_branches(repo: RepoPath) -> Iterator[bytes]:
             yield branch_name
 
 
+def branches_containing(repo: RepoPath, commit: str) -> Iterator[bytes]:
+    """List branches that contain the specified commit.
+
+    Args:
+        repo: Path to the repository
+        commit: Commit-ish string (SHA, branch name, tag, etc.)
+
+    Yields:
+        Branch names (without refs/heads/ prefix) that contain the commit
+
+    Raises:
+        ValueError: If the commit reference is malformed
+        KeyError: If the commit reference does not exist
+    """
+    with open_repo_closing(repo) as r:
+        commit_obj = parse_commit(r, commit)
+        commit_sha = commit_obj.id
+
+        for branch_ref, branch_sha in r.refs.as_dict(base=LOCAL_BRANCH_PREFIX).items():
+            if can_fast_forward(r, commit_sha, branch_sha):
+                yield branch_ref
+
+
 def active_branch(repo: RepoPath) -> bytes:
     """Return the active branch in the repository, if any.
 

+ 66 - 0
tests/test_cli.py

@@ -563,6 +563,72 @@ class BranchCommandTest(DulwichCliTestCase):
 
         self.assertEqual(branches, expected_branches)
 
+    def test_branch_list_contains(self):
+        # Create initial commit
+        test_file = os.path.join(self.repo_path, "test.txt")
+        with open(test_file, "w") as f:
+            f.write("test")
+        self._run_cli("add", "test.txt")
+        self._run_cli("commit", "--message=Initial")
+
+        initial_commit_sha = self.repo.refs[b"HEAD"]
+
+        # Create first branch from initial commit
+        self._run_cli("branch", "branch-1")
+
+        # Make a new commit on master
+        test_file2 = os.path.join(self.repo_path, "test2.txt")
+        with open(test_file2, "w") as f:
+            f.write("test2")
+        self._run_cli("add", "test2.txt")
+        self._run_cli("commit", "--message=Second commit")
+
+        second_commit_sha = self.repo.refs[b"HEAD"]
+
+        # Create second branch from current master (contains both commits)
+        self._run_cli("branch", "branch-2")
+
+        # Create third branch that doesn't contain the second commit
+        # Switch to initial commit and create branch from there
+        self.repo.refs[b"HEAD"] = initial_commit_sha
+        self._run_cli("branch", "branch-3")
+
+        # Switch back to master
+        self.repo.refs[b"HEAD"] = second_commit_sha
+
+        # Test --contains with second commit (should include master and branch-2)
+        result, stdout, stderr = self._run_cli(
+            "branch", "--contains", second_commit_sha.decode()
+        )
+        self.assertEqual(result, 0)
+
+        branches = [line.strip() for line in stdout.splitlines()]
+        expected_branches = {"master", "branch-2"}
+        self.assertEqual(set(branches), expected_branches)
+
+        # Test --contains with initial commit (should include all branches)
+        result, stdout, stderr = self._run_cli(
+            "branch", "--contains", initial_commit_sha.decode()
+        )
+        self.assertEqual(result, 0)
+
+        branches = [line.strip() for line in stdout.splitlines()]
+        expected_branches = {"master", "branch-1", "branch-2", "branch-3"}
+        self.assertEqual(set(branches), expected_branches)
+
+        # Test --contains without argument (uses HEAD, which is second commit)
+        result, stdout, stderr = self._run_cli("branch", "--contains")
+        self.assertEqual(result, 0)
+
+        branches = [line.strip() for line in stdout.splitlines()]
+        expected_branches = {"master", "branch-2"}
+        self.assertEqual(set(branches), expected_branches)
+
+        # Test with invalid commit hash
+        result, stdout, stderr = self._run_cli("branch", "--contains", "invalid123")
+        self.assertNotEqual(result, 0)
+        self.assertIn("error: object name invalid123 not found", stderr)
+
 
 class CheckoutCommandTest(DulwichCliTestCase):
     """Tests for checkout command."""

+ 161 - 0
tests/test_porcelain.py

@@ -6877,6 +6877,167 @@ class BranchNoMergedTests(PorcelainTestCase):
         self.assertEqual([], result)
 
 
+class BranchContainsTests(PorcelainTestCase):
+    def test_commit_in_single_branch(self) -> None:
+        """Test commit contained in only one branch."""
+        # Create: c1 → c2 (master), c1 → c3 (feature)
+        [c1, c2, c3] = build_commit_graph(
+            self.repo.object_store,
+            [
+                [1],  # c1
+                [2, 1],  # c2 → c1 (master line)
+                [3, 1],  # c3 → c1 (feature branch)
+            ],
+        )
+
+        self.repo.refs[b"HEAD"] = c2.id
+        self.repo.refs[b"refs/heads/master"] = c2.id
+        self.repo.refs[b"refs/heads/feature"] = c3.id
+
+        # c2 is only in master branch
+        result = list(porcelain.branches_containing(self.repo, c2.id.decode()))
+        self.assertEqual([b"master"], result)
+
+        # c3 is only in feature branch
+        result = list(porcelain.branches_containing(self.repo, c3.id.decode()))
+        self.assertEqual([b"feature"], result)
+
+    def test_commit_in_multiple_branches(self) -> None:
+        """Test commit contained in multiple branches."""
+        # Create: c1 → c2 → c3 (master), c2 → c4 (feature)
+        [c1, c2, c3, c4] = build_commit_graph(
+            self.repo.object_store,
+            [
+                [1],  # c1
+                [2, 1],  # c2 → c1
+                [3, 2],  # c3 → c2 (master)
+                [4, 2],  # c4 → c2 (feature)
+            ],
+        )
+
+        self.repo.refs[b"HEAD"] = c3.id
+        self.repo.refs[b"refs/heads/master"] = c3.id
+        self.repo.refs[b"refs/heads/feature"] = c4.id
+
+        # c2 is in both branches (common ancestor)
+        branches = list(porcelain.branches_containing(self.repo, c2.id.decode()))
+        expected = [b"master", b"feature"]
+        expected.sort()
+        branches.sort()
+        self.assertEqual(expected, branches)
+
+        # c1 is in both branches (older common ancestor)
+        branches = list(porcelain.branches_containing(self.repo, c1.id.decode()))
+        expected = [b"master", b"feature"]
+        expected.sort()
+        branches.sort()
+        self.assertEqual(expected, branches)
+
+    def test_commit_in_all_branches(self) -> None:
+        """Test commit contained in all branches."""
+        # Create linear history: c1 → c2 → c3 (HEAD/master)
+        [c1, c2, c3] = build_commit_graph(self.repo.object_store, [[1], [2, 1], [3, 2]])
+
+        self.repo.refs[b"HEAD"] = c3.id
+        self.repo.refs[b"refs/heads/master"] = c3.id
+        self.repo.refs[b"refs/heads/feature-1"] = c3.id  # Same as master
+        self.repo.refs[b"refs/heads/feature-2"] = c2.id  # Ancestor
+
+        # c1 is in all branches
+        branches = list(porcelain.branches_containing(self.repo, c1.id.decode()))
+        expected = [b"master", b"feature-1", b"feature-2"]
+        expected.sort()
+        branches.sort()
+        self.assertEqual(expected, branches)
+
+    def test_commit_in_no_branches(self) -> None:
+        """Test commit not contained in any branch."""
+        # Create: c1 → c2 (master), c1 → c3 (feature), orphan c4
+        [c1, c2, c3, c4] = build_commit_graph(
+            self.repo.object_store,
+            [
+                [1],  # c1
+                [2, 1],  # c2 → c1 (master)
+                [3, 1],  # c3 → c1 (feature)
+                [4],  # c4 (orphan commit)
+            ],
+        )
+
+        self.repo.refs[b"HEAD"] = c2.id
+        self.repo.refs[b"refs/heads/master"] = c2.id
+        self.repo.refs[b"refs/heads/feature"] = c3.id
+
+        # c4 is not in any branch
+        result = list(porcelain.branches_containing(self.repo, c4.id.decode()))
+        self.assertEqual([], result)
+
+    def test_commit_ref_by_branch_name(self) -> None:
+        """Test using branch name as commit reference."""
+        # Create: c1 → c2 (master), c1 → c3 (feature)
+        [c1, c2, c3] = build_commit_graph(
+            self.repo.object_store,
+            [
+                [1],  # c1
+                [2, 1],  # c2 → c1 (master)
+                [3, 1],  # c3 → c1 (feature)
+            ],
+        )
+
+        self.repo.refs[b"HEAD"] = c2.id
+        self.repo.refs[b"refs/heads/master"] = c2.id
+        self.repo.refs[b"refs/heads/feature"] = c3.id
+
+        # Use "master" as commit reference - should find master branch
+        result = list(porcelain.branches_containing(self.repo, "master"))
+        self.assertEqual([b"master"], result)
+
+        # Use "feature" as commit reference - should find feature branch
+        result = list(porcelain.branches_containing(self.repo, "feature"))
+        self.assertEqual([b"feature"], result)
+
+    def test_commit_ref_by_head(self) -> None:
+        """Test using HEAD as commit reference."""
+        # Create: c1 → c2 → c3 (HEAD/master)
+        [c1, c2, c3] = build_commit_graph(self.repo.object_store, [[1], [2, 1], [3, 2]])
+
+        self.repo.refs[b"HEAD"] = c3.id
+        self.repo.refs[b"refs/heads/master"] = c3.id
+        self.repo.refs[b"refs/heads/feature"] = c2.id  # Ancestor
+
+        # Use "HEAD" as commit reference
+        result = list(porcelain.branches_containing(self.repo, "HEAD"))
+        self.assertEqual([b"master"], result)
+
+    def test_invalid_commit_ref(self) -> None:
+        """Test with invalid commit reference."""
+        [c1] = build_commit_graph(self.repo.object_store, [[1]])
+        self.repo.refs[b"HEAD"] = c1.id
+        self.repo.refs[b"refs/heads/master"] = c1.id
+
+        # Test with non-existent commit
+        with self.assertRaises(KeyError) as cm:
+            list(porcelain.branches_containing(self.repo, "nonexistent"))
+        self.assertEqual(b"nonexistent", cm.exception.args[0])
+
+        # Test with invalid SHA
+        with self.assertRaises(KeyError) as cm:
+            list(porcelain.branches_containing(self.repo, "invalid-sha"))
+        self.assertEqual(b"invalid-sha", cm.exception.args[0])
+
+    def test_short_sha_reference(self) -> None:
+        """Test using short SHA as commit reference."""
+        # Create: c1 → c2 (master)
+        [c1, c2] = build_commit_graph(self.repo.object_store, [[1], [2, 1]])
+
+        self.repo.refs[b"HEAD"] = c2.id
+        self.repo.refs[b"refs/heads/master"] = c2.id
+
+        # Use short SHA (first 7 characters)
+        short_sha = c1.id.decode()[:7]
+        result = list(porcelain.branches_containing(self.repo, short_sha))
+        self.assertEqual([b"master"], result)
+
+
 class BranchCreateTests(PorcelainTestCase):
     def test_branch_exists(self) -> None:
         [c1] = build_commit_graph(self.repo.object_store, [[1]])