Browse Source

unify behavior for setting correct HEAD and default branch

Peter Rowlands 4 years ago
parent
commit
c1c9a9dc6f
2 changed files with 49 additions and 32 deletions
  1. 0 2
      dulwich/porcelain.py
  2. 49 30
      dulwich/repo.py

+ 0 - 2
dulwich/porcelain.py

@@ -455,8 +455,6 @@ def clone(
             depth=depth,
             **kwargs
         )
-        for key, target_ref in fetch_result.symrefs.items():
-            target_repo.refs.set_symbolic_ref(key, target_ref)
         head_ref = fetch_result.symrefs.get(b"HEAD", None)
         try:
             head_sha = target_repo[fetch_result.refs[b"HEAD"]].id

+ 49 - 30
dulwich/repo.py

@@ -1395,12 +1395,6 @@ class Repo(BaseRepo):
             target_repo.refs.import_refs(
                 b"refs/tags", self.refs.as_dict(b"refs/tags"), message=ref_message
             )
-            try:
-                target_repo.refs.add_if_new(
-                    DEFAULT_REF, self.refs[DEFAULT_REF], message=ref_message
-                )
-            except KeyError:
-                pass
 
             head_chain, sha = self.refs.follow(b"HEAD")
             head_chain = head_chain[-1] if head_chain else None
@@ -1468,19 +1462,45 @@ class Repo(BaseRepo):
 
             ref_message = b"clone: from " + source_path
             origin_head, origin_sha = clone_refs(target, ref_message)
-            if origin_sha:
+            if origin_sha and not origin_head:
+                # set detached HEAD
                 target.refs[b"HEAD"] = origin_sha
 
+            origin_base = b"refs/remotes/" + origin + b"/"
+
             # set refs/remotes/origin/HEAD
             if origin_head and origin_head.startswith(LOCAL_BRANCH_PREFIX):
-                cls._clone_set_origin_head(target, origin, origin_head)
+                origin_ref = origin_base + b"HEAD"
+                target_ref = origin_base + origin_head[len(LOCAL_BRANCH_PREFIX) :]
+                if target_ref in target.refs:
+                    target.refs.set_symbolic_ref(origin_ref, target_ref)
 
-            head_ref = b"HEAD" if origin_head else None
             if branch:
-                for ref in (LOCAL_BRANCH_PREFIX + branch, LOCAL_TAG_PREFIX + branch):
-                    if ref in target.refs:
-                        head_ref = ref
-                        break
+                origin_ref = origin_base + branch
+                if origin_ref in target.refs:
+                    local_ref = LOCAL_BRANCH_PREFIX + branch
+                    target.refs.add_if_new(
+                        local_ref, target.refs[origin_ref], ref_message
+                    )
+                    head_ref = local_ref
+                elif LOCAL_TAG_PREFIX + branch in target.refs:
+                    head_ref = LOCAL_TAG_PREFIX + branch
+                else:
+                    raise ValueError(
+                        "%s is not a valid branch or tag" % os.fsencode(branch)
+                    )
+            elif origin_head:
+                head_ref = origin_head
+                if origin_head.startswith(LOCAL_BRANCH_PREFIX):
+                    origin_ref = origin_base + origin_head[len(LOCAL_BRANCH_PREFIX) :]
+                else:
+                    origin_ref = origin_head
+                try:
+                    target.refs.add_if_new(
+                        head_ref, target.refs[origin_ref], ref_message
+                    )
+                except KeyError:
+                    pass
 
             # Update target head
             if head_ref:
@@ -1506,33 +1526,28 @@ class Repo(BaseRepo):
         if target_ref in r.refs:
             r.refs.set_symbolic_ref(origin_ref, target_ref)
 
-            # set HEAD to default remote branch if it differs from DEFAULT_REF
-            if origin_head != DEFAULT_REF:
-                origin_ref = origin_base + DEFAULT_REF[len(LOCAL_BRANCH_PREFIX) :]
-                if origin_ref not in r.refs:
-                    del r.refs[DEFAULT_REF]
-                r.refs.set_symbolic_ref(b"HEAD", origin_head)
-
     @staticmethod
     def _clone_set_head(r, head_ref, ref_message):
         if head_ref.startswith(LOCAL_TAG_PREFIX):
             # detach HEAD at specified tag
-            _cls, head_sha = r.refs[head_ref].object
-            head = r[head_sha]
+            head = r.refs[head_ref]
+            if isinstance(head, Tag):
+                _cls, obj = head.object
+                head = obj.get_object(obj).id
             del r.refs[b"HEAD"]
             r.refs.set_if_equals(
-                b"HEAD", None, head_sha, message=ref_message
+                b"HEAD", None, head, message=ref_message
             )
         else:
-            if head_ref == b"HEAD":
-                _chain, head = r.refs.follow(head_ref)
-            else:
-                # set HEAD to specific remote branch
-                r.refs.set_symbolic_ref(b"HEAD", head_ref)
+            # set HEAD to specific branch
+            try:
                 head = r.refs[head_ref]
+                r.refs.set_symbolic_ref(b"HEAD", head_ref)
                 r.refs.set_if_equals(
-                    b"HEAD", None, head.id, message=ref_message
+                    b"HEAD", None, head, message=ref_message
                 )
+            except KeyError:
+                head = None
         return head
 
     def reset_index(self, tree=None):
@@ -1548,7 +1563,11 @@ class Repo(BaseRepo):
         )
 
         if tree is None:
-            tree = self[b"HEAD"].tree
+            head = self[b"HEAD"]
+            if isinstance(head, Tag):
+                _cls, obj = head.object
+                head = self.get_object(obj)
+            tree = head.tree
         config = self.get_config()
         honor_filemode = config.get_boolean(b"core", b"filemode", os.name != "nt")
         if config.get_boolean(b"core", b"core.protectNTFS", os.name == "nt"):