Browse Source

Merge branch 'clone-branch' of git://github.com/pmrowla/dulwich

Jelmer Vernooij 3 năm trước cách đây
mục cha
commit
3d3fcbc2f9
5 tập tin đã thay đổi với 312 bổ sung86 xóa
  1. 187 0
      dulwich/clone.py
  2. 30 38
      dulwich/porcelain.py
  3. 40 44
      dulwich/repo.py
  4. 12 4
      dulwich/tests/test_porcelain.py
  5. 43 0
      dulwich/tests/test_repository.py

+ 187 - 0
dulwich/clone.py

@@ -0,0 +1,187 @@
+# clone.py
+# Copyright (C) 2021 Jelmer Vernooij <jelmer@samba.org>
+#
+# Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU
+# General Public License as public by the Free Software Foundation; version 2.0
+# or (at your option) any later version. You can redistribute it and/or
+# modify it under the terms of either of these two licenses.
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# You should have received a copy of the licenses; if not, see
+# <http://www.gnu.org/licenses/> for a copy of the GNU General Public License
+# and <http://www.apache.org/licenses/LICENSE-2.0> for a copy of the Apache
+# License, Version 2.0.
+#
+
+"""Repository clone handling."""
+
+import os
+import shutil
+from typing import TYPE_CHECKING, Callable, Tuple
+
+from dulwich.objects import (
+    Tag,
+)
+from dulwich.refs import (
+    LOCAL_BRANCH_PREFIX,
+    LOCAL_TAG_PREFIX,
+)
+
+if TYPE_CHECKING:
+    from dulwich.repo import Repo
+
+
+def do_clone(
+    source_path,
+    target_path,
+    clone_refs: Callable[["Repo", bytes], Tuple[bytes, bytes]] = None,
+    mkdir=True,
+    bare=False,
+    origin=b"origin",
+    checkout=None,
+    errstream=None,
+    branch=None,
+):
+    """Clone a repository.
+
+    Args:
+      source_path: Source repository path
+      target_path: Target repository path
+      clone_refs: Callback to handle setting up cloned remote refs in
+        the target repo
+      mkdir: Create the target directory
+      bare: Whether to create a bare repository
+      checkout: Whether or not to check-out HEAD after cloning
+      origin: Base name for refs in target repository
+        cloned from this repository
+      branch: Optional branch or tag to be used as HEAD in the new repository
+        instead of the source repository's HEAD.
+    Returns: Created repository as `Repo`
+    """
+    from dulwich.repo import Repo
+
+    if not clone_refs:
+        raise ValueError("clone_refs callback is required")
+
+    if mkdir:
+        os.mkdir(target_path)
+
+    try:
+        target = None
+        if not bare:
+            target = Repo.init(target_path)
+            if checkout is None:
+                checkout = True
+        else:
+            if checkout:
+                raise ValueError("checkout and bare are incompatible")
+            target = Repo.init_bare(target_path)
+
+        target_config = target.get_config()
+        target_config.set((b"remote", origin), b"url", source_path)
+        target_config.set(
+            (b"remote", origin),
+            b"fetch",
+            b"+refs/heads/*:refs/remotes/" + origin + b"/*",
+        )
+        target_config.write_to_path()
+
+        ref_message = b"clone: from " + source_path
+        origin_head, origin_sha = clone_refs(target, ref_message)
+        if origin_sha and not origin_head:
+            # set detached HEAD
+            target.refs[b"HEAD"] = origin_sha
+
+        _set_origin_head(target, origin, origin_head)
+        head_ref = _set_default_branch(
+            target, origin, origin_head, branch, ref_message
+        )
+
+        # Update target head
+        if head_ref:
+            head = _set_head(target, head_ref, ref_message)
+        else:
+            head = None
+
+        if checkout and head is not None:
+            if errstream:
+                errstream.write(b"Checking out " + head + b"\n")
+            target.reset_index()
+    except BaseException:
+        if target is not None:
+            target.close()
+        if mkdir:
+            shutil.rmtree(target_path)
+        raise
+
+    return target
+
+
+def _set_origin_head(r, origin, origin_head):
+    # set refs/remotes/origin/HEAD
+    origin_base = b"refs/remotes/" + origin + b"/"
+    if origin_head and origin_head.startswith(LOCAL_BRANCH_PREFIX):
+        origin_ref = origin_base + b"HEAD"
+        target_ref = origin_base + origin_head[len(LOCAL_BRANCH_PREFIX) :]
+        if target_ref in r.refs:
+            r.refs.set_symbolic_ref(origin_ref, target_ref)
+
+
+def _set_default_branch(r, origin, origin_head, branch, ref_message):
+    origin_base = b"refs/remotes/" + origin + b"/"
+    if branch:
+        origin_ref = origin_base + branch
+        if origin_ref in r.refs:
+            local_ref = LOCAL_BRANCH_PREFIX + branch
+            r.refs.add_if_new(
+                local_ref, r.refs[origin_ref], ref_message
+            )
+            head_ref = local_ref
+        elif LOCAL_TAG_PREFIX + branch in r.refs:
+            head_ref = LOCAL_TAG_PREFIX + branch
+        else:
+            raise ValueError(
+                "%s is not a valid branch or tag" % os.fsencode(branch)
+            )
+    elif origin_head:
+        head_ref = origin_head
+        if origin_head.startswith(LOCAL_BRANCH_PREFIX):
+            origin_ref = origin_base + origin_head[len(LOCAL_BRANCH_PREFIX) :]
+        else:
+            origin_ref = origin_head
+        try:
+            r.refs.add_if_new(
+                head_ref, r.refs[origin_ref], ref_message
+            )
+        except KeyError:
+            pass
+    return head_ref
+
+
+def _set_head(r, head_ref, ref_message):
+    if head_ref.startswith(LOCAL_TAG_PREFIX):
+        # detach HEAD at specified tag
+        head = r.refs[head_ref]
+        if isinstance(head, Tag):
+            _cls, obj = head.object
+            head = obj.get_object(obj).id
+        del r.refs[b"HEAD"]
+        r.refs.set_if_equals(
+            b"HEAD", None, head, message=ref_message
+        )
+    else:
+        # set HEAD to specific branch
+        try:
+            head = r.refs[head_ref]
+            r.refs.set_symbolic_ref(b"HEAD", head_ref)
+            r.refs.set_if_equals(
+                b"HEAD", None, head, message=ref_message
+            )
+        except KeyError:
+            head = None
+    return head

+ 30 - 38
dulwich/porcelain.py

@@ -70,7 +70,6 @@ import datetime
 import os
 from pathlib import Path
 import posixpath
-import shutil
 import stat
 import sys
 import time
@@ -87,6 +86,9 @@ from dulwich.archive import (
 from dulwich.client import (
     get_transport_and_path,
 )
+from dulwich.clone import (
+    do_clone,
+)
 from dulwich.config import (
     StackedConfig,
 )
@@ -140,6 +142,7 @@ from dulwich.protocol import (
 from dulwich.refs import (
     ANNOTATED_TAG_SUFFIX,
     LOCAL_BRANCH_PREFIX,
+    LOCAL_TAG_PREFIX,
     strip_peeled_refs,
     RefsContainer,
 )
@@ -403,6 +406,7 @@ def clone(
     outstream=None,
     origin=b"origin",
     depth=None,
+    branch=None,
     **kwargs
 ):
     """Clone a local or remote git repository.
@@ -416,9 +420,10 @@ def clone(
       outstream: Optional stream to write progress to (deprecated)
       origin: Name of remote from the repository used to clone
       depth: Depth to fetch at
+      branch: Optional branch or tag to be used as HEAD in the new repository
+        instead of the cloned repository's HEAD.
     Returns: The new repository
     """
-    # TODO(jelmer): This code overlaps quite a bit with Repo.clone
     if outstream is not None:
         import warnings
 
@@ -437,51 +442,38 @@ def clone(
     if target is None:
         target = source.split("/")[-1]
 
-    if not os.path.exists(target):
-        os.mkdir(target)
+    mkdir = not os.path.exists(target)
 
-    if bare:
-        r = Repo.init_bare(target)
-    else:
-        r = Repo.init(target)
+    if not isinstance(source, bytes):
+        source = source.encode(DEFAULT_ENCODING)
 
-    reflog_message = b"clone: from " + source.encode("utf-8")
-    try:
-        target_config = r.get_config()
-        if not isinstance(source, bytes):
-            source = source.encode(DEFAULT_ENCODING)
-        target_config.set((b"remote", origin), b"url", source)
-        target_config.set(
-            (b"remote", origin),
-            b"fetch",
-            b"+refs/heads/*:refs/remotes/" + origin + b"/*",
-        )
-        target_config.write_to_path()
+    def clone_refs(target_repo, ref_message):
         fetch_result = fetch(
-            r,
+            target_repo,
             origin,
             errstream=errstream,
-            message=reflog_message,
+            message=ref_message,
             depth=depth,
             **kwargs
         )
-        for key, target in fetch_result.symrefs.items():
-            r.refs.set_symbolic_ref(key, target)
+        head_ref = fetch_result.symrefs.get(b"HEAD", None)
         try:
-            head = r[fetch_result.refs[b"HEAD"]]
+            head_sha = target_repo[fetch_result.refs[b"HEAD"]].id
         except KeyError:
-            head = None
-        else:
-            r[b"HEAD"] = head.id
-        if checkout and not bare and head is not None:
-            errstream.write(b"Checking out " + head.id + b"\n")
-            r.reset_index(head.tree)
-    except BaseException:
-        shutil.rmtree(target)
-        r.close()
-        raise
-
-    return r
+            head_sha = None
+        return head_ref, head_sha
+
+    return do_clone(
+        source,
+        target,
+        clone_refs=clone_refs,
+        mkdir=mkdir,
+        bare=bare,
+        origin=origin,
+        checkout=checkout,
+        errstream=errstream,
+        branch=branch,
+    )
 
 
 def add(repo=".", paths=None):
@@ -1430,7 +1422,7 @@ def _make_branch_ref(name):
 def _make_tag_ref(name):
     if getattr(name, "encode", None):
         name = name.encode(DEFAULT_ENCODING)
-    return b"refs/tags/" + name
+    return LOCAL_TAG_PREFIX + name
 
 
 def branch_delete(repo, name):

+ 40 - 44
dulwich/repo.py

@@ -42,6 +42,9 @@ if TYPE_CHECKING:
     from dulwich.config import StackedConfig, ConfigFile
     from dulwich.index import Index
 
+from dulwich.clone import (
+    do_clone,
+)
 from dulwich.errors import (
     NoIndexPresent,
     NotBlobError,
@@ -87,6 +90,8 @@ from dulwich.line_ending import BlobNormalizer, TreeBlobNormalizer
 
 from dulwich.refs import (  # noqa: F401
     ANNOTATED_TAG_SUFFIX,
+    LOCAL_BRANCH_PREFIX,
+    LOCAL_TAG_PREFIX,
     check_ref_format,
     RefsContainer,
     DictRefsContainer,
@@ -1383,6 +1388,7 @@ class Repo(BaseRepo):
         bare=False,
         origin=b"origin",
         checkout=None,
+        branch=None,
     ):
         """Clone this repository.
 
@@ -1390,57 +1396,43 @@ class Repo(BaseRepo):
           target_path: Target path
           mkdir: Create the target directory
           bare: Whether to create a bare repository
+          checkout: Whether or not to check-out HEAD after cloning
           origin: Base name for refs in target repository
             cloned from this repository
+          branch: Optional branch or tag to be used as HEAD in the new repository
+            instead of this repository's HEAD.
         Returns: Created repository as `Repo`
         """
-        if not bare:
-            target = self.init(target_path, mkdir=mkdir)
-        else:
-            if checkout:
-                raise ValueError("checkout and bare are incompatible")
-            target = self.init_bare(target_path, mkdir=mkdir)
-        self.fetch(target)
-        encoded_path = self.path
-        if not isinstance(encoded_path, bytes):
-            encoded_path = os.fsencode(encoded_path)
-        ref_message = b"clone: from " + encoded_path
-        target.refs.import_refs(
-            b"refs/remotes/" + origin,
-            self.refs.as_dict(b"refs/heads"),
-            message=ref_message,
-        )
-        target.refs.import_refs(
-            b"refs/tags", self.refs.as_dict(b"refs/tags"), message=ref_message
-        )
-        try:
-            target.refs.add_if_new(
-                DEFAULT_REF, self.refs[DEFAULT_REF], message=ref_message
+
+        def clone_refs(target_repo, ref_message):
+            self.fetch(target_repo)
+            target_repo.refs.import_refs(
+                b"refs/remotes/" + origin,
+                self.refs.as_dict(b"refs/heads"),
+                message=ref_message,
+            )
+            target_repo.refs.import_refs(
+                b"refs/tags", self.refs.as_dict(b"refs/tags"), message=ref_message
             )
-        except KeyError:
-            pass
-        target_config = target.get_config()
-        target_config.set(("remote", "origin"), "url", encoded_path)
-        target_config.set(
-            ("remote", "origin"),
-            "fetch",
-            "+refs/heads/*:refs/remotes/origin/*",
-        )
-        target_config.write_to_path()
 
-        # Update target head
-        head_chain, head_sha = self.refs.follow(b"HEAD")
-        if head_chain and head_sha is not None:
-            target.refs.set_symbolic_ref(b"HEAD", head_chain[-1], message=ref_message)
-            target[b"HEAD"] = head_sha
+            head_chain, sha = self.refs.follow(b"HEAD")
+            head_chain = head_chain[-1] if head_chain else None
+            return head_chain, sha
 
-            if checkout is None:
-                checkout = not bare
-            if checkout:
-                # Checkout HEAD to target dir
-                target.reset_index()
+        encoded_path = self.path
+        if not isinstance(encoded_path, bytes):
+            encoded_path = os.fsencode(encoded_path)
 
-        return target
+        return do_clone(
+            encoded_path,
+            target_path,
+            clone_refs=clone_refs,
+            mkdir=mkdir,
+            bare=bare,
+            origin=origin,
+            checkout=checkout,
+            branch=branch,
+        )
 
     def reset_index(self, tree=None):
         """Reset the index back to a specific tree.
@@ -1455,7 +1447,11 @@ class Repo(BaseRepo):
         )
 
         if tree is None:
-            tree = self[b"HEAD"].tree
+            head = self[b"HEAD"]
+            if isinstance(head, Tag):
+                _cls, obj = head.object
+                head = self.get_object(obj)
+            tree = head.tree
         config = self.get_config()
         honor_filemode = config.get_boolean(b"core", b"filemode", os.name != "nt")
         if config.get_boolean(b"core", b"core.protectNTFS", os.name == "nt"):

+ 12 - 4
dulwich/tests/test_porcelain.py

@@ -630,9 +630,12 @@ class CloneTests(PorcelainTestCase):
         r.close()
 
     def test_source_broken(self):
-        target_path = tempfile.mkdtemp()
-        self.assertRaises(Exception, porcelain.clone, "/nonexistant/repo", target_path)
-        self.assertFalse(os.path.exists(target_path))
+        with tempfile.TemporaryDirectory() as parent:
+            target_path = os.path.join(parent, "target")
+            self.assertRaises(
+                Exception, porcelain.clone, "/nonexistant/repo", target_path
+            )
+            self.assertFalse(os.path.exists(target_path))
 
     def test_fetch_symref(self):
         f1_1 = make_object(Blob, data=b"f1")
@@ -652,7 +655,10 @@ class CloneTests(PorcelainTestCase):
         self.assertEqual(0, len(target_repo.open_index()))
         self.assertEqual(c1.id, target_repo.refs[b"refs/heads/else"])
         self.assertEqual(c1.id, target_repo.refs[b"HEAD"])
-        self.assertEqual({b"HEAD": b"refs/heads/else"}, target_repo.refs.get_symrefs())
+        self.assertEqual(
+            {b"HEAD": b"refs/heads/else", b"refs/remotes/origin/HEAD": b"refs/remotes/origin/else"},
+            target_repo.refs.get_symrefs(),
+        )
 
 
 class InitTests(TestCase):
@@ -2385,6 +2391,8 @@ class FetchTests(PorcelainTestCase):
             for k, v in remote_refs.items()
             if k.startswith(local_ref_prefix)
         }
+        if b"HEAD" in locally_known_remote_refs and b"HEAD" in remote_refs:
+            normalized_remote_refs[b"HEAD"] = remote_refs[b"HEAD"]
 
         self.assertEqual(locally_known_remote_refs, normalized_remote_refs)
 

+ 43 - 0
dulwich/tests/test_repository.py

@@ -385,6 +385,7 @@ class RepositoryRootTests(TestCase):
                 {
                     b"HEAD": b"a90fa2d900a17e99b433217e988c4eb4a2e9a097",
                     b"refs/remotes/origin/master": b"a90fa2d900a17e99b433217e988c4eb4a2e9a097",
+                    b"refs/remotes/origin/HEAD": b"a90fa2d900a17e99b433217e988c4eb4a2e9a097",
                     b"refs/heads/master": b"a90fa2d900a17e99b433217e988c4eb4a2e9a097",
                     b"refs/tags/mytag": b"28237f4dc30d0d462658d6b937b08a0f0b6ef55a",
                     b"refs/tags/mytag-packed": b"b0931cadc54336e78a1d980420e3268903b57a50",
@@ -451,6 +452,48 @@ class RepositoryRootTests(TestCase):
             ValueError, r.clone, tmp_dir, mkdir=False, checkout=True, bare=True
         )
 
+    def test_clone_branch(self):
+        r = self.open_repo("a.git")
+        r.refs[b"refs/heads/mybranch"] = b"28237f4dc30d0d462658d6b937b08a0f0b6ef55a"
+        tmp_dir = self.mkdtemp()
+        self.addCleanup(shutil.rmtree, tmp_dir)
+        with r.clone(tmp_dir, mkdir=False, branch=b"mybranch") as t:
+            # HEAD should point to specified branch and not origin HEAD
+            chain, sha = t.refs.follow(b"HEAD")
+            self.assertEqual(chain[-1], b"refs/heads/mybranch")
+            self.assertEqual(sha, b"28237f4dc30d0d462658d6b937b08a0f0b6ef55a")
+            self.assertEqual(
+                t.refs[b"refs/remotes/origin/HEAD"],
+                b"a90fa2d900a17e99b433217e988c4eb4a2e9a097",
+            )
+
+    def test_clone_tag(self):
+        r = self.open_repo("a.git")
+        tmp_dir = self.mkdtemp()
+        self.addCleanup(shutil.rmtree, tmp_dir)
+        with r.clone(tmp_dir, mkdir=False, branch=b"mytag") as t:
+            # HEAD should be detached (and not a symbolic ref) at tag
+            self.assertEqual(
+                t.refs.read_ref(b"HEAD"),
+                b"28237f4dc30d0d462658d6b937b08a0f0b6ef55a",
+            )
+            self.assertEqual(
+                t.refs[b"refs/remotes/origin/HEAD"],
+                b"a90fa2d900a17e99b433217e988c4eb4a2e9a097",
+            )
+
+    def test_clone_invalid_branch(self):
+        r = self.open_repo("a.git")
+        tmp_dir = self.mkdtemp()
+        self.addCleanup(shutil.rmtree, tmp_dir)
+        self.assertRaises(
+            ValueError,
+            r.clone,
+            tmp_dir,
+            mkdir=False,
+            branch=b"mybranch",
+        )
+
     def test_merge_history(self):
         r = self.open_repo("simple_merge.git")
         shas = [e.commit.id for e in r.get_walker()]