Bladeren bron

Return symrefs from ls_refs

Fixes #863
Jelmer Vernooij 1 maand geleden
bovenliggende
commit
e733f26507
8 gewijzigde bestanden met toevoegingen van 165 en 19 verwijderingen
  1. 2 0
      NEWS
  2. 13 3
      dulwich/cli.py
  3. 98 8
      dulwich/client.py
  4. 1 1
      dulwich/porcelain.py
  5. 3 3
      tests/compat/test_client.py
  6. 37 0
      tests/test_cli.py
  7. 4 2
      tests/test_client.py
  8. 7 2
      tests/test_porcelain.py

+ 2 - 0
NEWS

@@ -1,5 +1,7 @@
 0.23.1	UNRELEASED
 
+ * Return symrefs from ls_refs. (Jelmer Vernooij, #863)
+
  * Support short commit hashes in ``porcelain.reset()``.
    (Jelmer Vernooij, #1154)
 

+ 13 - 3
dulwich/cli.py

@@ -675,11 +675,21 @@ class cmd_status(Command):
 class cmd_ls_remote(Command):
     def run(self, args) -> None:
         parser = argparse.ArgumentParser()
+        parser.add_argument(
+            "--symref", action="store_true", help="Show symbolic references"
+        )
         parser.add_argument("url", help="Remote URL to list references from")
         args = parser.parse_args(args)
-        refs = porcelain.ls_remote(args.url)
-        for ref in sorted(refs):
-            sys.stdout.write(f"{ref}\t{refs[ref]}\n")
+        result = porcelain.ls_remote(args.url)
+
+        if args.symref:
+            # Show symrefs first, like git does
+            for ref, target in sorted(result.symrefs.items()):
+                sys.stdout.write(f"ref: {target.decode()}\t{ref.decode()}\n")
+
+        # Show regular refs
+        for ref in sorted(result.refs):
+            sys.stdout.write(f"{result.refs[ref].decode()}\t{ref.decode()}\n")
 
 
 class cmd_ls_tree(Command):

+ 98 - 8
dulwich/client.py

@@ -121,6 +121,7 @@ from .protocol import (
 )
 from .refs import (
     PEELED_TAG_SUFFIX,
+    SYMREF,
     Ref,
     _import_remote_refs,
     _set_default_branch,
@@ -395,6 +396,77 @@ class FetchPackResult:
         return f"{self.__class__.__name__}({self.refs!r}, {self.symrefs!r}, {self.agent!r})"
 
 
+class LsRemoteResult:
+    """Result of a ls-remote operation.
+
+    Attributes:
+      refs: Dictionary with all remote refs
+      symrefs: Dictionary with remote symrefs
+    """
+
+    _FORWARDED_ATTRS: ClassVar[set[str]] = {
+        "clear",
+        "copy",
+        "fromkeys",
+        "get",
+        "items",
+        "keys",
+        "pop",
+        "popitem",
+        "setdefault",
+        "update",
+        "values",
+        "viewitems",
+        "viewkeys",
+        "viewvalues",
+    }
+
+    def __init__(self, refs, symrefs) -> None:
+        self.refs = refs
+        self.symrefs = symrefs
+
+    def _warn_deprecated(self) -> None:
+        import warnings
+
+        warnings.warn(
+            "Treating LsRemoteResult as a dictionary is deprecated. "
+            "Use result.refs instead.",
+            DeprecationWarning,
+            stacklevel=3,
+        )
+
+    def __eq__(self, other):
+        if isinstance(other, dict):
+            self._warn_deprecated()
+            return self.refs == other
+        return self.refs == other.refs and self.symrefs == other.symrefs
+
+    def __contains__(self, name) -> bool:
+        self._warn_deprecated()
+        return name in self.refs
+
+    def __getitem__(self, name):
+        self._warn_deprecated()
+        return self.refs[name]
+
+    def __len__(self) -> int:
+        self._warn_deprecated()
+        return len(self.refs)
+
+    def __iter__(self):
+        self._warn_deprecated()
+        return iter(self.refs)
+
+    def __getattribute__(self, name):
+        if name in type(self)._FORWARDED_ATTRS:
+            self._warn_deprecated()
+            return getattr(self.refs, name)
+        return super().__getattribute__(name)
+
+    def __repr__(self) -> str:
+        return f"{self.__class__.__name__}({self.refs!r}, {self.symrefs!r})"
+
+
 class SendPackResult:
     """Result of a upload-pack operation.
 
@@ -1041,11 +1113,16 @@ class GitClient:
         path,
         protocol_version: Optional[int] = None,
         ref_prefix: Optional[list[Ref]] = None,
-    ) -> dict[Ref, ObjectID]:
+    ) -> LsRemoteResult:
         """Retrieve the current refs from a git smart server.
 
         Args:
           path: Path to the repo to fetch from. (as bytestring)
+          protocol_version: Desired Git protocol version.
+          ref_prefix: Prefix filter for refs.
+
+        Returns:
+          LsRemoteResult object with refs and symrefs
         """
         raise NotImplementedError(self.get_refs)
 
@@ -1484,13 +1561,13 @@ class TraditionalGitClient(GitClient):
             proto.write_pkt_line(None)
             with proto:
                 try:
-                    refs, _symrefs, peeled = read_pkt_refs_v2(proto.read_pkt_seq())
+                    refs, symrefs, peeled = read_pkt_refs_v2(proto.read_pkt_seq())
                 except HangupException as exc:
                     raise _remote_error_from_stderr(stderr) from exc
                 proto.write_pkt_line(None)
                 for refname, refvalue in peeled.items():
                     refs[refname + PEELED_TAG_SUFFIX] = refvalue
-                return refs
+                return LsRemoteResult(refs, symrefs)
         else:
             with proto:
                 try:
@@ -1498,10 +1575,10 @@ class TraditionalGitClient(GitClient):
                 except HangupException as exc:
                     raise _remote_error_from_stderr(stderr) from exc
                 proto.write_pkt_line(None)
-                (_symrefs, _agent) = _extract_symrefs_and_agent(server_capabilities)
+                (symrefs, _agent) = _extract_symrefs_and_agent(server_capabilities)
                 if ref_prefix is not None:
                     refs = filter_ref_prefix(refs, ref_prefix)
-                return refs
+                return LsRemoteResult(refs, symrefs)
 
     def archive(
         self,
@@ -1932,7 +2009,20 @@ class LocalGitClient(GitClient):
     ):
         """Retrieve the current refs from a local on-disk repository."""
         with self._open_repo(path) as target:
-            return target.get_refs()
+            refs = target.get_refs()
+            # Extract symrefs from the local repository
+            symrefs = {}
+            for ref in refs:
+                try:
+                    # Check if this ref is symbolic by reading it directly
+                    ref_value = target.refs.read_ref(ref)
+                    if ref_value and ref_value.startswith(SYMREF):
+                        # Extract the target from the symref
+                        symrefs[ref] = ref_value[len(SYMREF) :]
+                except (KeyError, ValueError):
+                    # Not a symbolic ref or error reading it
+                    pass
+            return LsRemoteResult(refs, symrefs)
 
 
 # What Git client to use for local access
@@ -2799,7 +2889,7 @@ class AbstractHttpGitClient(GitClient):
     ):
         """Retrieve the current refs from a git smart server."""
         url = self._get_url(path)
-        refs, _, _, _, peeled = self._discover_references(
+        refs, _, _, symrefs, peeled = self._discover_references(
             b"git-upload-pack",
             url,
             protocol_version=protocol_version,
@@ -2807,7 +2897,7 @@ class AbstractHttpGitClient(GitClient):
         )
         for refname, refvalue in peeled.items():
             refs[refname + PEELED_TAG_SUFFIX] = refvalue
-        return refs
+        return LsRemoteResult(refs, symrefs)
 
     def get_url(self, path):
         return self._get_url(path).rstrip("/")

+ 1 - 1
dulwich/porcelain.py

@@ -2110,7 +2110,7 @@ def ls_remote(remote, config: Optional[Config] = None, **kwargs):
       remote: Remote repository location
       config: Configuration to use
     Returns:
-      Dictionary with remote refs
+      LsRemoteResult object with refs and symrefs
     """
     if config is None:
         config = StackedConfig.default()

+ 3 - 3
tests/compat/test_client.py

@@ -269,7 +269,7 @@ class DulwichClientTestBase:
                 b"refs/tags/v1.0",
                 b"refs/tags/v1.0^{}",
             ],
-            sorted(refs.keys()),
+            sorted(refs.refs.keys()),
         )
 
     def test_get_refs_with_ref_prefix(self) -> None:
@@ -282,7 +282,7 @@ class DulwichClientTestBase:
                 b"refs/heads/branch",
                 b"refs/heads/master",
             ],
-            sorted(refs.keys()),
+            sorted(refs.refs.keys()),
         )
 
     def test_fetch_pack_depth(self) -> None:
@@ -405,7 +405,7 @@ class DulwichClientTestBase:
 
         repo_dir = os.path.join(self.gitroot, "server_new.export")
         with repo.Repo(repo_dir) as dest:
-            self.assertDictEqual(dest.refs.as_dict(), refs)
+            self.assertDictEqual(dest.refs.as_dict(), refs.refs)
 
 
 class DulwichTCPClientTest(CompatTestCase, DulwichClientTestBase):

+ 37 - 0
tests/test_cli.py

@@ -324,6 +324,43 @@ class FetchPackCommandTest(DulwichCliTestCase):
         mock_client.fetch.assert_called_once()
 
 
+class LsRemoteCommandTest(DulwichCliTestCase):
+    """Tests for ls-remote command."""
+
+    def test_ls_remote_basic(self):
+        # Create a commit
+        test_file = os.path.join(self.repo_path, "test.txt")
+        with open(test_file, "w") as f:
+            f.write("test")
+        self._run_cli("add", "test.txt")
+        self._run_cli("commit", "--message=Initial")
+
+        # Test basic ls-remote
+        result, stdout, stderr = self._run_cli("ls-remote", self.repo_path)
+        lines = stdout.strip().split("\n")
+        self.assertTrue(any("HEAD" in line for line in lines))
+        self.assertTrue(any("refs/heads/master" in line for line in lines))
+
+    def test_ls_remote_symref(self):
+        # Create a commit
+        test_file = os.path.join(self.repo_path, "test.txt")
+        with open(test_file, "w") as f:
+            f.write("test")
+        self._run_cli("add", "test.txt")
+        self._run_cli("commit", "--message=Initial")
+
+        # Test ls-remote with --symref option
+        result, stdout, stderr = self._run_cli("ls-remote", "--symref", self.repo_path)
+        lines = stdout.strip().split("\n")
+        # Should show symref for HEAD in exact format: "ref: refs/heads/master\tHEAD"
+        expected_line = "ref: refs/heads/master\tHEAD"
+        self.assertIn(
+            expected_line,
+            lines,
+            f"Expected line '{expected_line}' not found in output: {lines}",
+        )
+
+
 class PullCommandTest(DulwichCliTestCase):
     """Tests for pull command."""
 

+ 4 - 2
tests/test_client.py

@@ -993,8 +993,10 @@ class LocalGitClientTests(TestCase):
         self.addCleanup(tear_down_repo, local)
 
         client = LocalGitClient()
-        refs = client.get_refs(local.path)
-        self.assertDictEqual(local.refs.as_dict(), refs)
+        result = client.get_refs(local.path)
+        self.assertDictEqual(local.refs.as_dict(), result.refs)
+        # Check that symrefs are detected correctly
+        self.assertIn(b"HEAD", result.symrefs)
 
     def send_and_verify(self, branch, local, target) -> None:
         """Send branch from local to remote repository and verify it worked."""

+ 7 - 2
tests/test_porcelain.py

@@ -4357,7 +4357,9 @@ class LsTreeTests(PorcelainTestCase):
 
 class LsRemoteTests(PorcelainTestCase):
     def test_empty(self) -> None:
-        self.assertEqual({}, porcelain.ls_remote(self.repo.path))
+        result = porcelain.ls_remote(self.repo.path)
+        self.assertEqual({}, result.refs)
+        self.assertEqual({}, result.symrefs)
 
     def test_some(self) -> None:
         cid = porcelain.commit(
@@ -4367,10 +4369,13 @@ class LsRemoteTests(PorcelainTestCase):
             committer=b"committer <email>",
         )
 
+        result = porcelain.ls_remote(self.repo.path)
         self.assertEqual(
             {b"refs/heads/master": cid, b"HEAD": cid},
-            porcelain.ls_remote(self.repo.path),
+            result.refs,
         )
+        # HEAD should be a symref to refs/heads/master
+        self.assertEqual({b"HEAD": b"refs/heads/master"}, result.symrefs)
 
 
 class LsFilesTests(PorcelainTestCase):