فهرست منبع

Add support for storing remote refs during push/pull. Fixes #752

Jelmer Vernooij 4 سال پیش
والد
کامیت
cba6ef9506
4فایلهای تغییر یافته به همراه106 افزوده شده و 43 حذف شده
  1. 5 0
      NEWS
  2. 81 35
      dulwich/porcelain.py
  3. 7 3
      dulwich/refs.py
  4. 13 5
      dulwich/tests/test_porcelain.py

+ 5 - 0
NEWS

@@ -1,3 +1,8 @@
+0.21.0	UNRELEASED
+
+ * Add support for remembering remote refs after push/pull.
+   (Jelmer Vernooij, #752)
+
 0.20.2	2020-06-01
 
  * Brown bag release to fix uploads of Windows wheels.

+ 81 - 35
dulwich/porcelain.py

@@ -70,6 +70,12 @@ import shutil
 import stat
 import sys
 import time
+from typing import (
+    Dict,
+    Optional,
+    Tuple,
+    Union,
+    )
 
 from dulwich.archive import (
     tar_stream,
@@ -127,6 +133,7 @@ from dulwich.refs import (
     ANNOTATED_TAG_SUFFIX,
     LOCAL_BRANCH_PREFIX,
     strip_peeled_refs,
+    RefsContainer,
 )
 from dulwich.repo import (BaseRepo, Repo)
 from dulwich.server import (
@@ -354,9 +361,6 @@ def clone(source, target=None, bare=False, checkout=None,
 
     reflog_message = b'clone: from ' + source.encode('utf-8')
     try:
-        fetch_result = fetch(
-            r, source, origin, errstream=errstream, message=reflog_message,
-            depth=depth, **kwargs)
         target_config = r.get_config()
         if not isinstance(source, bytes):
             source = source.encode(DEFAULT_ENCODING)
@@ -365,10 +369,13 @@ def clone(source, target=None, bare=False, checkout=None,
             (b'remote', origin), b'fetch',
             b'+refs/heads/*:refs/remotes/' + origin + b'/*')
         target_config.write_to_path()
+        fetch_result = fetch(
+            r, origin, errstream=errstream, message=reflog_message,
+            depth=depth, **kwargs)
         # TODO(jelmer): Support symref capability,
         # https://github.com/jelmer/dulwich/issues/485
         try:
-            head = r[fetch_result[b'HEAD']]
+            head = r[fetch_result.refs[b'HEAD']]
         except KeyError:
             head = None
         else:
@@ -869,7 +876,34 @@ def reset(repo, mode, treeish="HEAD"):
         r.reset_index(tree.id)
 
 
-def push(repo, remote_location, refspecs,
+def get_remote_repo(
+        repo: Repo,
+        remote_location: Optional[Union[str, bytes]] = None
+        ) -> Tuple[Optional[str], str]:
+    config = repo.get_config()
+    if remote_location is None:
+        remote_location = get_branch_remote(repo)
+    if isinstance(remote_location, str):
+        encoded_location = remote_location.encode()
+    else:
+        encoded_location = remote_location
+
+    section = (b'remote', encoded_location)
+
+    remote_name: Optional[str]
+
+    if config.has_section(section):
+        remote_name = encoded_location.decode()
+        url = config.get(section, 'url')
+        encoded_location = url
+    else:
+        remote_name = None
+        config = None
+
+    return (remote_name, encoded_location.decode())
+
+
+def push(repo, remote_location=None, refspecs=None,
          outstream=default_bytes_out_stream,
          errstream=default_bytes_err_stream, **kwargs):
     """Remote push with dulwich via dulwich.client
@@ -884,12 +918,14 @@ def push(repo, remote_location, refspecs,
 
     # Open the repo
     with open_repo_closing(repo) as r:
+        (remote_name, remote_location) = get_remote_repo(r, remote_location)
 
         # Get the client and path
         client, path = get_transport_and_path(
                 remote_location, config=r.get_config_stack(), **kwargs)
 
         selected_refs = []
+        remote_changed_refs = {}
 
         def update_refs(refs):
             selected_refs.extend(parse_reftuples(r.refs, refs, refspecs))
@@ -898,8 +934,10 @@ def push(repo, remote_location, refspecs,
             for (lh, rh, force) in selected_refs:
                 if lh is None:
                     new_refs[rh] = ZERO_SHA
+                    remote_changed_refs[rh] = None
                 else:
                     new_refs[rh] = r.refs[lh]
+                    remote_changed_refs[rh] = r.refs[lh]
             return new_refs
 
         err_encoding = getattr(errstream, 'encoding', None) or DEFAULT_ENCODING
@@ -919,6 +957,9 @@ def push(repo, remote_location, refspecs,
             errstream.write(b"Push to " + remote_location_bytes +
                             b" failed -> " + e.args[0] + b"\n")
 
+        if remote_name is not None:
+            _import_remote_refs(r.refs, remote_name, remote_changed_refs)
+
 
 def pull(repo, remote_location=None, refspecs=None,
          outstream=default_bytes_out_stream,
@@ -934,14 +975,7 @@ def pull(repo, remote_location=None, refspecs=None,
     """
     # Open the repo
     with open_repo_closing(repo) as r:
-        if remote_location is None:
-            config = r.get_config()
-            remote_name = get_branch_remote(r.path)
-            section = (b'remote', remote_name)
-
-            if config.has_section(section):
-                url = config.get(section, 'url')
-                remote_location = url.decode()
+        (remote_name, remote_location) = get_remote_repo(r, remote_location)
 
         if refspecs is None:
             refspecs = [b"HEAD"]
@@ -963,6 +997,8 @@ def pull(repo, remote_location=None, refspecs=None,
         # Perform 'git checkout .' - syncs staged changes
         tree = r[b"HEAD"].tree
         r.reset_index(tree=tree)
+        if remote_name is not None:
+            _import_remote_refs(r.refs, remote_name, fetch_result.refs)
 
 
 def status(repo=".", ignored=False):
@@ -1251,21 +1287,40 @@ def get_branch_remote(repo):
         branch_name = active_branch(r.path)
         config = r.get_config()
         try:
-            remote_name = config.get((b'branch', branch_name), 'remote')
+            remote_name = config.get((b'branch', branch_name), b'remote')
         except KeyError:
             remote_name = b'origin'
     return remote_name
 
 
-def fetch(repo, remote_location, remote_name=b'origin', outstream=sys.stdout,
-          errstream=default_bytes_err_stream, message=None, depth=None,
-          prune=False, prune_tags=False, **kwargs):
+def _import_remote_refs(
+        refs_container: RefsContainer, remote_name: str,
+        refs: Dict[str, str], message: Optional[bytes] = None,
+        prune: bool = False, prune_tags: bool = False):
+    stripped_refs = strip_peeled_refs(refs)
+    branches = {
+        n[len(LOCAL_BRANCH_PREFIX):]: v for (n, v) in stripped_refs.items()
+        if n.startswith(LOCAL_BRANCH_PREFIX)}
+    refs_container.import_refs(
+        b'refs/remotes/' + remote_name.encode(), branches, message=message,
+        prune=prune)
+    tags = {
+        n[len(b'refs/tags/'):]: v for (n, v) in stripped_refs.items()
+        if n.startswith(b'refs/tags/') and
+        not n.endswith(ANNOTATED_TAG_SUFFIX)}
+    refs_container.import_refs(
+        b'refs/tags', tags, message=message,
+        prune=prune_tags)
+
+
+def fetch(repo, remote_location=None,
+          outstream=sys.stdout, errstream=default_bytes_err_stream,
+          message=None, depth=None, prune=False, prune_tags=False, **kwargs):
     """Fetch objects from a remote server.
 
     Args:
       repo: Path to the repository
       remote_location: String identifying a remote server
-      remote_name: Name for remote server
       outstream: Output stream (defaults to stdout)
       errstream: Error stream (defaults to stderr)
       message: Reflog message (defaults to b"fetch: from <remote_name>")
@@ -1275,28 +1330,19 @@ def fetch(repo, remote_location, remote_name=b'origin', outstream=sys.stdout,
     Returns:
       Dictionary with refs on the remote
     """
-    if message is None:
-        message = b'fetch: from ' + remote_location.encode("utf-8")
     with open_repo_closing(repo) as r:
+        (remote_name, remote_location) = get_remote_repo(r, remote_location)
+        if message is None:
+            message = b'fetch: from ' + remote_location.encode("utf-8")
         client, path = get_transport_and_path(
             remote_location, config=r.get_config_stack(), **kwargs)
         fetch_result = client.fetch(path, r, progress=errstream.write,
                                     depth=depth)
-        stripped_refs = strip_peeled_refs(fetch_result.refs)
-        branches = {
-            n[len(LOCAL_BRANCH_PREFIX):]: v for (n, v) in stripped_refs.items()
-            if n.startswith(LOCAL_BRANCH_PREFIX)}
-        r.refs.import_refs(
-            b'refs/remotes/' + remote_name, branches, message=message,
-            prune=prune)
-        tags = {
-            n[len(b'refs/tags/'):]: v for (n, v) in stripped_refs.items()
-            if n.startswith(b'refs/tags/') and
-            not n.endswith(ANNOTATED_TAG_SUFFIX)}
-        r.refs.import_refs(
-            b'refs/tags', tags, message=message,
-            prune=prune_tags)
-    return fetch_result.refs
+        if remote_name is not None:
+            _import_remote_refs(
+                r.refs, remote_name, fetch_result.refs, message, prune=prune,
+                prune_tags=prune_tags)
+    return fetch_result
 
 
 def ls_remote(remote, config=None, **kwargs):

+ 7 - 3
dulwich/refs.py

@@ -146,15 +146,19 @@ class RefsContainer(object):
         else:
             to_delete = set()
         for name, value in other.items():
-            self.set_if_equals(b'/'.join((base, name)), None, value,
-                               message=message)
+            if value is None:
+                to_delete.add(name)
+            else:
+                self.set_if_equals(b'/'.join((base, name)), None, value,
+                                   message=message)
             if to_delete:
                 try:
                     to_delete.remove(name)
                 except KeyError:
                     pass
         for ref in to_delete:
-            self.remove_if_equals(b'/'.join((base, ref)), None)
+            self.remove_if_equals(
+                b'/'.join((base, ref)), None, message=message)
 
     def allkeys(self):
         """All refs present in this container."""

+ 13 - 5
dulwich/tests/test_porcelain.py

@@ -245,7 +245,7 @@ class CloneTests(PorcelainTestCase):
         self.assertEqual(r.path, target_path)
         target_repo = Repo(target_path)
         self.assertEqual(0, len(target_repo.open_index()))
-        self.assertEqual(c3.id, target_repo.refs[b'refs/tags/foo'])
+        self.assertEqual(c3.id, target_repo.refs[b'refs/tags/foo'], target_repo.refs.as_dict())
         self.assertTrue(b'f1' not in os.listdir(target_path))
         self.assertTrue(b'f2' not in os.listdir(target_path))
         c = r.get_config()
@@ -900,9 +900,13 @@ class PushTests(PorcelainTestCase):
         self.repo.refs[refs_path] = new_id
 
         # Push to the remote
-        porcelain.push(clone_path, self.repo.path, b"HEAD:" + refs_path,
+        porcelain.push(clone_path, 'origin', b"HEAD:" + refs_path,
                        outstream=outstream, errstream=errstream)
 
+        self.assertEqual(
+            target_repo.refs[b'refs/remotes/origin/foo'],
+            target_repo.refs[b'HEAD'])
+
         # Check that the target and source
         with Repo(clone_path) as r_clone:
             self.assertEqual({
@@ -1378,7 +1382,7 @@ class FetchTests(PorcelainTestCase):
         target_repo.close()
 
         # Fetch changes into the cloned repo
-        porcelain.fetch(target_path, self.repo.path,
+        porcelain.fetch(target_path, 'origin',
                         outstream=outstream, errstream=errstream)
 
         # Assert that fetch updated the local image of the remote
@@ -1390,7 +1394,7 @@ class FetchTests(PorcelainTestCase):
             self.assertTrue(self.repo[b'HEAD'].id in r)
 
     def test_with_remote_name(self):
-        remote_name = b'origin'
+        remote_name = 'origin'
         outstream = BytesIO()
         errstream = BytesIO()
 
@@ -1420,10 +1424,14 @@ class FetchTests(PorcelainTestCase):
                          committer=b'test2 <email>')
 
         self.assertFalse(self.repo[b'HEAD'].id in target_repo)
+
+        target_config = target_repo.get_config()
+        target_config.set(
+            (b'remote', remote_name.encode()), b'url', self.repo.path.encode())
         target_repo.close()
 
         # Fetch changes into the cloned repo
-        porcelain.fetch(target_path, self.repo.path, remote_name=remote_name,
+        porcelain.fetch(target_path, remote_name,
                         outstream=outstream, errstream=errstream)
 
         # Assert that fetch updated the local image of the remote