Jelajahi Sumber

start unifying porcelain.clone() & Repo.clone()

Peter Rowlands 4 tahun lalu
induk
melakukan
14ec9a7972
3 mengubah file dengan 163 tambahan dan 112 penghapusan
  1. 22 76
      dulwich/porcelain.py
  2. 140 36
      dulwich/repo.py
  3. 1 0
      dulwich/tests/test_repository.py

+ 22 - 76
dulwich/porcelain.py

@@ -144,7 +144,7 @@ from dulwich.refs import (
     strip_peeled_refs,
     RefsContainer,
 )
-from dulwich.repo import DEFAULT_REF, BaseRepo, Repo
+from dulwich.repo import BaseRepo, Repo
 from dulwich.server import (
     FileSystemBackend,
     TCPGitServer,
@@ -418,6 +418,8 @@ def clone(
       outstream: Optional stream to write progress to (deprecated)
       origin: Name of remote from the repository used to clone
       depth: Depth to fetch at
+      branch: Optional branch or tag to be used as HEAD in the new repository
+        instead of the cloned repository's HEAD.
     Returns: The new repository
     """
     # TODO(jelmer): This code overlaps quite a bit with Repo.clone
@@ -442,94 +444,38 @@ def clone(
     if not os.path.exists(target):
         os.mkdir(target)
 
-    if bare:
-        r = Repo.init_bare(target)
-    else:
-        r = Repo.init(target)
+    if not isinstance(source, bytes):
+        source = source.encode(DEFAULT_ENCODING)
 
-    reflog_message = b"clone: from " + source.encode("utf-8")
-    try:
-        target_config = r.get_config()
-        if not isinstance(source, bytes):
-            source = source.encode(DEFAULT_ENCODING)
-        target_config.set((b"remote", origin), b"url", source)
-        target_config.set(
-            (b"remote", origin),
-            b"fetch",
-            b"+refs/heads/*:refs/remotes/" + origin + b"/*",
-        )
-        target_config.write_to_path()
+    def clone_refs(target_repo, ref_message):
         fetch_result = fetch(
-            r,
+            target_repo,
             origin,
             errstream=errstream,
-            message=reflog_message,
+            message=ref_message,
             depth=depth,
             **kwargs
         )
         for key, target_ref in fetch_result.symrefs.items():
-            r.refs.set_symbolic_ref(key, target_ref)
-
-        head_ref = b"HEAD" if b"HEAD" in fetch_result.refs else None
-        if branch:
-            for ref in (_make_branch_ref(branch), _make_tag_ref(branch)):
-                if ref in fetch_result.refs:
-                    head_ref = ref
-                    break
+            target_repo.refs.set_symbolic_ref(key, target_ref)
+        return fetch_result.symrefs.get(b"HEAD", None)
 
-        if head_ref:
-            head = _clone_update_head(r, source, origin, head_ref, fetch_result)
-        else:
-            head = None
-
-        if checkout and not bare and head is not None:
-            errstream.write(b"Checking out " + head.id + b"\n")
-            r.reset_index(head.tree)
+    try:
+        return Repo.do_clone(
+            source,
+            target,
+            clone_refs=clone_refs,
+            mkdir=False,
+            bare=bare,
+            origin=origin,
+            checkout=checkout,
+            errstream=errstream,
+            branch=branch,
+        )
     except BaseException:
         shutil.rmtree(target)
-        r.close()
         raise
 
-    return r
-
-
-def _clone_update_head(r, source, origin, new_ref, fetch_result):
-    ref_message = b"clone: from " + source
-
-    # set refs/remotes/origin/HEAD
-    origin_head = fetch_result.symrefs.get(b"HEAD", b"")
-    if origin_head.startswith(LOCAL_BRANCH_PREFIX):
-        origin_base = b"refs/remotes/" + origin + b"/"
-        origin_ref = origin_base + b"HEAD"
-        target_ref = origin_base + origin_head[len(LOCAL_BRANCH_PREFIX) :]
-        r.refs.set_symbolic_ref(origin_ref, target_ref)
-
-    # set local HEAD
-    if new_ref.startswith(LOCAL_TAG_PREFIX):
-        # ref is a tag, detach HEAD at remote tag
-        _cls, head_sha = r[fetch_result.refs[new_ref]].object
-        head = r[head_sha]
-        del r.refs[b"HEAD"]
-        r.refs.set_if_equals(
-            b"HEAD", None, head_sha, message=ref_message
-        )
-    else:
-        head = r[fetch_result.refs[new_ref]]
-        if new_ref == b"HEAD":
-            # set HEAD to default remote branch if it differs from DEFAULT_REF
-            default_ref = fetch_result.symrefs.get(b"HEAD")
-            if default_ref and default_ref != DEFAULT_REF:
-                del r.refs[DEFAULT_REF]
-                r.refs.set_symbolic_ref(b"HEAD", default_ref)
-        else:
-            # set HEAD to specific remote branch
-            del r.refs[DEFAULT_REF]
-            r.refs.set_symbolic_ref(b"HEAD", new_ref)
-        r.refs.set_if_equals(
-            b"HEAD", None, head.id, message=ref_message
-        )
-    return head
-
 
 def add(repo=".", paths=None):
     """Add files to the staging area.

+ 140 - 36
dulwich/repo.py

@@ -87,6 +87,8 @@ from dulwich.line_ending import BlobNormalizer, TreeBlobNormalizer
 
 from dulwich.refs import (  # noqa: F401
     ANNOTATED_TAG_SUFFIX,
+    LOCAL_BRANCH_PREFIX,
+    LOCAL_TAG_PREFIX,
     check_ref_format,
     RefsContainer,
     DictRefsContainer,
@@ -1370,6 +1372,7 @@ class Repo(BaseRepo):
         bare=False,
         origin=b"origin",
         checkout=None,
+        branch=None,
     ):
         """Clone this repository.
 
@@ -1381,54 +1384,155 @@ class Repo(BaseRepo):
             cloned from this repository
         Returns: Created repository as `Repo`
         """
-        if not bare:
-            target = self.init(target_path, mkdir=mkdir)
-        else:
-            if checkout:
-                raise ValueError("checkout and bare are incompatible")
-            target = self.init_bare(target_path, mkdir=mkdir)
-        self.fetch(target)
+
+        def clone_refs(target_repo, ref_message):
+            self.fetch(target_repo)
+            target_repo.refs.import_refs(
+                b"refs/remotes/" + origin,
+                self.refs.as_dict(b"refs/heads"),
+                message=ref_message,
+            )
+            target_repo.refs.import_refs(
+                b"refs/tags", self.refs.as_dict(b"refs/tags"), message=ref_message
+            )
+            try:
+                target_repo.refs.add_if_new(
+                    DEFAULT_REF, self.refs[DEFAULT_REF], message=ref_message
+                )
+            except KeyError:
+                pass
+
+            head_chain, _sha = self.refs.follow(b"HEAD")
+            return head_chain[-1] if head_chain else None
+
         encoded_path = self.path
         if not isinstance(encoded_path, bytes):
             encoded_path = os.fsencode(encoded_path)
-        ref_message = b"clone: from " + encoded_path
-        target.refs.import_refs(
-            b"refs/remotes/" + origin,
-            self.refs.as_dict(b"refs/heads"),
-            message=ref_message,
-        )
-        target.refs.import_refs(
-            b"refs/tags", self.refs.as_dict(b"refs/tags"), message=ref_message
-        )
-        try:
-            target.refs.add_if_new(
-                DEFAULT_REF, self.refs[DEFAULT_REF], message=ref_message
-            )
-        except KeyError:
-            pass
-        target_config = target.get_config()
-        target_config.set(("remote", "origin"), "url", encoded_path)
-        target_config.set(
-            ("remote", "origin"),
-            "fetch",
-            "+refs/heads/*:refs/remotes/origin/*",
+
+        return self.do_clone(
+            encoded_path,
+            target_path,
+            clone_refs=clone_refs,
+            mkdir=mkdir,
+            bare=bare,
+            origin=origin,
+            checkout=checkout,
+            branch=branch,
         )
-        target_config.write_to_path()
 
-        # Update target head
-        head_chain, head_sha = self.refs.follow(b"HEAD")
-        if head_chain and head_sha is not None:
-            target.refs.set_symbolic_ref(b"HEAD", head_chain[-1], message=ref_message)
-            target[b"HEAD"] = head_sha
+    @classmethod
+    def do_clone(
+        cls,
+        source_path,
+        target_path,
+        clone_refs=None,
+        mkdir=True,
+        bare=False,
+        origin=b"origin",
+        checkout=None,
+        errstream=None,
+        branch=None,
+    ):
+        """Clone this repository.
+
+        Args:
+          target_path: Target path
+          mkdir: Create the target directory
+          bare: Whether to create a bare repository
+          origin: Base name for refs in target repository
+            cloned from this repository
+        Returns: Created repository as `Repo`
+        """
+        if not clone_refs:
+            raise ValueError("clone_refs callback is required")
 
+        if not bare:
+            target = cls.init(target_path, mkdir=mkdir)
             if checkout is None:
-                checkout = not bare
+                checkout = True
+        else:
             if checkout:
-                # Checkout HEAD to target dir
+                raise ValueError("checkout and bare are incompatible")
+            target = cls.init_bare(target_path, mkdir=mkdir)
+
+        try:
+            target_config = target.get_config()
+            target_config.set((b"remote", origin), b"url", source_path)
+            target_config.set(
+                (b"remote", origin),
+                b"fetch",
+                b"+refs/heads/*:refs/remotes/" + origin + b"/*",
+            )
+            target_config.write_to_path()
+
+            ref_message = b"clone: from " + source_path
+            origin_head = clone_refs(target, ref_message)
+
+            # set refs/remotes/origin/HEAD
+            if origin_head and origin_head.startswith(LOCAL_BRANCH_PREFIX):
+                cls._clone_set_origin_head(target, origin, origin_head)
+
+            head_ref = b"HEAD" if origin_head else None
+            if branch:
+                for ref in (LOCAL_BRANCH_PREFIX + branch, LOCAL_TAG_PREFIX + branch):
+                    if ref in target.refs:
+                        head_ref = ref
+                        break
+
+            # Update target head
+            if head_ref:
+                head = cls._clone_set_head(target, head_ref, ref_message)
+            else:
+                head = None
+
+            if checkout and head is not None:
+                if errstream:
+                    errstream.write(b"Checking out " + head.id + b"\n")
                 target.reset_index()
+        except BaseException:
+            target.close()
+            raise
 
         return target
 
+    @staticmethod
+    def _clone_set_origin_head(r, origin, origin_head):
+        origin_base = b"refs/remotes/" + origin + b"/"
+        origin_ref = origin_base + b"HEAD"
+        target_ref = origin_base + origin_head[len(LOCAL_BRANCH_PREFIX) :]
+        if target_ref in r.refs:
+            r.refs.set_symbolic_ref(origin_ref, target_ref)
+
+            # set HEAD to default remote branch if it differs from DEFAULT_REF
+            if origin_head != DEFAULT_REF:
+                origin_ref = origin_base + DEFAULT_REF[len(LOCAL_BRANCH_PREFIX) :]
+                if origin_ref not in r.refs:
+                    del r.refs[DEFAULT_REF]
+                r.refs.set_symbolic_ref(b"HEAD", origin_head)
+
+    @staticmethod
+    def _clone_set_head(r, head_ref, ref_message):
+        if head_ref.startswith(LOCAL_TAG_PREFIX):
+            print(r.refs)
+            # detach HEAD at specified tag
+            _cls, head_sha = r.refs[head_ref].object
+            head = r[head_sha]
+            del r.refs[b"HEAD"]
+            r.refs.set_if_equals(
+                b"HEAD", None, head_sha, message=ref_message
+            )
+        else:
+            if head_ref == b"HEAD":
+                _chain, head = r.refs.follow(head_ref)
+            else:
+                # set HEAD to specific remote branch
+                r.refs.set_symbolic_ref(b"HEAD", head_ref)
+                head = r.refs[head_ref]
+                r.refs.set_if_equals(
+                    b"HEAD", None, head.id, message=ref_message
+                )
+        return head
+
     def reset_index(self, tree=None):
         """Reset the index back to a specific tree.
 

+ 1 - 0
dulwich/tests/test_repository.py

@@ -385,6 +385,7 @@ class RepositoryRootTests(TestCase):
                 {
                     b"HEAD": b"a90fa2d900a17e99b433217e988c4eb4a2e9a097",
                     b"refs/remotes/origin/master": b"a90fa2d900a17e99b433217e988c4eb4a2e9a097",
+                    b"refs/remotes/origin/HEAD": b"a90fa2d900a17e99b433217e988c4eb4a2e9a097",
                     b"refs/heads/master": b"a90fa2d900a17e99b433217e988c4eb4a2e9a097",
                     b"refs/tags/mytag": b"28237f4dc30d0d462658d6b937b08a0f0b6ef55a",
                     b"refs/tags/mytag-packed": b"b0931cadc54336e78a1d980420e3268903b57a50",