Explorar el Código

refs: handle per-worktree and shared refs on read and on reflog writes (#1749)

This commit adds a proper support for reading and iterating over
references from shared and per-worktree "refs" directory depending on
the reference name.

It also updates the reflog writing to use the correct directory based
on the reference.

Fixes https://github.com/iterative/dvc/issues/10821.
Related documentation: https://git-scm.com/docs/git-worktree#_refs

### TODO

- [x] Add tests
Jelmer Vernooij hace 5 meses
padre
commit
1658e0597c
Se han modificado 4 ficheros con 277 adiciones y 31 borrados
  1. 69 29
      dulwich/refs.py
  2. 10 2
      dulwich/repo.py
  3. 188 0
      tests/test_refs.py
  4. 10 0
      tests/test_stash.py

+ 69 - 29
dulwich/refs.py

@@ -831,21 +831,55 @@ class DiskRefsContainer(RefsContainer):
         """Return string representation of DiskRefsContainer."""
         return f"{self.__class__.__name__}({self.path!r})"
 
-    def subkeys(self, base: bytes) -> set[bytes]:
-        """Return subkeys under a given base reference path."""
-        subkeys = set()
-        path = self.refpath(base)
-        for root, unused_dirs, files in os.walk(path):
-            directory = root[len(path) :]
+    def _iter_dir(
+        self,
+        path: bytes,
+        base: bytes,
+        dir_filter: Optional[Callable[[bytes], bool]] = None,
+    ) -> Iterator[bytes]:
+        refspath = os.path.join(path, base.rstrip(b"/"))
+        prefix_len = len(os.path.join(path, b""))
+
+        for root, dirs, files in os.walk(refspath):
+            directory = root[prefix_len:]
             if os.path.sep != "/":
                 directory = directory.replace(os.fsencode(os.path.sep), b"/")
-            directory = directory.strip(b"/")
+            if dir_filter is not None:
+                dirs[:] = [
+                    d for d in dirs if dir_filter(b"/".join([directory, d, b""]))
+                ]
+
             for filename in files:
-                refname = b"/".join(([directory] if directory else []) + [filename])
-                # check_ref_format requires at least one /, so we prepend the
-                # base before calling it.
-                if check_ref_format(base + b"/" + refname):
-                    subkeys.add(refname)
+                refname = b"/".join([directory, filename])
+                if check_ref_format(refname):
+                    yield refname
+
+    def _iter_loose_refs(self, base: bytes = b"refs/") -> Iterator[bytes]:
+        base = base.rstrip(b"/") + b"/"
+        search_paths: list[tuple[bytes, Optional[Callable[[bytes], bool]]]] = []
+        if base != b"refs/":
+            path = self.worktree_path if is_per_worktree_ref(base) else self.path
+            search_paths.append((path, None))
+        elif self.worktree_path == self.path:
+            # Iterate through all the refs from the main worktree
+            search_paths.append((self.path, None))
+        else:
+            # Iterate through all the shared refs from the commondir, excluding per-worktree refs
+            search_paths.append((self.path, lambda r: not is_per_worktree_ref(r)))
+            # Iterate through all the per-worktree refs from the worktree's gitdir
+            search_paths.append((self.worktree_path, is_per_worktree_ref))
+
+        for path, dir_filter in search_paths:
+            yield from self._iter_dir(path, base, dir_filter=dir_filter)
+
+    def subkeys(self, base: bytes) -> set[bytes]:
+        """Return subkeys under a given base reference path."""
+        subkeys = set()
+
+        for key in self._iter_loose_refs(base):
+            if key.startswith(base):
+                subkeys.add(key[len(base) :].strip(b"/"))
+
         for key in self.get_packed_refs():
             if key.startswith(base):
                 subkeys.add(key[len(base) :].strip(b"/"))
@@ -856,29 +890,19 @@ class DiskRefsContainer(RefsContainer):
         allkeys = set()
         if os.path.exists(self.refpath(HEADREF)):
             allkeys.add(HEADREF)
-        path = self.refpath(b"")
-        refspath = self.refpath(b"refs")
-        for root, unused_dirs, files in os.walk(refspath):
-            directory = root[len(path) :]
-            if os.path.sep != "/":
-                directory = directory.replace(os.fsencode(os.path.sep), b"/")
-            for filename in files:
-                refname = b"/".join([directory, filename])
-                if check_ref_format(refname):
-                    allkeys.add(refname)
+
+        allkeys.update(self._iter_loose_refs())
         allkeys.update(self.get_packed_refs())
         return allkeys
 
     def refpath(self, name: bytes) -> bytes:
         """Return the disk path of a ref."""
+        path = name
         if os.path.sep != "/":
-            name = name.replace(b"/", os.fsencode(os.path.sep))
-        # TODO: as the 'HEAD' reference is working tree specific, it
-        # should actually not be a part of RefsContainer
-        if name == HEADREF:
-            return os.path.join(self.worktree_path, name)
-        else:
-            return os.path.join(self.path, name)
+            path = path.replace(b"/", os.fsencode(os.path.sep))
+
+        root_dir = self.worktree_path if is_per_worktree_ref(name) else self.path
+        return os.path.join(root_dir, path)
 
     def get_packed_refs(self) -> dict[bytes, bytes]:
         """Get contents of the packed-refs file.
@@ -1741,3 +1765,19 @@ def filter_ref_prefix(refs: T, prefixes: Iterable[bytes]) -> T:
     """
     filtered = {k: v for k, v in refs.items() if any(k.startswith(p) for p in prefixes)}
     return cast(T, filtered)
+
+
+def is_per_worktree_ref(ref: bytes) -> bool:
+    """Returns whether a reference is stored per worktree or not.
+
+    Per-worktree references are:
+    - all pseudorefs, e.g. HEAD
+    - all references stored inside "refs/bisect/", "refs/worktree/" and "refs/rewritten/"
+
+    All refs starting with "refs/" are shared, except for the ones listed above.
+
+    See https://git-scm.com/docs/git-worktree#_refs.
+    """
+    return not ref.startswith(b"refs/") or ref.startswith(
+        (b"refs/bisect/", b"refs/worktree/", b"refs/rewritten/")
+    )

+ 10 - 2
dulwich/repo.py

@@ -112,6 +112,7 @@ from .refs import (
     _set_head,
     _set_origin_head,
     check_ref_format,  # noqa: F401
+    is_per_worktree_ref,
     read_packed_refs,  # noqa: F401
     read_packed_refs_with_peeled,  # noqa: F401
     serialize_refs,
@@ -1315,7 +1316,7 @@ class Repo(BaseRepo):
     ) -> None:
         from .reflog import format_reflog_line
 
-        path = os.path.join(self.controldir(), "logs", os.fsdecode(ref))
+        path = self._reflog_path(ref)
         try:
             os.makedirs(os.path.dirname(path))
         except FileExistsError:
@@ -1336,6 +1337,13 @@ class Repo(BaseRepo):
                 + b"\n"
             )
 
+    def _reflog_path(self, ref: bytes) -> str:
+        if ref.startswith((b"main-worktree/", b"worktrees/")):
+            raise NotImplementedError(f"refs {ref.decode()} are not supported")
+
+        base = self.controldir() if is_per_worktree_ref(ref) else self.commondir()
+        return os.path.join(base, "logs", os.fsdecode(ref))
+
     def read_reflog(self, ref):
         """Read reflog entries for a reference.
 
@@ -1347,7 +1355,7 @@ class Repo(BaseRepo):
         """
         from .reflog import read_reflog
 
-        path = os.path.join(self.controldir(), "logs", os.fsdecode(ref))
+        path = self._reflog_path(ref)
         try:
             with open(path, "rb") as f:
                 yield from read_reflog(f)

+ 188 - 0
tests/test_refs.py

@@ -36,6 +36,7 @@ from dulwich.refs import (
     SymrefLoop,
     _split_ref_line,
     check_ref_format,
+    is_per_worktree_ref,
     parse_remote_ref,
     parse_symref_value,
     read_packed_refs,
@@ -46,6 +47,7 @@ from dulwich.refs import (
 )
 from dulwich.repo import Repo
 from dulwich.tests.utils import open_repo, tear_down_repo
+from dulwich.worktree import add_worktree
 
 from . import SkipTest, TestCase
 
@@ -838,6 +840,192 @@ class DiskRefsContainerTests(RefsContainerTests, TestCase):
         self._refs[ref_name] = ref_value
 
 
+class IsPerWorktreeRefsTests(TestCase):
+    def test(self) -> None:
+        cases = [
+            (b"HEAD", True),
+            (b"refs/bisect/good", True),
+            (b"refs/worktree/foo", True),
+            (b"refs/rewritten/onto", True),
+            (b"refs/stash", False),
+            (b"refs/heads/main", False),
+            (b"refs/tags/v1.0", False),
+            (b"refs/remotes/origin/main", False),
+            (b"refs/custom/foo", False),
+            (b"refs/replace/aaaaaa", False),
+        ]
+        for ref, expected in cases:
+            with self.subTest(ref=ref, expected=expected):
+                self.assertEqual(is_per_worktree_ref(ref), expected)
+
+
+class DiskRefsContainerWorktreeRefsTest(TestCase):
+    def setUp(self) -> None:
+        # Create temporary directories
+        temp_dir = tempfile.mkdtemp()
+        test_dir = os.path.join(temp_dir, "main")
+        os.makedirs(test_dir)
+
+        repo = Repo.init(test_dir, default_branch=b"main")
+        main_worktree = repo.get_worktree()
+        with open(os.path.join(test_dir, "test.txt"), "wb") as f:
+            f.write(b"test content")
+        main_worktree.stage(["test.txt"])
+        self.first_commit = main_worktree.commit(message=b"Initial commit")
+
+        worktree_dir = os.path.join(temp_dir, "worktree")
+        wt_repo = add_worktree(repo, worktree_dir, branch="wt-main")
+        linked_worktree = wt_repo.get_worktree()
+        with open(os.path.join(test_dir, "test2.txt"), "wb") as f:
+            f.write(b"test content")
+        linked_worktree.stage(["test2.txt"])
+        self.second_commit = linked_worktree.commit(message=b"second commit")
+
+        self.refs = repo.refs
+        self.wt_refs = wt_repo.refs
+
+    def test_refpath(self) -> None:
+        main_path = self.refs.path
+        common = self.wt_refs.path
+        wt_path = self.wt_refs.worktree_path
+
+        cases = [
+            (self.refs, b"HEAD", main_path),
+            (self.refs, b"refs/heads/main", main_path),
+            (self.refs, b"refs/heads/wt-main", main_path),
+            (self.refs, b"refs/worktree/foo", main_path),
+            (self.refs, b"refs/bisect/good", main_path),
+            (self.wt_refs, b"HEAD", wt_path),
+            (self.wt_refs, b"refs/heads/main", common),
+            (self.wt_refs, b"refs/heads/wt-main", common),
+            (self.wt_refs, b"refs/worktree/foo", wt_path),
+            (self.wt_refs, b"refs/bisect/good", wt_path),
+        ]
+
+        for refs, refname, git_dir in cases:
+            with self.subTest(refs=refs, refname=refname, git_dir=git_dir):
+                refpath = refs.refpath(refname)
+                expected_path = os.path.join(
+                    git_dir, refname.replace(b"/", os.fsencode(os.sep))
+                )
+                self.assertEqual(refpath, expected_path)
+
+    def test_shared_ref(self) -> None:
+        self.assertEqual(self.refs[b"refs/heads/main"], self.first_commit)
+        self.assertEqual(self.refs[b"refs/heads/wt-main"], self.second_commit)
+        self.assertEqual(self.wt_refs[b"refs/heads/main"], self.first_commit)
+        self.assertEqual(self.wt_refs[b"refs/heads/wt-main"], self.second_commit)
+
+        expected = {b"HEAD", b"refs/heads/main", b"refs/heads/wt-main"}
+        self.assertEqual(expected, self.refs.keys())
+        self.assertEqual(expected, self.wt_refs.keys())
+
+        self.assertEqual({b"main", b"wt-main"}, set(self.refs.keys(b"refs/heads/")))
+        self.assertEqual({b"main", b"wt-main"}, set(self.wt_refs.keys(b"refs/heads/")))
+
+        ref_path = os.path.join(self.refs.path, b"refs", b"heads", b"main")
+        self.assertTrue(os.path.exists(ref_path))
+
+        ref_path = os.path.join(self.wt_refs.worktree_path, b"refs", b"heads", b"main")
+        self.assertFalse(os.path.exists(ref_path))
+
+    def test_per_worktree_ref(self) -> None:
+        path = self.refs.path
+        wt_path = self.wt_refs.worktree_path
+
+        self.assertEqual(self.refs[b"HEAD"], self.first_commit)
+        self.assertEqual(self.wt_refs[b"HEAD"], self.second_commit)
+
+        self.refs[b"refs/bisect/good"] = self.first_commit
+        self.wt_refs[b"refs/bisect/good"] = self.second_commit
+
+        self.refs[b"refs/bisect/start"] = self.first_commit
+        self.wt_refs[b"refs/bisect/bad"] = self.second_commit
+
+        self.assertEqual(self.refs[b"refs/bisect/good"], self.first_commit)
+        self.assertEqual(self.wt_refs[b"refs/bisect/good"], self.second_commit)
+
+        self.assertTrue(os.path.exists(os.path.join(path, b"refs", b"bisect", b"good")))
+        self.assertTrue(
+            os.path.exists(os.path.join(wt_path, b"refs", b"bisect", b"good"))
+        )
+
+        self.assertEqual(self.refs[b"refs/bisect/start"], self.first_commit)
+        with self.assertRaises(KeyError):
+            self.wt_refs[b"refs/bisect/start"]
+        self.assertTrue(
+            os.path.exists(os.path.join(path, b"refs", b"bisect", b"start"))
+        )
+        self.assertFalse(
+            os.path.exists(os.path.join(wt_path, b"refs", b"bisect", b"start"))
+        )
+
+        with self.assertRaises(KeyError):
+            self.refs[b"refs/bisect/bad"]
+        self.assertEqual(self.wt_refs[b"refs/bisect/bad"], self.second_commit)
+        self.assertFalse(os.path.exists(os.path.join(path, b"refs", b"bisect", b"bad")))
+        self.assertTrue(
+            os.path.exists(os.path.join(wt_path, b"refs", b"bisect", b"bad"))
+        )
+
+        expected_refs = {
+            b"HEAD",
+            b"refs/heads/main",
+            b"refs/heads/wt-main",
+            b"refs/bisect/good",
+            b"refs/bisect/start",
+        }
+        self.assertEqual(self.refs.keys(), expected_refs)
+        self.assertEqual({b"good", b"start"}, self.refs.keys(b"refs/bisect/"))
+
+        expected_wt_refs = {
+            b"HEAD",
+            b"refs/heads/main",
+            b"refs/heads/wt-main",
+            b"refs/bisect/good",
+            b"refs/bisect/bad",
+        }
+        self.assertEqual(self.wt_refs.keys(), expected_wt_refs)
+        self.assertEqual({b"good", b"bad"}, self.wt_refs.keys(b"refs/bisect/"))
+
+    def test_delete_per_worktree_ref(self) -> None:
+        self.refs[b"refs/worktree/foo"] = self.first_commit
+        self.wt_refs[b"refs/worktree/foo"] = self.second_commit
+
+        del self.wt_refs[b"refs/worktree/foo"]
+        with self.assertRaises(KeyError):
+            self.wt_refs[b"refs/worktree/foo"]
+
+        del self.refs[b"refs/worktree/foo"]
+        with self.assertRaises(KeyError):
+            self.refs[b"refs/worktree/foo"]
+
+    def test_delete_shared_ref(self) -> None:
+        self.refs[b"refs/heads/branch"] = self.first_commit
+
+        del self.wt_refs[b"refs/heads/branch"]
+
+        with self.assertRaises(KeyError):
+            self.wt_refs[b"refs/heads/branch"]
+        with self.assertRaises(KeyError):
+            self.refs[b"refs/heads/branch"]
+
+    def test_contains_shared_ref(self):
+        self.assertIn(b"refs/heads/main", self.refs)
+        self.assertIn(b"refs/heads/main", self.wt_refs)
+        self.assertIn(b"refs/heads/wt-main", self.refs)
+        self.assertIn(b"refs/heads/wt-main", self.wt_refs)
+
+    def test_contains_per_worktree_ref(self):
+        self.refs[b"refs/worktree/foo"] = self.first_commit
+        self.wt_refs[b"refs/worktree/bar"] = self.second_commit
+
+        self.assertIn(b"refs/worktree/foo", self.refs)
+        self.assertNotIn(b"refs/worktree/bar", self.refs)
+        self.assertNotIn(b"refs/worktree/foo", self.wt_refs)
+        self.assertIn(b"refs/worktree/bar", self.wt_refs)
+
+
 _TEST_REFS_SERIALIZED = (
     b"42d06bd4b77fed026b154d16493e5deab78f02ec\t"
     b"refs/heads/40-char-ref-aaaaaaaaaaaaaaaaaa\n"

+ 10 - 0
tests/test_stash.py

@@ -28,6 +28,7 @@ import tempfile
 from dulwich.objects import Blob, Tree
 from dulwich.repo import Repo
 from dulwich.stash import DEFAULT_STASH_REF, Stash
+from dulwich.worktree import add_worktree
 
 from . import TestCase
 
@@ -221,3 +222,12 @@ class StashTests(TestCase):
         # Verify index has the staged changes
         index = self.repo.open_index()
         self.assertIn(b"new.txt", index)
+
+
+class StashInWorktreeTest(StashTests):
+    """Tests for stash in a worktree."""
+
+    def setUp(self) -> None:
+        super().setUp()
+        self.repo_dir = os.path.join(self.test_dir, "wt")
+        self.repo = add_worktree(self.repo, self.repo_dir, "worktree")