瀏覽代碼

Add support for branch.<name>.merge configuration

- Add get_branch_merge() to read upstream branch configuration
- Add set_branch_tracking() to set both remote and merge configs
- Enhance checkout() to automatically set up tracking when creating
  a new branch from a remote branch
- Add comprehensive tests for branch tracking functionality

This implements the branch.<name>.merge configuration variable.
Jelmer Vernooij 1 月之前
父節點
當前提交
1d89c0b435
共有 4 個文件被更改,包括 137 次插入11 次删除
  1. 50 0
      dulwich/porcelain.py
  2. 7 7
      dulwich/refs.py
  3. 76 0
      tests/test_porcelain.py
  4. 4 4
      tests/test_refs.py

+ 50 - 0
dulwich/porcelain.py

@@ -2232,6 +2232,42 @@ def get_branch_remote(repo):
     return remote_name
 
 
+def get_branch_merge(repo, branch_name=None):
+    """Return the branch's merge reference (upstream branch), if any.
+
+    Args:
+      repo: Repository to open
+      branch_name: Name of the branch (defaults to active branch)
+
+    Returns:
+      merge reference name (e.g. b"refs/heads/main")
+
+    Raises:
+      KeyError: if the branch does not have a merge configuration
+    """
+    with open_repo_closing(repo) as r:
+        if branch_name is None:
+            branch_name = active_branch(r.path)
+        config = r.get_config()
+        return config.get((b"branch", branch_name), b"merge")
+
+
+def set_branch_tracking(repo, branch_name, remote_name, remote_ref):
+    """Set up branch tracking configuration.
+
+    Args:
+      repo: Repository to open
+      branch_name: Name of the local branch
+      remote_name: Name of the remote (e.g. b"origin")
+      remote_ref: Remote reference to track (e.g. b"refs/heads/main")
+    """
+    with open_repo_closing(repo) as r:
+        config = r.get_config()
+        config.set((b"branch", branch_name), b"remote", remote_name)
+        config.set((b"branch", branch_name), b"merge", remote_ref)
+        config.write_to_path()
+
+
 def fetch(
     repo,
     remote_location=None,
@@ -2698,6 +2734,20 @@ def checkout(
             # Create new branch and switch to it
             branch_create(r, new_branch, objectish=target_commit.id.decode("ascii"))
             update_head(r, new_branch)
+
+            # Set up tracking if creating from a remote branch
+            from .refs import LOCAL_REMOTE_PREFIX, parse_remote_ref
+
+            if target.startswith(LOCAL_REMOTE_PREFIX):
+                try:
+                    remote_name, branch_name = parse_remote_ref(target)
+                    # Set tracking to refs/heads/<branch> on the remote
+                    set_branch_tracking(
+                        r, new_branch, remote_name, b"refs/heads/" + branch_name
+                    )
+                except ValueError:
+                    # Invalid remote ref format, skip tracking setup
+                    pass
         else:
             # Check if target is a branch name (with or without refs/heads/ prefix)
             branch_ref = None

+ 7 - 7
dulwich/refs.py

@@ -104,27 +104,27 @@ def check_ref_format(refname: Ref) -> bool:
 
 def parse_remote_ref(ref: bytes) -> tuple[bytes, bytes]:
     """Parse a remote ref into remote name and branch name.
-    
+
     Args:
       ref: Remote ref like b"refs/remotes/origin/main"
-      
+
     Returns:
       Tuple of (remote_name, branch_name)
-      
+
     Raises:
       ValueError: If ref is not a valid remote ref
     """
     if not ref.startswith(LOCAL_REMOTE_PREFIX):
         raise ValueError(f"Not a remote ref: {ref!r}")
-    
+
     # Remove the prefix
-    remainder = ref[len(LOCAL_REMOTE_PREFIX):]
-    
+    remainder = ref[len(LOCAL_REMOTE_PREFIX) :]
+
     # Split into remote name and branch name
     parts = remainder.split(b"/", 1)
     if len(parts) != 2:
         raise ValueError(f"Invalid remote ref format: {ref!r}")
-    
+
     remote_name, branch_name = parts
     return (remote_name, branch_name)
 

+ 76 - 0
tests/test_porcelain.py

@@ -3229,6 +3229,44 @@ class CheckoutTests(PorcelainTestCase):
 
         target_repo.close()
 
+    def test_checkout_new_branch_from_remote_sets_tracking(self) -> None:
+        # Create a "remote" repository
+        remote_path = tempfile.mkdtemp()
+        self.addCleanup(shutil.rmtree, remote_path)
+        remote_repo = porcelain.init(remote_path)
+
+        # Add a commit to the remote
+        remote_sha, _ = _commit_file_with_content(
+            remote_repo, "bar", "remote content\n"
+        )
+
+        # Clone the remote repository
+        target_path = tempfile.mkdtemp()
+        self.addCleanup(shutil.rmtree, target_path)
+        target_repo = porcelain.clone(remote_path, target_path)
+
+        # Create a remote tracking branch reference
+        remote_branch_ref = b"refs/remotes/origin/feature"
+        target_repo.refs[remote_branch_ref] = remote_sha
+
+        # Checkout a new branch from the remote branch
+        porcelain.checkout(target_repo, remote_branch_ref, new_branch=b"local-feature")
+
+        # Verify the branch was created and is active
+        self.assertEqual(b"local-feature", porcelain.active_branch(target_repo))
+
+        # Verify tracking configuration was set
+        config = target_repo.get_config()
+        self.assertEqual(
+            b"origin", config.get((b"branch", b"local-feature"), b"remote")
+        )
+        self.assertEqual(
+            b"refs/heads/feature", config.get((b"branch", b"local-feature"), b"merge")
+        )
+
+        target_repo.close()
+        remote_repo.close()
+
 
 class GeneralCheckoutTests(PorcelainTestCase):
     """Tests for the general checkout function that handles branches, tags, and commits."""
@@ -5204,6 +5242,44 @@ class ActiveBranchTests(PorcelainTestCase):
         self.assertEqual(b"master", porcelain.active_branch(self.repo))
 
 
+class BranchTrackingTests(PorcelainTestCase):
+    def test_get_branch_merge(self) -> None:
+        # Set up branch tracking configuration
+        config = self.repo.get_config()
+        config.set((b"branch", b"master"), b"remote", b"origin")
+        config.set((b"branch", b"master"), b"merge", b"refs/heads/main")
+        config.write_to_path()
+
+        # Test getting merge ref for current branch
+        merge_ref = porcelain.get_branch_merge(self.repo)
+        self.assertEqual(b"refs/heads/main", merge_ref)
+
+        # Test getting merge ref for specific branch
+        merge_ref = porcelain.get_branch_merge(self.repo, b"master")
+        self.assertEqual(b"refs/heads/main", merge_ref)
+
+        # Test branch without merge config
+        with self.assertRaises(KeyError):
+            porcelain.get_branch_merge(self.repo, b"nonexistent")
+
+    def test_set_branch_tracking(self) -> None:
+        # Create a new branch
+        sha, _ = _commit_file_with_content(self.repo, "foo", "content\n")
+        porcelain.branch_create(self.repo, "feature")
+
+        # Set up tracking
+        porcelain.set_branch_tracking(
+            self.repo, b"feature", b"upstream", b"refs/heads/main"
+        )
+
+        # Verify configuration was written
+        config = self.repo.get_config()
+        self.assertEqual(b"upstream", config.get((b"branch", b"feature"), b"remote"))
+        self.assertEqual(
+            b"refs/heads/main", config.get((b"branch", b"feature"), b"merge")
+        )
+
+
 class FindUniqueAbbrevTests(PorcelainTestCase):
     def test_simple(self) -> None:
         c1, c2, c3 = build_commit_graph(

+ 4 - 4
tests/test_refs.py

@@ -901,24 +901,24 @@ class ParseRemoteRefTests(TestCase):
         remote, branch = parse_remote_ref(b"refs/remotes/origin/main")
         self.assertEqual(b"origin", remote)
         self.assertEqual(b"main", branch)
-        
+
         # Test with branch containing slashes
         remote, branch = parse_remote_ref(b"refs/remotes/upstream/feature/new-ui")
         self.assertEqual(b"upstream", remote)
         self.assertEqual(b"feature/new-ui", branch)
-    
+
     def test_invalid_not_remote_ref(self) -> None:
         # Not a remote ref
         with self.assertRaises(ValueError) as cm:
             parse_remote_ref(b"refs/heads/main")
         self.assertIn("Not a remote ref", str(cm.exception))
-    
+
     def test_invalid_format(self) -> None:
         # Missing branch name
         with self.assertRaises(ValueError) as cm:
             parse_remote_ref(b"refs/remotes/origin")
         self.assertIn("Invalid remote ref format", str(cm.exception))
-        
+
         # Just the prefix
         with self.assertRaises(ValueError) as cm:
             parse_remote_ref(b"refs/remotes/")