Browse Source

split do_clone into helper functions

Peter Rowlands 4 years ago
parent
commit
e40fbdbe02
1 changed files with 41 additions and 39 deletions
  1. 41 39
      dulwich/repo.py

+ 41 - 39
dulwich/repo.py

@@ -1459,41 +1459,10 @@ class Repo(BaseRepo):
                 # set detached HEAD
                 target.refs[b"HEAD"] = origin_sha
 
-            origin_base = b"refs/remotes/" + origin + b"/"
-
-            # set refs/remotes/origin/HEAD
-            if origin_head and origin_head.startswith(LOCAL_BRANCH_PREFIX):
-                origin_ref = origin_base + b"HEAD"
-                target_ref = origin_base + origin_head[len(LOCAL_BRANCH_PREFIX) :]
-                if target_ref in target.refs:
-                    target.refs.set_symbolic_ref(origin_ref, target_ref)
-
-            if branch:
-                origin_ref = origin_base + branch
-                if origin_ref in target.refs:
-                    local_ref = LOCAL_BRANCH_PREFIX + branch
-                    target.refs.add_if_new(
-                        local_ref, target.refs[origin_ref], ref_message
-                    )
-                    head_ref = local_ref
-                elif LOCAL_TAG_PREFIX + branch in target.refs:
-                    head_ref = LOCAL_TAG_PREFIX + branch
-                else:
-                    raise ValueError(
-                        "%s is not a valid branch or tag" % os.fsencode(branch)
-                    )
-            elif origin_head:
-                head_ref = origin_head
-                if origin_head.startswith(LOCAL_BRANCH_PREFIX):
-                    origin_ref = origin_base + origin_head[len(LOCAL_BRANCH_PREFIX) :]
-                else:
-                    origin_ref = origin_head
-                try:
-                    target.refs.add_if_new(
-                        head_ref, target.refs[origin_ref], ref_message
-                    )
-                except KeyError:
-                    pass
+            cls._clone_set_origin_head(target, origin, origin_head)
+            head_ref = cls._clone_set_default_branch(
+                target, origin, origin_head, branch, ref_message
+            )
 
             # Update target head
             if head_ref:
@@ -1513,11 +1482,44 @@ class Repo(BaseRepo):
 
     @staticmethod
     def _clone_set_origin_head(r, origin, origin_head):
+        # set refs/remotes/origin/HEAD
         origin_base = b"refs/remotes/" + origin + b"/"
-        origin_ref = origin_base + b"HEAD"
-        target_ref = origin_base + origin_head[len(LOCAL_BRANCH_PREFIX) :]
-        if target_ref in r.refs:
-            r.refs.set_symbolic_ref(origin_ref, target_ref)
+        if origin_head and origin_head.startswith(LOCAL_BRANCH_PREFIX):
+            origin_ref = origin_base + b"HEAD"
+            target_ref = origin_base + origin_head[len(LOCAL_BRANCH_PREFIX) :]
+            if target_ref in r.refs:
+                r.refs.set_symbolic_ref(origin_ref, target_ref)
+
+    @staticmethod
+    def _clone_set_default_branch(r, origin, origin_head, branch, ref_message):
+        origin_base = b"refs/remotes/" + origin + b"/"
+        if branch:
+            origin_ref = origin_base + branch
+            if origin_ref in r.refs:
+                local_ref = LOCAL_BRANCH_PREFIX + branch
+                r.refs.add_if_new(
+                    local_ref, r.refs[origin_ref], ref_message
+                )
+                head_ref = local_ref
+            elif LOCAL_TAG_PREFIX + branch in r.refs:
+                head_ref = LOCAL_TAG_PREFIX + branch
+            else:
+                raise ValueError(
+                    "%s is not a valid branch or tag" % os.fsencode(branch)
+                )
+        elif origin_head:
+            head_ref = origin_head
+            if origin_head.startswith(LOCAL_BRANCH_PREFIX):
+                origin_ref = origin_base + origin_head[len(LOCAL_BRANCH_PREFIX) :]
+            else:
+                origin_ref = origin_head
+            try:
+                r.refs.add_if_new(
+                    head_ref, r.refs[origin_ref], ref_message
+                )
+            except KeyError:
+                pass
+        return head_ref
 
     @staticmethod
     def _clone_set_head(r, head_ref, ref_message):