Przeglądaj źródła

Support branch..merge configuration options (#1648)

Jelmer Vernooij 1 miesiąc temu
rodzic
commit
ded2ea977f
4 zmienionych plików z 184 dodań i 0 usunięć
  1. 50 0
      dulwich/porcelain.py
  2. 27 0
      dulwich/refs.py
  3. 76 0
      tests/test_porcelain.py
  4. 31 0
      tests/test_refs.py

+ 50 - 0
dulwich/porcelain.py

@@ -2232,6 +2232,42 @@ def get_branch_remote(repo):
     return remote_name
 
 
+def get_branch_merge(repo, branch_name=None):
+    """Return the branch's merge reference (upstream branch), if any.
+
+    Args:
+      repo: Repository to open
+      branch_name: Name of the branch (defaults to active branch)
+
+    Returns:
+      merge reference name (e.g. b"refs/heads/main")
+
+    Raises:
+      KeyError: if the branch does not have a merge configuration
+    """
+    with open_repo_closing(repo) as r:
+        if branch_name is None:
+            branch_name = active_branch(r.path)
+        config = r.get_config()
+        return config.get((b"branch", branch_name), b"merge")
+
+
+def set_branch_tracking(repo, branch_name, remote_name, remote_ref):
+    """Set up branch tracking configuration.
+
+    Args:
+      repo: Repository to open
+      branch_name: Name of the local branch
+      remote_name: Name of the remote (e.g. b"origin")
+      remote_ref: Remote reference to track (e.g. b"refs/heads/main")
+    """
+    with open_repo_closing(repo) as r:
+        config = r.get_config()
+        config.set((b"branch", branch_name), b"remote", remote_name)
+        config.set((b"branch", branch_name), b"merge", remote_ref)
+        config.write_to_path()
+
+
 def fetch(
     repo,
     remote_location=None,
@@ -2698,6 +2734,20 @@ def checkout(
             # Create new branch and switch to it
             branch_create(r, new_branch, objectish=target_commit.id.decode("ascii"))
             update_head(r, new_branch)
+
+            # Set up tracking if creating from a remote branch
+            from .refs import LOCAL_REMOTE_PREFIX, parse_remote_ref
+
+            if target.startswith(LOCAL_REMOTE_PREFIX):
+                try:
+                    remote_name, branch_name = parse_remote_ref(target)
+                    # Set tracking to refs/heads/<branch> on the remote
+                    set_branch_tracking(
+                        r, new_branch, remote_name, b"refs/heads/" + branch_name
+                    )
+                except ValueError:
+                    # Invalid remote ref format, skip tracking setup
+                    pass
         else:
             # Check if target is a branch name (with or without refs/heads/ prefix)
             branch_ref = None

+ 27 - 0
dulwich/refs.py

@@ -102,6 +102,33 @@ def check_ref_format(refname: Ref) -> bool:
     return True
 
 
+def parse_remote_ref(ref: bytes) -> tuple[bytes, bytes]:
+    """Parse a remote ref into remote name and branch name.
+
+    Args:
+      ref: Remote ref like b"refs/remotes/origin/main"
+
+    Returns:
+      Tuple of (remote_name, branch_name)
+
+    Raises:
+      ValueError: If ref is not a valid remote ref
+    """
+    if not ref.startswith(LOCAL_REMOTE_PREFIX):
+        raise ValueError(f"Not a remote ref: {ref!r}")
+
+    # Remove the prefix
+    remainder = ref[len(LOCAL_REMOTE_PREFIX) :]
+
+    # Split into remote name and branch name
+    parts = remainder.split(b"/", 1)
+    if len(parts) != 2:
+        raise ValueError(f"Invalid remote ref format: {ref!r}")
+
+    remote_name, branch_name = parts
+    return (remote_name, branch_name)
+
+
 class RefsContainer:
     """A container for refs."""
 

+ 76 - 0
tests/test_porcelain.py

@@ -3229,6 +3229,44 @@ class CheckoutTests(PorcelainTestCase):
 
         target_repo.close()
 
+    def test_checkout_new_branch_from_remote_sets_tracking(self) -> None:
+        # Create a "remote" repository
+        remote_path = tempfile.mkdtemp()
+        self.addCleanup(shutil.rmtree, remote_path)
+        remote_repo = porcelain.init(remote_path)
+
+        # Add a commit to the remote
+        remote_sha, _ = _commit_file_with_content(
+            remote_repo, "bar", "remote content\n"
+        )
+
+        # Clone the remote repository
+        target_path = tempfile.mkdtemp()
+        self.addCleanup(shutil.rmtree, target_path)
+        target_repo = porcelain.clone(remote_path, target_path)
+
+        # Create a remote tracking branch reference
+        remote_branch_ref = b"refs/remotes/origin/feature"
+        target_repo.refs[remote_branch_ref] = remote_sha
+
+        # Checkout a new branch from the remote branch
+        porcelain.checkout(target_repo, remote_branch_ref, new_branch=b"local-feature")
+
+        # Verify the branch was created and is active
+        self.assertEqual(b"local-feature", porcelain.active_branch(target_repo))
+
+        # Verify tracking configuration was set
+        config = target_repo.get_config()
+        self.assertEqual(
+            b"origin", config.get((b"branch", b"local-feature"), b"remote")
+        )
+        self.assertEqual(
+            b"refs/heads/feature", config.get((b"branch", b"local-feature"), b"merge")
+        )
+
+        target_repo.close()
+        remote_repo.close()
+
 
 class GeneralCheckoutTests(PorcelainTestCase):
     """Tests for the general checkout function that handles branches, tags, and commits."""
@@ -5204,6 +5242,44 @@ class ActiveBranchTests(PorcelainTestCase):
         self.assertEqual(b"master", porcelain.active_branch(self.repo))
 
 
+class BranchTrackingTests(PorcelainTestCase):
+    def test_get_branch_merge(self) -> None:
+        # Set up branch tracking configuration
+        config = self.repo.get_config()
+        config.set((b"branch", b"master"), b"remote", b"origin")
+        config.set((b"branch", b"master"), b"merge", b"refs/heads/main")
+        config.write_to_path()
+
+        # Test getting merge ref for current branch
+        merge_ref = porcelain.get_branch_merge(self.repo)
+        self.assertEqual(b"refs/heads/main", merge_ref)
+
+        # Test getting merge ref for specific branch
+        merge_ref = porcelain.get_branch_merge(self.repo, b"master")
+        self.assertEqual(b"refs/heads/main", merge_ref)
+
+        # Test branch without merge config
+        with self.assertRaises(KeyError):
+            porcelain.get_branch_merge(self.repo, b"nonexistent")
+
+    def test_set_branch_tracking(self) -> None:
+        # Create a new branch
+        sha, _ = _commit_file_with_content(self.repo, "foo", "content\n")
+        porcelain.branch_create(self.repo, "feature")
+
+        # Set up tracking
+        porcelain.set_branch_tracking(
+            self.repo, b"feature", b"upstream", b"refs/heads/main"
+        )
+
+        # Verify configuration was written
+        config = self.repo.get_config()
+        self.assertEqual(b"upstream", config.get((b"branch", b"feature"), b"remote"))
+        self.assertEqual(
+            b"refs/heads/main", config.get((b"branch", b"feature"), b"merge")
+        )
+
+
 class FindUniqueAbbrevTests(PorcelainTestCase):
     def test_simple(self) -> None:
         c1, c2, c3 = build_commit_graph(

+ 31 - 0
tests/test_refs.py

@@ -36,6 +36,7 @@ from dulwich.refs import (
     SymrefLoop,
     _split_ref_line,
     check_ref_format,
+    parse_remote_ref,
     parse_symref_value,
     read_packed_refs,
     read_packed_refs_with_peeled,
@@ -894,6 +895,36 @@ class ParseSymrefValueTests(TestCase):
         self.assertRaises(ValueError, parse_symref_value, b"foobar")
 
 
+class ParseRemoteRefTests(TestCase):
+    def test_valid(self) -> None:
+        # Test simple case
+        remote, branch = parse_remote_ref(b"refs/remotes/origin/main")
+        self.assertEqual(b"origin", remote)
+        self.assertEqual(b"main", branch)
+
+        # Test with branch containing slashes
+        remote, branch = parse_remote_ref(b"refs/remotes/upstream/feature/new-ui")
+        self.assertEqual(b"upstream", remote)
+        self.assertEqual(b"feature/new-ui", branch)
+
+    def test_invalid_not_remote_ref(self) -> None:
+        # Not a remote ref
+        with self.assertRaises(ValueError) as cm:
+            parse_remote_ref(b"refs/heads/main")
+        self.assertIn("Not a remote ref", str(cm.exception))
+
+    def test_invalid_format(self) -> None:
+        # Missing branch name
+        with self.assertRaises(ValueError) as cm:
+            parse_remote_ref(b"refs/remotes/origin")
+        self.assertIn("Invalid remote ref format", str(cm.exception))
+
+        # Just the prefix
+        with self.assertRaises(ValueError) as cm:
+            parse_remote_ref(b"refs/remotes/")
+        self.assertIn("Invalid remote ref format", str(cm.exception))
+
+
 class StripPeeledRefsTests(TestCase):
     all_refs: ClassVar[dict[bytes, bytes]] = {
         b"refs/heads/master": b"8843d7f92416211de9ebb963ff4ce28125932878",