Просмотр исходного кода

Add convenience functions for dealing with refs (#1937)

Jelmer Vernooij 3 месяцев назад
Родитель
Сommit
94c702f291
8 измененных файлов с 198 добавлено и 42 удалено
  1. 4 4
      dulwich/filter_branch.py
  2. 4 3
      dulwich/objectspec.py
  3. 25 21
      dulwich/porcelain.py
  4. 2 1
      dulwich/rebase.py
  5. 91 7
      dulwich/refs.py
  6. 6 3
      dulwich/repo.py
  7. 2 3
      dulwich/worktree.py
  8. 64 0
      tests/test_refs.py

+ 4 - 4
dulwich/filter_branch.py

@@ -30,7 +30,7 @@ from typing import Callable, Optional, TypedDict
 from .index import Index, build_index_from_tree
 from .object_store import BaseObjectStore
 from .objects import Commit, Tag, Tree
-from .refs import RefsContainer
+from .refs import RefsContainer, local_tag_name
 
 
 class CommitData(TypedDict, total=False):
@@ -483,12 +483,12 @@ def filter_refs(
                                 new_tag.tag_timezone = tag_obj.tag_timezone
                                 object_store.add_object(new_tag)
                                 # Update ref to point to new tag object
-                                refs[b"refs/tags/" + new_tag_name] = new_tag.id
+                                refs[local_tag_name(new_tag_name)] = new_tag.id
                                 # Delete old tag
                                 del refs[ref]
                             else:
                                 # Just rename the tag
-                                new_ref = b"refs/tags/" + new_tag_name
+                                new_ref = local_tag_name(new_tag_name)
                                 tag_callback(ref, new_ref)
                 elif isinstance(tag_obj, Commit):
                     # Lightweight tag - points directly to a commit
@@ -496,7 +496,7 @@ def filter_refs(
                     if tag_sha in mapping or commit_filter.tag_name_filter is not None:
                         new_tag_name = commit_filter.tag_name_filter(tag_name)
                         if new_tag_name and new_tag_name != tag_name:
-                            new_ref = b"refs/tags/" + new_tag_name
+                            new_ref = local_tag_name(new_tag_name)
                             if tag_sha in mapping:
                                 # Point to rewritten commit
                                 refs[new_ref] = mapping[tag_sha]

+ 4 - 3
dulwich/objectspec.py

@@ -25,6 +25,7 @@ from collections.abc import Sequence
 from typing import TYPE_CHECKING, Optional, Union
 
 from .objects import Commit, ShaFile, Tag, Tree
+from .refs import local_branch_name, local_tag_name
 from .repo import BaseRepo
 
 if TYPE_CHECKING:
@@ -232,8 +233,8 @@ def parse_ref(
     possible_refs = [
         refspec,
         b"refs/" + refspec,
-        b"refs/tags/" + refspec,
-        b"refs/heads/" + refspec,
+        local_tag_name(refspec),
+        local_branch_name(refspec),
         b"refs/remotes/" + refspec,
         b"refs/remotes/" + refspec + b"/HEAD",
     ]
@@ -282,7 +283,7 @@ def parse_reftuple(
         except KeyError:
             # TODO: check force?
             if b"/" not in rh:
-                rh = b"refs/heads/" + rh
+                rh = local_branch_name(rh)
     return (lh, rh, force)
 
 

+ 25 - 21
dulwich/porcelain.py

@@ -193,6 +193,8 @@ from .refs import (
     SymrefLoop,
     _import_remote_refs,
     filter_ref_prefix,
+    local_branch_name,
+    local_tag_name,
     shorten_ref_name,
 )
 from .repo import BaseRepo, Repo, get_user_identity
@@ -3354,13 +3356,13 @@ def receive_pack(
 def _make_branch_ref(name: Union[str, bytes]) -> Ref:
     if isinstance(name, str):
         name = name.encode(DEFAULT_ENCODING)
-    return LOCAL_BRANCH_PREFIX + name
+    return local_branch_name(name)
 
 
 def _make_tag_ref(name: Union[str, bytes]) -> Ref:
     if isinstance(name, str):
         name = name.encode(DEFAULT_ENCODING)
-    return LOCAL_TAG_PREFIX + name
+    return local_tag_name(name)
 
 
 def branch_delete(
@@ -3406,10 +3408,11 @@ def branch_create(
             if isinstance(objectish, str)
             else objectish
         )
+
         if b"refs/remotes/" + objectish_bytes in r.refs:
             objectish = b"refs/remotes/" + objectish_bytes
-        elif b"refs/heads/" + objectish_bytes in r.refs:
-            objectish = b"refs/heads/" + objectish_bytes
+        elif local_branch_name(objectish_bytes) in r.refs:
+            objectish = local_branch_name(objectish_bytes)
 
         object = parse_object(r, objectish)
         refname = _make_branch_ref(name)
@@ -3441,12 +3444,13 @@ def branch_create(
                 if isinstance(original_objectish, str)
                 else original_objectish
             )
+
             if objectish_bytes in r.refs:
                 objectish_ref = objectish_bytes
             elif b"refs/remotes/" + objectish_bytes in r.refs:
                 objectish_ref = b"refs/remotes/" + objectish_bytes
-            elif b"refs/heads/" + objectish_bytes in r.refs:
-                objectish_ref = b"refs/heads/" + objectish_bytes
+            elif local_branch_name(objectish_bytes) in r.refs:
+                objectish_ref = local_branch_name(objectish_bytes)
         else:
             # HEAD might point to a remote-tracking branch
             head_ref = r.refs.follow(b"HEAD")[0][1]
@@ -3466,7 +3470,7 @@ def branch_create(
                 parts = objectish_ref[len(b"refs/remotes/") :].split(b"/", 1)
                 if len(parts) == 2:
                     remote_name = parts[0]
-                    remote_branch = b"refs/heads/" + parts[1]
+                    remote_branch = local_branch_name(parts[1])
 
                     # Set up tracking
                     repo_config = r.get_config()
@@ -3529,7 +3533,7 @@ def branch_list(repo: RepoPath) -> list[bytes]:
         elif sort_key in ("committerdate", "authordate"):
             # Sort by date
             def get_commit_date(branch_name: bytes) -> int:
-                ref = LOCAL_BRANCH_PREFIX + branch_name
+                ref = local_branch_name(branch_name)
                 sha = r.refs[ref]
                 commit = r.object_store[sha]
                 assert isinstance(commit, Commit)
@@ -4051,16 +4055,16 @@ def show_branch(
                 # Try as full ref name first
                 if branch_bytes in refs:
                     branch_refs[branch_bytes] = refs[branch_bytes]
-                # Try as branch name
-                elif LOCAL_BRANCH_PREFIX + branch_bytes in refs:
-                    branch_refs[LOCAL_BRANCH_PREFIX + branch_bytes] = refs[
-                        LOCAL_BRANCH_PREFIX + branch_bytes
-                    ]
-                # Try as remote branch
-                elif LOCAL_REMOTE_PREFIX + branch_bytes in refs:
-                    branch_refs[LOCAL_REMOTE_PREFIX + branch_bytes] = refs[
-                        LOCAL_REMOTE_PREFIX + branch_bytes
-                    ]
+                else:
+                    # Try as branch name
+                    branch_ref = local_branch_name(branch_bytes)
+                    if branch_ref in refs:
+                        branch_refs[branch_ref] = refs[branch_ref]
+                    # Try as remote branch
+                    elif LOCAL_REMOTE_PREFIX + branch_bytes in refs:
+                        branch_refs[LOCAL_REMOTE_PREFIX + branch_bytes] = refs[
+                            LOCAL_REMOTE_PREFIX + branch_bytes
+                        ]
         else:
             # Default behavior: show local branches
             if all_branches:
@@ -4817,7 +4821,7 @@ def checkout(
             update_head(r, new_branch)
 
             # Set up tracking if creating from a remote branch
-            from .refs import LOCAL_REMOTE_PREFIX, parse_remote_ref
+            from .refs import LOCAL_REMOTE_PREFIX, local_branch_name, parse_remote_ref
 
             if isinstance(original_target, bytes) and target_bytes.startswith(
                 LOCAL_REMOTE_PREFIX
@@ -4826,7 +4830,7 @@ def checkout(
                     remote_name, branch_name = parse_remote_ref(target_bytes)
                     # Set tracking to refs/heads/<branch> on the remote
                     set_branch_tracking(
-                        r, new_branch, remote_name, b"refs/heads/" + branch_name
+                        r, new_branch, remote_name, local_branch_name(branch_name)
                     )
                 except ValueError:
                     # Invalid remote ref format, skip tracking setup
@@ -6575,7 +6579,7 @@ def filter_branch(
             else:
                 # Convert branch name to full ref if needed
                 if not branch.startswith(b"refs/"):
-                    branch = b"refs/heads/" + branch
+                    branch = local_branch_name(branch)
                 refs = [branch]
 
         # Convert subdirectory filter to bytes if needed

+ 2 - 1
dulwich/rebase.py

@@ -33,6 +33,7 @@ from dulwich.graph import find_merge_base
 from dulwich.merge import three_way_merge
 from dulwich.objects import Commit
 from dulwich.objectspec import parse_commit
+from dulwich.refs import local_branch_name
 from dulwich.repo import BaseRepo, Repo
 
 
@@ -762,7 +763,7 @@ class Rebaser:
                 self._rebasing_branch = branch
             else:
                 # Assume it's a branch name
-                self._rebasing_branch = b"refs/heads/" + branch
+                self._rebasing_branch = local_branch_name(branch)
         else:
             # Use current branch
             if self._original_head is not None and self._original_head.startswith(

+ 91 - 7
dulwich/refs.py

@@ -1468,6 +1468,90 @@ def is_local_branch(x: bytes) -> bool:
     return x.startswith(LOCAL_BRANCH_PREFIX)
 
 
+def local_branch_name(name: bytes) -> bytes:
+    """Build a full branch ref from a short name.
+
+    Args:
+      name: Short branch name (e.g., b"master") or full ref
+
+    Returns:
+      Full branch ref name (e.g., b"refs/heads/master")
+
+    Examples:
+      >>> local_branch_name(b"master")
+      b'refs/heads/master'
+      >>> local_branch_name(b"refs/heads/master")
+      b'refs/heads/master'
+    """
+    if name.startswith(LOCAL_BRANCH_PREFIX):
+        return name
+    return LOCAL_BRANCH_PREFIX + name
+
+
+def local_tag_name(name: bytes) -> bytes:
+    """Build a full tag ref from a short name.
+
+    Args:
+      name: Short tag name (e.g., b"v1.0") or full ref
+
+    Returns:
+      Full tag ref name (e.g., b"refs/tags/v1.0")
+
+    Examples:
+      >>> local_tag_name(b"v1.0")
+      b'refs/tags/v1.0'
+      >>> local_tag_name(b"refs/tags/v1.0")
+      b'refs/tags/v1.0'
+    """
+    if name.startswith(LOCAL_TAG_PREFIX):
+        return name
+    return LOCAL_TAG_PREFIX + name
+
+
+def extract_branch_name(ref: bytes) -> bytes:
+    """Extract branch name from a full branch ref.
+
+    Args:
+      ref: Full branch ref (e.g., b"refs/heads/master")
+
+    Returns:
+      Short branch name (e.g., b"master")
+
+    Raises:
+      ValueError: If ref is not a local branch
+
+    Examples:
+      >>> extract_branch_name(b"refs/heads/master")
+      b'master'
+      >>> extract_branch_name(b"refs/heads/feature/foo")
+      b'feature/foo'
+    """
+    if not ref.startswith(LOCAL_BRANCH_PREFIX):
+        raise ValueError(f"Not a local branch ref: {ref!r}")
+    return ref[len(LOCAL_BRANCH_PREFIX) :]
+
+
+def extract_tag_name(ref: bytes) -> bytes:
+    """Extract tag name from a full tag ref.
+
+    Args:
+      ref: Full tag ref (e.g., b"refs/tags/v1.0")
+
+    Returns:
+      Short tag name (e.g., b"v1.0")
+
+    Raises:
+      ValueError: If ref is not a local tag
+
+    Examples:
+      >>> extract_tag_name(b"refs/tags/v1.0")
+      b'v1.0'
+    """
+    if not ref.startswith(LOCAL_TAG_PREFIX):
+        raise ValueError(f"Not a local tag ref: {ref!r}")
+    return ref[len(LOCAL_TAG_PREFIX) :]
+
+
 def shorten_ref_name(ref: bytes) -> bytes:
     """Convert a full ref name to its short form.
 
@@ -1527,7 +1611,7 @@ def _set_origin_head(
     origin_base = b"refs/remotes/" + origin + b"/"
     if origin_head and origin_head.startswith(LOCAL_BRANCH_PREFIX):
         origin_ref = origin_base + HEADREF
-        target_ref = origin_base + origin_head[len(LOCAL_BRANCH_PREFIX) :]
+        target_ref = origin_base + extract_branch_name(origin_head)
         if target_ref in refs:
             refs.set_symbolic_ref(origin_ref, target_ref)
 
@@ -1544,17 +1628,17 @@ def _set_default_branch(
     if branch:
         origin_ref = origin_base + branch
         if origin_ref in refs:
-            local_ref = LOCAL_BRANCH_PREFIX + branch
+            local_ref = local_branch_name(branch)
             refs.add_if_new(local_ref, refs[origin_ref], ref_message)
             head_ref = local_ref
-        elif LOCAL_TAG_PREFIX + branch in refs:
-            head_ref = LOCAL_TAG_PREFIX + branch
+        elif local_tag_name(branch) in refs:
+            head_ref = local_tag_name(branch)
         else:
             raise ValueError(f"{os.fsencode(branch)!r} is not a valid branch or tag")
     elif origin_head:
         head_ref = origin_head
         if origin_head.startswith(LOCAL_BRANCH_PREFIX):
-            origin_ref = origin_base + origin_head[len(LOCAL_BRANCH_PREFIX) :]
+            origin_ref = origin_base + extract_branch_name(origin_head)
         else:
             origin_ref = origin_head
         try:
@@ -1598,7 +1682,7 @@ def _import_remote_refs(
 ) -> None:
     stripped_refs = strip_peeled_refs(refs)
     branches = {
-        n[len(LOCAL_BRANCH_PREFIX) :]: v
+        extract_branch_name(n): v
         for (n, v) in stripped_refs.items()
         if n.startswith(LOCAL_BRANCH_PREFIX) and v is not None
     }
@@ -1609,7 +1693,7 @@ def _import_remote_refs(
         prune=prune,
     )
     tags = {
-        n[len(LOCAL_TAG_PREFIX) :]: v
+        extract_tag_name(n): v
         for (n, v) in stripped_refs.items()
         if n.startswith(LOCAL_TAG_PREFIX)
         and not n.endswith(PEELED_TAG_SUFFIX)

+ 6 - 3
dulwich/repo.py

@@ -104,7 +104,6 @@ from .objects import (
 from .pack import generate_unpacked_objects
 from .refs import (
     ANNOTATED_TAG_SUFFIX,  # noqa: F401
-    LOCAL_BRANCH_PREFIX,
     LOCAL_TAG_PREFIX,  # noqa: F401
     SYMREF,  # noqa: F401
     DictRefsContainer,
@@ -116,7 +115,9 @@ from .refs import (
     _set_head,
     _set_origin_head,
     check_ref_format,  # noqa: F401
+    extract_branch_name,
     is_per_worktree_ref,
+    local_branch_name,
     read_packed_refs,  # noqa: F401
     read_packed_refs_with_peeled,  # noqa: F401
     serialize_refs,
@@ -1762,7 +1763,9 @@ class Repo(BaseRepo):
             else:
                 if head_ref and head_ref.startswith(b"refs/heads/"):
                     # Extract branch name from ref
-                    branch = head_ref[11:].decode("utf-8", errors="replace")
+                    branch = extract_branch_name(head_ref).decode(
+                        "utf-8", errors="replace"
+                    )
                     return match_glob_pattern(branch, pattern)
             return False
 
@@ -1877,7 +1880,7 @@ class Repo(BaseRepo):
                 default_branch = config.get("init", "defaultBranch")
             except KeyError:
                 default_branch = DEFAULT_BRANCH
-        ret.refs.set_symbolic_ref(b"HEAD", LOCAL_BRANCH_PREFIX + default_branch)
+        ret.refs.set_symbolic_ref(b"HEAD", local_branch_name(default_branch))
         ret._init_files(bare=bare, symlinks=symlinks, format=format)
         return ret
 

+ 2 - 3
dulwich/worktree.py

@@ -38,7 +38,7 @@ from typing import Any, Callable, Union
 
 from .errors import CommitError, HookError
 from .objects import Blob, Commit, ObjectID, Tag, Tree
-from .refs import SYMREF, Ref
+from .refs import SYMREF, Ref, local_branch_name
 from .repo import (
     GITDIR,
     WORKTREES,
@@ -921,8 +921,7 @@ def add_worktree(
     if branch is not None:
         if isinstance(branch, str):
             branch = branch.encode()
-        if not branch.startswith(b"refs/heads/"):
-            branch = b"refs/heads/" + branch
+        branch = local_branch_name(branch)
 
     # Check if branch is already checked out in another worktree
     if branch and not force:

+ 64 - 0
tests/test_refs.py

@@ -1237,3 +1237,67 @@ class ShortenRefNameTests(TestCase):
         # Refs that don't match any standard prefix are returned as-is
         self.assertEqual(b"refs/stash", shorten_ref_name(b"refs/stash"))
         self.assertEqual(b"refs/bisect/good", shorten_ref_name(b"refs/bisect/good"))
+
+
+class RefUtilityFunctionsTests(TestCase):
+    """Tests for the new ref utility functions."""
+
+    def test_local_branch_name(self) -> None:
+        """Test local_branch_name function."""
+        from dulwich.refs import local_branch_name
+
+        # Test adding prefix to branch name
+        self.assertEqual(b"refs/heads/master", local_branch_name(b"master"))
+        self.assertEqual(b"refs/heads/develop", local_branch_name(b"develop"))
+        self.assertEqual(
+            b"refs/heads/feature/new-ui", local_branch_name(b"feature/new-ui")
+        )
+
+        # Test idempotency - already has prefix
+        self.assertEqual(b"refs/heads/master", local_branch_name(b"refs/heads/master"))
+
+    def test_local_tag_name(self) -> None:
+        """Test local_tag_name function."""
+        from dulwich.refs import local_tag_name
+
+        # Test adding prefix to tag name
+        self.assertEqual(b"refs/tags/v1.0", local_tag_name(b"v1.0"))
+        self.assertEqual(b"refs/tags/release-2.0", local_tag_name(b"release-2.0"))
+
+        # Test idempotency - already has prefix
+        self.assertEqual(b"refs/tags/v1.0", local_tag_name(b"refs/tags/v1.0"))
+
+    def test_extract_branch_name(self) -> None:
+        """Test extract_branch_name function."""
+        from dulwich.refs import extract_branch_name
+
+        # Test extracting branch name from full ref
+        self.assertEqual(b"master", extract_branch_name(b"refs/heads/master"))
+        self.assertEqual(b"develop", extract_branch_name(b"refs/heads/develop"))
+        self.assertEqual(
+            b"feature/new-ui", extract_branch_name(b"refs/heads/feature/new-ui")
+        )
+
+        # Test error on invalid ref
+        with self.assertRaises(ValueError) as cm:
+            extract_branch_name(b"refs/tags/v1.0")
+        self.assertIn("Not a local branch ref", str(cm.exception))
+
+        with self.assertRaises(ValueError):
+            extract_branch_name(b"master")
+
+    def test_extract_tag_name(self) -> None:
+        """Test extract_tag_name function."""
+        from dulwich.refs import extract_tag_name
+
+        # Test extracting tag name from full ref
+        self.assertEqual(b"v1.0", extract_tag_name(b"refs/tags/v1.0"))
+        self.assertEqual(b"release-2.0", extract_tag_name(b"refs/tags/release-2.0"))
+
+        # Test error on invalid ref
+        with self.assertRaises(ValueError) as cm:
+            extract_tag_name(b"refs/heads/master")
+        self.assertIn("Not a local tag ref", str(cm.exception))
+
+        with self.assertRaises(ValueError):
+            extract_tag_name(b"v1.0")