Browse Source

Merge remote-tracking branch 'dtrifiro/add-get-refs-patterns'

Jelmer Vernooij 1 year ago
parent
commit
e1f37c3e0a
3 changed files with 140 additions and 1 deletions
  1. 10 0
      dulwich/cli.py
  2. 61 1
      dulwich/porcelain.py
  3. 69 0
      dulwich/tests/test_porcelain.py

+ 10 - 0
dulwich/cli.py

@@ -133,6 +133,15 @@ class cmd_fetch(Command):
             print("{} -> {}".format(*item))
             print("{} -> {}".format(*item))
 
 
 
 
+class cmd_for_each_ref(Command):
+    def run(self, args):
+        parser = argparse.ArgumentParser()
+        parser.add_argument("pattern", type=str, nargs="?")
+        args = parser.parse_args(args)
+        for sha, object_type, ref in porcelain.for_each_ref(".", args.pattern):
+            print(f"{sha.decode()} {object_type.decode()}\t{ref.decode()}")
+
+
 class cmd_fsck(Command):
 class cmd_fsck(Command):
     def run(self, args):
     def run(self, args):
         opts, args = getopt(args, "", [])
         opts, args = getopt(args, "", [])
@@ -765,6 +774,7 @@ commands = {
     "dump-index": cmd_dump_index,
     "dump-index": cmd_dump_index,
     "fetch-pack": cmd_fetch_pack,
     "fetch-pack": cmd_fetch_pack,
     "fetch": cmd_fetch,
     "fetch": cmd_fetch,
+    "for-each-ref": cmd_for_each_ref,
     "fsck": cmd_fsck,
     "fsck": cmd_fsck,
     "help": cmd_help,
     "help": cmd_help,
     "init": cmd_init,
     "init": cmd_init,

+ 61 - 1
dulwich/porcelain.py

@@ -33,6 +33,7 @@ Currently implemented:
  * describe
  * describe
  * diff-tree
  * diff-tree
  * fetch
  * fetch
+ * for-each-ref
  * init
  * init
  * ls-files
  * ls-files
  * ls-remote
  * ls-remote
@@ -64,6 +65,7 @@ Functions should generally accept both unicode strings and bytestrings
 """
 """
 
 
 import datetime
 import datetime
+import fnmatch
 import os
 import os
 import posixpath
 import posixpath
 import stat
 import stat
@@ -73,7 +75,7 @@ from collections import namedtuple
 from contextlib import closing, contextmanager
 from contextlib import closing, contextmanager
 from io import BytesIO, RawIOBase
 from io import BytesIO, RawIOBase
 from pathlib import Path
 from pathlib import Path
-from typing import Optional, Tuple, Union
+from typing import Dict, List, Optional, Tuple, Union
 
 
 from .archive import tar_stream
 from .archive import tar_stream
 from .client import get_transport_and_path
 from .client import get_transport_and_path
@@ -1700,6 +1702,64 @@ def fetch(
     return fetch_result
     return fetch_result
 
 
 
 
+def for_each_ref(
+    repo: Union[Repo, str] = ".",
+    pattern: Optional[Union[str, bytes]] = None,
+    **kwargs,
+) -> List[Tuple[bytes, bytes, bytes]]:
+    """Iterate over all refs that match the (optional) pattern.
+
+    Args:
+      repo: Path to the repository
+      pattern: Optional glob (7) patterns to filter the refs with
+    Returns:
+      List of bytes tuples with: (sha, object_type, ref_name)
+    """
+    if kwargs:
+        raise NotImplementedError(f"{''.join(kwargs.keys())}")
+
+    if isinstance(pattern, str):
+        pattern = os.fsencode(pattern)
+
+    with open_repo_closing(repo) as r:
+        refs = r.get_refs()
+
+    if pattern:
+        matching_refs: Dict[bytes, bytes] = {}
+        pattern_parts = pattern.split(b"/")
+        for ref, sha in refs.items():
+            matches = False
+
+            # git for-each-ref uses glob (7) style patterns, but fnmatch
+            # is greedy and also matches slashes, unlike glob.glob.
+            # We have to check parts of the pattern individually.
+            # See https://github.com/python/cpython/issues/72904
+            ref_parts = ref.split(b"/")
+            if len(ref_parts) > len(pattern_parts):
+                continue
+
+            for pat, ref_part in zip(pattern_parts, ref_parts):
+                matches = fnmatch.fnmatchcase(ref_part, pat)
+                if not matches:
+                    break
+
+            if matches:
+                matching_refs[ref] = sha
+
+        refs = matching_refs
+
+    ret: List[Tuple[bytes, bytes, bytes]] = [
+        (sha, r.get_object(sha).type_name, ref)
+        for ref, sha in sorted(
+            refs.items(),
+            key=lambda ref_sha: ref_sha[0],
+        )
+        if ref != b"HEAD"
+    ]
+
+    return ret
+
+
 def ls_remote(remote, config: Optional[Config] = None, **kwargs):
 def ls_remote(remote, config: Optional[Config] = None, **kwargs):
     """List the refs in a remote.
     """List the refs in a remote.
 
 

+ 69 - 0
dulwich/tests/test_porcelain.py

@@ -3596,3 +3596,72 @@ class ServerTests(PorcelainTestCase):
 
 
         with self._serving() as url:
         with self._serving() as url:
             porcelain.push(self.repo, url, "master")
             porcelain.push(self.repo, url, "master")
+
+
+class ForEachTests(PorcelainTestCase):
+    def setUp(self):
+        super().setUp()
+        c1, c2, c3, c4 = build_commit_graph(
+            self.repo.object_store, [[1], [2, 1], [3, 1, 2], [4]]
+        )
+        porcelain.tag_create(
+            self.repo.path,
+            b"v0.1",
+            objectish=c1.id,
+            annotated=True,
+            message=b"0.1",
+        )
+        porcelain.tag_create(
+            self.repo.path,
+            b"v1.0",
+            objectish=c2.id,
+            annotated=True,
+            message=b"1.0",
+        )
+        porcelain.tag_create(self.repo.path, b"simple-tag", objectish=c3.id)
+        porcelain.tag_create(
+            self.repo.path,
+            b"v1.1",
+            objectish=c4.id,
+            annotated=True,
+            message=b"1.1",
+        )
+        porcelain.branch_create(
+            self.repo.path, b"feat", objectish=c2.id.decode("ascii")
+        )
+        self.repo.refs[b"HEAD"] = c4.id
+
+    def test_for_each_ref(self):
+        refs = porcelain.for_each_ref(self.repo)
+
+        self.assertEqual(
+            [(object_type, tag) for _, object_type, tag in refs],
+            [
+                (b"commit", b"refs/heads/feat"),
+                (b"commit", b"refs/heads/master"),
+                (b"commit", b"refs/tags/simple-tag"),
+                (b"tag", b"refs/tags/v0.1"),
+                (b"tag", b"refs/tags/v1.0"),
+                (b"tag", b"refs/tags/v1.1"),
+            ],
+        )
+
+    def test_for_each_ref_pattern(self):
+        versions = porcelain.for_each_ref(self.repo, pattern="refs/tags/v*")
+        self.assertEqual(
+            [(object_type, tag) for _, object_type, tag in versions],
+            [
+                (b"tag", b"refs/tags/v0.1"),
+                (b"tag", b"refs/tags/v1.0"),
+                (b"tag", b"refs/tags/v1.1"),
+            ],
+        )
+
+        versions = porcelain.for_each_ref(self.repo, pattern="refs/tags/v1.?")
+        self.assertEqual(
+            [(object_type, tag) for _, object_type, tag in versions],
+            [
+                (b"tag", b"refs/tags/v1.0"),
+                (b"tag", b"refs/tags/v1.1"),
+            ],
+        )