|
@@ -137,6 +137,31 @@ class ReportStatusParser(object):
|
|
|
self._ref_status_ok = False
|
|
|
|
|
|
|
|
|
+def read_pkt_refs(proto):
|
|
|
+ server_capabilities = None
|
|
|
+ refs = {}
|
|
|
+ # Receive refs from server
|
|
|
+ for pkt in proto.read_pkt_seq():
|
|
|
+ (sha, ref) = pkt.rstrip('\n').split(None, 1)
|
|
|
+ if sha == 'ERR':
|
|
|
+ raise GitProtocolError(ref)
|
|
|
+ if server_capabilities is None:
|
|
|
+ (ref, server_capabilities) = extract_capabilities(ref)
|
|
|
+ refs[ref] = sha
|
|
|
+
|
|
|
+ if len(refs) == 0:
|
|
|
+ return None, set([])
|
|
|
+ return refs, set(server_capabilities)
|
|
|
+
|
|
|
+
|
|
|
+def read_info_refs(f):
|
|
|
+ ret = {}
|
|
|
+ for l in f.readlines():
|
|
|
+ (sha, name) = l.rstrip("\r\n").split("\t", 1)
|
|
|
+ ret[name] = sha
|
|
|
+ return ret
|
|
|
+
|
|
|
+
|
|
|
# TODO(durin42): this doesn't correctly degrade if the server doesn't
|
|
|
# support some capabilities. This should work properly with servers
|
|
|
# that don't support multi_ack.
|
|
@@ -153,27 +178,12 @@ class GitClient(object):
|
|
|
activity.
|
|
|
"""
|
|
|
self._report_activity = report_activity
|
|
|
+ self._report_status_parser = None
|
|
|
self._fetch_capabilities = set(FETCH_CAPABILITIES)
|
|
|
self._send_capabilities = set(SEND_CAPABILITIES)
|
|
|
if not thin_packs:
|
|
|
self._fetch_capabilities.remove('thin-pack')
|
|
|
|
|
|
- def _read_refs(self, proto):
|
|
|
- server_capabilities = None
|
|
|
- refs = {}
|
|
|
- # Receive refs from server
|
|
|
- for pkt in proto.read_pkt_seq():
|
|
|
- (sha, ref) = pkt.rstrip('\n').split(' ', 1)
|
|
|
- if sha == 'ERR':
|
|
|
- raise GitProtocolError(ref)
|
|
|
- if server_capabilities is None:
|
|
|
- (ref, server_capabilities) = extract_capabilities(ref)
|
|
|
- refs[ref] = sha
|
|
|
-
|
|
|
- if len(refs) == 0:
|
|
|
- return None, set([])
|
|
|
- return refs, set(server_capabilities)
|
|
|
-
|
|
|
def send_pack(self, path, determine_wants, generate_pack_contents,
|
|
|
progress=None):
|
|
|
"""Upload a pack to a remote repository.
|
|
@@ -201,10 +211,15 @@ class GitClient(object):
|
|
|
"""
|
|
|
if determine_wants is None:
|
|
|
determine_wants = target.object_store.determine_wants_all
|
|
|
- f, commit = target.object_store.add_pack()
|
|
|
- result = self.fetch_pack(path, determine_wants,
|
|
|
- target.get_graph_walker(), f.write, progress)
|
|
|
- commit()
|
|
|
+ f, commit, abort = target.object_store.add_pack()
|
|
|
+ try:
|
|
|
+ result = self.fetch_pack(path, determine_wants,
|
|
|
+ target.get_graph_walker(), f.write, progress)
|
|
|
+ except:
|
|
|
+ abort()
|
|
|
+ raise
|
|
|
+ else:
|
|
|
+ commit()
|
|
|
return result
|
|
|
|
|
|
def fetch_pack(self, path, determine_wants, graph_walker, pack_data,
|
|
@@ -288,9 +303,11 @@ class GitClient(object):
|
|
|
want = []
|
|
|
have = [x for x in old_refs.values() if not x == ZERO_SHA]
|
|
|
sent_capabilities = False
|
|
|
+
|
|
|
for refname in set(new_refs.keys() + old_refs.keys()):
|
|
|
old_sha1 = old_refs.get(refname, ZERO_SHA)
|
|
|
new_sha1 = new_refs.get(refname, ZERO_SHA)
|
|
|
+
|
|
|
if old_sha1 != new_sha1:
|
|
|
if sent_capabilities:
|
|
|
proto.write_pkt_line('%s %s %s' % (old_sha1, new_sha1,
|
|
@@ -312,24 +329,20 @@ class GitClient(object):
|
|
|
:param capabilities: List of negotiated capabilities
|
|
|
:param progress: Optional progress reporting function
|
|
|
"""
|
|
|
- if 'report-status' in capabilities:
|
|
|
- report_status_parser = ReportStatusParser()
|
|
|
- else:
|
|
|
- report_status_parser = None
|
|
|
if "side-band-64k" in capabilities:
|
|
|
if progress is None:
|
|
|
progress = lambda x: None
|
|
|
channel_callbacks = { 2: progress }
|
|
|
if 'report-status' in capabilities:
|
|
|
channel_callbacks[1] = PktLineParser(
|
|
|
- report_status_parser.handle_packet).parse
|
|
|
+ self._report_status_parser.handle_packet).parse
|
|
|
self._read_side_band64k_data(proto, channel_callbacks)
|
|
|
else:
|
|
|
if 'report-status' in capabilities:
|
|
|
for pkt in proto.read_pkt_seq():
|
|
|
- report_status_parser.handle_packet(pkt)
|
|
|
- if report_status_parser is not None:
|
|
|
- report_status_parser.check()
|
|
|
+ self._report_status_parser.handle_packet(pkt)
|
|
|
+ if self._report_status_parser is not None:
|
|
|
+ self._report_status_parser.check()
|
|
|
# wait for EOF before returning
|
|
|
data = proto.read()
|
|
|
if data:
|
|
@@ -402,7 +415,7 @@ class GitClient(object):
|
|
|
raise Exception('Unexpected response %r' % data)
|
|
|
else:
|
|
|
while True:
|
|
|
- data = self.read(rbufsize)
|
|
|
+ data = proto.read(rbufsize)
|
|
|
if data == "":
|
|
|
break
|
|
|
pack_data(data)
|
|
@@ -439,16 +452,48 @@ class TraditionalGitClient(GitClient):
|
|
|
and rejects ref updates
|
|
|
"""
|
|
|
proto, unused_can_read = self._connect('receive-pack', path)
|
|
|
- old_refs, server_capabilities = self._read_refs(proto)
|
|
|
+ old_refs, server_capabilities = read_pkt_refs(proto)
|
|
|
negotiated_capabilities = self._send_capabilities & server_capabilities
|
|
|
+
|
|
|
+ if 'report-status' in negotiated_capabilities:
|
|
|
+ self._report_status_parser = ReportStatusParser()
|
|
|
+ report_status_parser = self._report_status_parser
|
|
|
+
|
|
|
try:
|
|
|
- new_refs = determine_wants(dict(old_refs))
|
|
|
+ new_refs = orig_new_refs = determine_wants(dict(old_refs))
|
|
|
except:
|
|
|
proto.write_pkt_line(None)
|
|
|
raise
|
|
|
+
|
|
|
+ if not 'delete-refs' in server_capabilities:
|
|
|
+ # Server does not support deletions. Fail later.
|
|
|
+ def remove_del(pair):
|
|
|
+ if pair[1] == ZERO_SHA:
|
|
|
+ if 'report-status' in negotiated_capabilities:
|
|
|
+ report_status_parser._ref_statuses.append(
|
|
|
+ 'ng %s remote does not support deleting refs'
|
|
|
+ % pair[1])
|
|
|
+ report_status_parser._ref_status_ok = False
|
|
|
+ return False
|
|
|
+ else:
|
|
|
+ return True
|
|
|
+
|
|
|
+ new_refs = dict(
|
|
|
+ filter(
|
|
|
+ remove_del,
|
|
|
+ [(ref, sha) for ref, sha in new_refs.iteritems()]))
|
|
|
+
|
|
|
if new_refs is None:
|
|
|
proto.write_pkt_line(None)
|
|
|
return old_refs
|
|
|
+
|
|
|
+ if len(new_refs) == 0 and len(orig_new_refs):
|
|
|
+ # NOOP - Original new refs filtered out by policy
|
|
|
+ proto.write_pkt_line(None)
|
|
|
+ if self._report_status_parser is not None:
|
|
|
+ self._report_status_parser.check()
|
|
|
+ return old_refs
|
|
|
+
|
|
|
(have, want) = self._handle_receive_pack_head(proto,
|
|
|
negotiated_capabilities, old_refs, new_refs)
|
|
|
if not want and old_refs == new_refs:
|
|
@@ -456,6 +501,15 @@ class TraditionalGitClient(GitClient):
|
|
|
objects = generate_pack_contents(have, want)
|
|
|
if len(objects) > 0:
|
|
|
entries, sha = write_pack_objects(proto.write_file(), objects)
|
|
|
+ elif len(set(new_refs.values()) - set([ZERO_SHA])) > 0:
|
|
|
+ # Check for valid create/update refs
|
|
|
+ filtered_new_refs = \
|
|
|
+ dict([(ref, sha) for ref, sha in new_refs.iteritems()
|
|
|
+ if sha != ZERO_SHA])
|
|
|
+ if len(set(filtered_new_refs.iteritems()) -
|
|
|
+ set(old_refs.iteritems())) > 0:
|
|
|
+ entries, sha = write_pack_objects(proto.write_file(), objects)
|
|
|
+
|
|
|
self._handle_receive_pack_tail(proto, negotiated_capabilities,
|
|
|
progress)
|
|
|
return new_refs
|
|
@@ -470,7 +524,7 @@ class TraditionalGitClient(GitClient):
|
|
|
:param progress: Callback for progress reports (strings)
|
|
|
"""
|
|
|
proto, can_read = self._connect('upload-pack', path)
|
|
|
- refs, server_capabilities = self._read_refs(proto)
|
|
|
+ refs, server_capabilities = read_pkt_refs(proto)
|
|
|
negotiated_capabilities = self._fetch_capabilities & server_capabilities
|
|
|
|
|
|
if refs is None:
|
|
@@ -597,8 +651,33 @@ class SubprocessGitClient(TraditionalGitClient):
|
|
|
|
|
|
|
|
|
class SSHVendor(object):
|
|
|
+ """A client side SSH implementation."""
|
|
|
|
|
|
def connect_ssh(self, host, command, username=None, port=None):
|
|
|
+ import warnings
|
|
|
+ warnings.warn(
|
|
|
+ "SSHVendor.connect_ssh has been renamed to SSHVendor.run_command",
|
|
|
+ DeprecationWarning)
|
|
|
+ return self.run_command(host, command, username=username, port=port)
|
|
|
+
|
|
|
+ def run_command(self, host, command, username=None, port=None):
|
|
|
+ """Connect to an SSH server.
|
|
|
+
|
|
|
+ Run a command remotely and return a file-like object for interaction
|
|
|
+ with the remote command.
|
|
|
+
|
|
|
+ :param host: Host name
|
|
|
+ :param command: Command to run
|
|
|
+ :param username: Optional ame of user to log in as
|
|
|
+ :param port: Optional SSH port to use
|
|
|
+ """
|
|
|
+ raise NotImplementedError(self.run_command)
|
|
|
+
|
|
|
+
|
|
|
+class SubprocessSSHVendor(SSHVendor):
|
|
|
+ """SSH vendor that shells out to the local 'ssh' command."""
|
|
|
+
|
|
|
+ def run_command(self, host, command, username=None, port=None):
|
|
|
import subprocess
|
|
|
#FIXME: This has no way to deal with passwords..
|
|
|
args = ['ssh', '-x']
|
|
@@ -612,8 +691,112 @@ class SSHVendor(object):
|
|
|
stdout=subprocess.PIPE)
|
|
|
return SubprocessWrapper(proc)
|
|
|
|
|
|
+
|
|
|
+try:
|
|
|
+ import paramiko
|
|
|
+except ImportError:
|
|
|
+ pass
|
|
|
+else:
|
|
|
+ import threading
|
|
|
+
|
|
|
+ class ParamikoWrapper(object):
|
|
|
+ STDERR_READ_N = 2048 # 2k
|
|
|
+
|
|
|
+ def __init__(self, client, channel, progress_stderr=None):
|
|
|
+ self.client = client
|
|
|
+ self.channel = channel
|
|
|
+ self.progress_stderr = progress_stderr
|
|
|
+ self.should_monitor = bool(progress_stderr) or True
|
|
|
+ self.monitor_thread = None
|
|
|
+ self.stderr = ''
|
|
|
+
|
|
|
+ # Channel must block
|
|
|
+ self.channel.setblocking(True)
|
|
|
+
|
|
|
+ # Start
|
|
|
+ if self.should_monitor:
|
|
|
+ self.monitor_thread = threading.Thread(target=self.monitor_stderr)
|
|
|
+ self.monitor_thread.start()
|
|
|
+
|
|
|
+ def monitor_stderr(self):
|
|
|
+ while self.should_monitor:
|
|
|
+ # Block and read
|
|
|
+ data = self.read_stderr(self.STDERR_READ_N)
|
|
|
+
|
|
|
+ # Socket closed
|
|
|
+ if not data:
|
|
|
+ self.should_monitor = False
|
|
|
+ break
|
|
|
+
|
|
|
+ # Emit data
|
|
|
+ if self.progress_stderr:
|
|
|
+ self.progress_stderr(data)
|
|
|
+
|
|
|
+ # Append to buffer
|
|
|
+ self.stderr += data
|
|
|
+
|
|
|
+ def stop_monitoring(self):
|
|
|
+ # Stop StdErr thread
|
|
|
+ if self.should_monitor:
|
|
|
+ self.should_monitor = False
|
|
|
+ self.monitor_thread.join()
|
|
|
+
|
|
|
+ # Get left over data
|
|
|
+ data = self.channel.in_stderr_buffer.empty()
|
|
|
+ self.stderr += data
|
|
|
+
|
|
|
+ def can_read(self):
|
|
|
+ return self.channel.recv_ready()
|
|
|
+
|
|
|
+ def write(self, data):
|
|
|
+ return self.channel.sendall(data)
|
|
|
+
|
|
|
+ def read_stderr(self, n):
|
|
|
+ return self.channel.recv_stderr(n)
|
|
|
+
|
|
|
+ def read(self, n=None):
|
|
|
+ data = self.channel.recv(n)
|
|
|
+ data_len = len(data)
|
|
|
+
|
|
|
+ # Closed socket
|
|
|
+ if not data:
|
|
|
+ return
|
|
|
+
|
|
|
+ # Read more if needed
|
|
|
+ if n and data_len < n:
|
|
|
+ diff_len = n - data_len
|
|
|
+ return data + self.read(diff_len)
|
|
|
+ return data
|
|
|
+
|
|
|
+ def close(self):
|
|
|
+ self.channel.close()
|
|
|
+ self.stop_monitoring()
|
|
|
+
|
|
|
+ def __del__(self):
|
|
|
+ self.close()
|
|
|
+
|
|
|
+ class ParamikoSSHVendor(object):
|
|
|
+
|
|
|
+ def run_command(self, host, command, username=None, port=None,
|
|
|
+ progress_stderr=None, **kwargs):
|
|
|
+ client = paramiko.SSHClient()
|
|
|
+
|
|
|
+ policy = paramiko.client.MissingHostKeyPolicy()
|
|
|
+ client.set_missing_host_key_policy(policy)
|
|
|
+ client.connect(host, username=username, port=port, **kwargs)
|
|
|
+
|
|
|
+ # Open SSH session
|
|
|
+ channel = client.get_transport().open_session()
|
|
|
+
|
|
|
+ # Run commands
|
|
|
+ apply(channel.exec_command, command)
|
|
|
+
|
|
|
+ return ParamikoWrapper(client, channel,
|
|
|
+ progress_stderr=progress_stderr)
|
|
|
+
|
|
|
+
|
|
|
# Can be overridden by users
|
|
|
-get_ssh_vendor = SSHVendor
|
|
|
+get_ssh_vendor = SubprocessSSHVendor
|
|
|
|
|
|
|
|
|
class SSHGitClient(TraditionalGitClient):
|
|
@@ -631,7 +814,7 @@ class SSHGitClient(TraditionalGitClient):
|
|
|
def _connect(self, cmd, path):
|
|
|
if path.startswith("/~"):
|
|
|
path = path[1:]
|
|
|
- con = get_ssh_vendor().connect_ssh(
|
|
|
+ con = get_ssh_vendor().run_command(
|
|
|
self.host, ["%s '%s'" % (self._get_cmd_path(cmd), path)],
|
|
|
port=self.port, username=self.username)
|
|
|
return (Protocol(con.read, con.write, report_activity=self._report_activity),
|
|
@@ -678,14 +861,16 @@ class HttpGitClient(GitClient):
|
|
|
headers["Content-Type"] = "application/x-%s-request" % service
|
|
|
resp = self._http_request(url, headers)
|
|
|
self.dumb = (not resp.info().gettype().startswith("application/x-git-"))
|
|
|
- proto = Protocol(resp.read, None)
|
|
|
if not self.dumb:
|
|
|
+ proto = Protocol(resp.read, None)
|
|
|
# The first line should mention the service
|
|
|
pkts = list(proto.read_pkt_seq())
|
|
|
if pkts != [('# service=%s\n' % service)]:
|
|
|
raise GitProtocolError(
|
|
|
"unexpected first line %r from smart server" % pkts)
|
|
|
- return self._read_refs(proto)
|
|
|
+ return read_pkt_refs(proto)
|
|
|
+ else:
|
|
|
+ return read_info_refs(resp), set()
|
|
|
|
|
|
def _smart_request(self, service, url, data):
|
|
|
assert url[-1] == "/"
|
|
@@ -714,6 +899,10 @@ class HttpGitClient(GitClient):
|
|
|
old_refs, server_capabilities = self._discover_references(
|
|
|
"git-receive-pack", url)
|
|
|
negotiated_capabilities = self._send_capabilities & server_capabilities
|
|
|
+
|
|
|
+ if 'report-status' in negotiated_capabilities:
|
|
|
+ self._report_status_parser = ReportStatusParser()
|
|
|
+
|
|
|
new_refs = determine_wants(dict(old_refs))
|
|
|
if new_refs is None:
|
|
|
return old_refs
|
|
@@ -748,7 +937,7 @@ class HttpGitClient(GitClient):
|
|
|
url = self._get_url(path)
|
|
|
refs, server_capabilities = self._discover_references(
|
|
|
"git-upload-pack", url)
|
|
|
- negotiated_capabilities = server_capabilities
|
|
|
+ negotiated_capabilities = self._fetch_capabilities & server_capabilities
|
|
|
wants = determine_wants(refs)
|
|
|
if wants is not None:
|
|
|
wants = [cid for cid in wants if cid != ZERO_SHA]
|