Parcourir la source

Prepare for a world in which we check for diverged branches.

Jelmer Vernooij il y a 4 ans
Parent
commit
b93c20797b
1 fichiers modifiés avec 38 ajouts et 8 suppressions
  1. 38 8
      dulwich/porcelain.py

+ 38 - 8
dulwich/porcelain.py

@@ -218,6 +218,23 @@ def path_to_tree_path(repopath, path, tree_encoding=DEFAULT_ENCODING):
         return bytes(relpath)
 
 
+class DivergedBranches(Exception):
+    """Branches have diverged and fast-forward is not possible."""
+
+
+def check_diverged(store, current_sha, new_sha):
+    """Check if updating to a sha can be done with fast forwarding.
+
+    Args:
+      store: Object store
+      current_sha: Current head sha
+      new_sha: New head sha
+    """
+    return
+    # TODO(jelmer): check for diverged branches. See bug #666, #494
+    raise DivergedBranches(current_sha, new_sha)
+
+
 def archive(repo, committish=None, outstream=default_bytes_out_stream,
             errstream=default_bytes_err_stream):
     """Create an archive.
@@ -905,7 +922,8 @@ def get_remote_repo(
 
 def push(repo, remote_location=None, refspecs=None,
          outstream=default_bytes_out_stream,
-         errstream=default_bytes_err_stream, **kwargs):
+         errstream=default_bytes_err_stream,
+         force=False, **kwargs):
     """Remote push with dulwich via dulwich.client
 
     Args:
@@ -914,6 +932,7 @@ def push(repo, remote_location=None, refspecs=None,
       refspecs: Refs to push to remote
       outstream: A stream file to write output
       errstream: A stream file to write errors
+      force: Force overwriting refs
     """
 
     # Open the repo
@@ -928,14 +947,17 @@ def push(repo, remote_location=None, refspecs=None,
         remote_changed_refs = {}
 
         def update_refs(refs):
-            selected_refs.extend(parse_reftuples(r.refs, refs, refspecs))
+            selected_refs.extend(parse_reftuples(
+                r.refs, refs, refspecs, force=force))
             new_refs = {}
             # TODO: Handle selected_refs == {None: None}
-            for (lh, rh, force) in selected_refs:
+            for (lh, rh, force_ref) in selected_refs:
                 if lh is None:
                     new_refs[rh] = ZERO_SHA
                     remote_changed_refs[rh] = None
                 else:
+                    if not force_ref:
+                        check_diverged(r.object_store, refs[rh], r.refs[lh])
                     new_refs[rh] = r.refs[lh]
                     remote_changed_refs[rh] = r.refs[lh]
             return new_refs
@@ -963,7 +985,8 @@ def push(repo, remote_location=None, refspecs=None,
 
 def pull(repo, remote_location=None, refspecs=None,
          outstream=default_bytes_out_stream,
-         errstream=default_bytes_err_stream, **kwargs):
+         errstream=default_bytes_err_stream, fast_forward=True,
+         force=False, **kwargs):
     """Pull from remote via dulwich.client
 
     Args:
@@ -973,6 +996,9 @@ def pull(repo, remote_location=None, refspecs=None,
       outstream: A stream file to write to output
       errstream: A stream file to write to errors
     """
+    if not fast_forward:
+        raise NotImplementedError('no-ff pull is not yet supported')
+
     # Open the repo
     with open_repo_closing(repo) as r:
         (remote_name, remote_location) = get_remote_repo(r, remote_location)
@@ -983,15 +1009,18 @@ def pull(repo, remote_location=None, refspecs=None,
 
         def determine_wants(remote_refs):
             selected_refs.extend(
-                parse_reftuples(remote_refs, r.refs, refspecs))
+                parse_reftuples(remote_refs, r.refs, refspecs, force=force))
             return [
-                remote_refs[lh] for (lh, rh, force) in selected_refs
+                remote_refs[lh] for (lh, rh, force_ref) in selected_refs
                 if remote_refs[lh] not in r.object_store]
         client, path = get_transport_and_path(
                 remote_location, config=r.get_config_stack(), **kwargs)
         fetch_result = client.fetch(
             path, r, progress=errstream.write, determine_wants=determine_wants)
-        for (lh, rh, force) in selected_refs:
+        for (lh, rh, force_ref) in selected_refs:
+            if fast_forward:
+                check_diverged(
+                    r.object_store, r.refs[rh], fetch_result.refs[lh])
             r.refs[rh] = fetch_result.refs[lh]
         if selected_refs:
             r[b'HEAD'] = fetch_result.refs[selected_refs[0][1]]
@@ -1317,7 +1346,8 @@ def _import_remote_refs(
 
 def fetch(repo, remote_location=None,
           outstream=sys.stdout, errstream=default_bytes_err_stream,
-          message=None, depth=None, prune=False, prune_tags=False, **kwargs):
+          message=None, depth=None, prune=False, prune_tags=False, force=False,
+          **kwargs):
     """Fetch objects from a remote server.
 
     Args: