123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468 |
- # server.py -- Implementation of the server side git protocols
- # Copryight (C) 2008 John Carr <john.carr@unrouted.co.uk>
- #
- # This program is free software; you can redistribute it and/or
- # modify it under the terms of the GNU General Public License
- # as published by the Free Software Foundation; version 2
- # or (at your option) any later version of the License.
- #
- # This program is distributed in the hope that it will be useful,
- # but WITHOUT ANY WARRANTY; without even the implied warranty of
- # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- # GNU General Public License for more details.
- #
- # You should have received a copy of the GNU General Public License
- # along with this program; if not, write to the Free Software
- # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
- # MA 02110-1301, USA.
- """Git smart network protocol server implementation.
- For more detailed implementation on the network protocol, see the
- Documentation/technical directory in the cgit distribution, and in particular:
- Documentation/technical/protocol-capabilities.txt
- Documentation/technical/pack-protocol.txt
- """
- import collections
- import SocketServer
- import tempfile
- from dulwich.errors import (
- GitProtocolError,
- )
- from dulwich.objects import (
- hex_to_sha,
- )
- from dulwich.protocol import (
- Protocol,
- ProtocolFile,
- TCP_GIT_PORT,
- extract_capabilities,
- extract_want_line_capabilities,
- SINGLE_ACK,
- MULTI_ACK,
- ack_type,
- )
- from dulwich.repo import (
- Repo,
- )
- from dulwich.pack import (
- write_pack_data,
- )
- class Backend(object):
- def get_refs(self):
- """
- Get all the refs in the repository
- :return: dict of name -> sha
- """
- raise NotImplementedError
- def apply_pack(self, refs, read):
- """ Import a set of changes into a repository and update the refs
- :param refs: list of tuple(name, sha)
- :param read: callback to read from the incoming pack
- """
- raise NotImplementedError
- def fetch_objects(self, determine_wants, graph_walker, progress):
- """
- Yield the objects required for a list of commits.
- :param progress: is a callback to send progress messages to the client
- """
- raise NotImplementedError
- class GitBackend(Backend):
- def __init__(self, repo=None):
- if repo is None:
- repo = Repo(tmpfile.mkdtemp())
- self.repo = repo
- self.object_store = self.repo.object_store
- self.fetch_objects = self.repo.fetch_objects
- self.get_refs = self.repo.get_refs
- def apply_pack(self, refs, read):
- f, commit = self.repo.object_store.add_thin_pack()
- try:
- f.write(read())
- finally:
- commit()
- for oldsha, sha, ref in refs:
- if ref == "0" * 40:
- del self.repo.refs[ref]
- else:
- self.repo.refs[ref] = sha
- print "pack applied"
- class Handler(object):
- """Smart protocol command handler base class."""
- def __init__(self, backend, read, write):
- self.backend = backend
- self.proto = Protocol(read, write)
- def capabilities(self):
- return " ".join(self.default_capabilities())
- class UploadPackHandler(Handler):
- """Protocol handler for uploading a pack to the server."""
- def __init__(self, backend, read, write):
- Handler.__init__(self, backend, read, write)
- self._client_capabilities = None
- self._graph_walker = None
- def default_capabilities(self):
- return ("multi_ack", "side-band-64k", "thin-pack", "ofs-delta")
- def set_client_capabilities(self, caps):
- my_caps = self.default_capabilities()
- for cap in caps:
- if '_ack' in cap and cap not in my_caps:
- raise GitProtocolError('Client asked for capability %s that '
- 'was not advertised.' % cap)
- self._client_capabilities = caps
- def get_client_capabilities(self):
- return self._client_capabilities
- client_capabilities = property(get_client_capabilities,
- set_client_capabilities)
- def handle(self):
- progress = lambda x: self.proto.write_sideband(2, x)
- write = lambda x: self.proto.write_sideband(1, x)
- graph_walker = ProtocolGraphWalker(self)
- objects_iter = self.backend.fetch_objects(
- graph_walker.determine_wants, graph_walker, progress)
- # Do they want any objects?
- if len(objects_iter) == 0:
- return
- progress("dul-daemon says what\n")
- progress("counting objects: %d, done.\n" % len(objects_iter))
- write_pack_data(ProtocolFile(None, write), objects_iter,
- len(objects_iter))
- progress("how was that, then?\n")
- # we are done
- self.proto.write("0000")
- class ProtocolGraphWalker(object):
- """A graph walker that knows the git protocol.
- As a graph walker, this class implements ack(), next(), and reset(). It also
- contains some base methods for interacting with the wire and walking the
- commit tree.
- The work of determining which acks to send is passed on to the
- implementation instance stored in _impl. The reason for this is that we do
- not know at object creation time what ack level the protocol requires. A
- call to set_ack_level() is required to set up the implementation, before any
- calls to next() or ack() are made.
- """
- def __init__(self, handler):
- self.handler = handler
- self.store = handler.backend.object_store
- self.proto = handler.proto
- self._wants = []
- self._cached = False
- self._cache = []
- self._cache_index = 0
- self._impl = None
- def determine_wants(self, heads):
- """Determine the wants for a set of heads.
- The given heads are advertised to the client, who then specifies which
- refs he wants using 'want' lines. This portion of the protocol is the
- same regardless of ack type, and in fact is used to set the ack type of
- the ProtocolGraphWalker.
- :param heads: a dict of refname->SHA1 to advertise
- :return: a list of SHA1s requested by the client
- """
- if not heads:
- raise GitProtocolError('No heads found')
- values = set(heads.itervalues())
- for i, (ref, sha) in enumerate(heads.iteritems()):
- line = "%s %s" % (sha, ref)
- if not i:
- line = "%s\x00%s" % (line, self.handler.capabilities())
- self.proto.write_pkt_line("%s\n" % line)
- # TODO: include peeled value of any tags
- # i'm done..
- self.proto.write_pkt_line(None)
- # Now client will sending want want want commands
- want = self.proto.read_pkt_line()
- if not want:
- return []
- line, caps = extract_want_line_capabilities(want)
- self.handler.client_capabilities = caps
- self.set_ack_type(ack_type(caps))
- command, sha = self._split_proto_line(line)
- want_revs = []
- while command != None:
- if command != 'want':
- raise GitProtocolError(
- 'Protocol got unexpected command %s' % command)
- if sha not in values:
- raise GitProtocolError(
- 'Client wants invalid object %s' % sha)
- want_revs.append(sha)
- command, sha = self.read_proto_line()
- self.set_wants(want_revs)
- return want_revs
- def ack(self, have_ref):
- return self._impl.ack(have_ref)
- def reset(self):
- self._cached = True
- self._cache_index = 0
- def next(self):
- if not self._cached:
- if not self._impl:
- return None
- return self._impl.next()
- self._cache_index += 1
- if self._cache_index > len(self._cache):
- return None
- return self._cache[self._cache_index]
- def _split_proto_line(self, line):
- fields = line.rstrip('\n').split(' ', 1)
- if len(fields) == 1 and fields[0] == 'done':
- return ('done', None)
- elif len(fields) == 2 and fields[0] in ('want', 'have'):
- try:
- hex_to_sha(fields[1])
- return tuple(fields)
- except (TypeError, AssertionError), e:
- raise GitProtocolError(e)
- raise GitProtocolError('Received invalid line from client:\n%s' % line)
- def read_proto_line(self):
- """Read a line from the wire.
- :return: a tuple having one of the following forms:
- ('want', obj_id)
- ('have', obj_id)
- ('done', None)
- (None, None) (for a flush-pkt)
- """
- line = self.proto.read_pkt_line()
- if not line:
- return (None, None)
- return self._split_proto_line(line)
- def send_ack(self, sha, ack_type=''):
- if ack_type:
- ack_type = ' %s' % ack_type
- self.proto.write_pkt_line('ACK %s%s\n' % (sha, ack_type))
- def send_nak(self):
- self.proto.write_pkt_line('NAK\n')
- def set_wants(self, wants):
- self._wants = wants
- def _is_satisfied(self, haves, want, earliest):
- """Check whether a want is satisfied by a set of haves.
- A want, typically a branch tip, is "satisfied" only if there exists a
- path back from that want to one of the haves.
- :param haves: A set of commits we know the client has.
- :param want: The want to check satisfaction for.
- :param earliest: A timestamp beyond which the search for haves will be
- terminated, presumably because we're searching too far down the
- wrong branch.
- """
- o = self.store[want]
- pending = collections.deque([o])
- while pending:
- commit = pending.popleft()
- if commit.id in haves:
- return True
- if not getattr(commit, 'get_parents', None):
- # non-commit wants are assumed to be satisfied
- continue
- for parent in commit.get_parents():
- parent_obj = self.store[parent]
- # TODO: handle parents with later commit times than children
- if parent_obj.commit_time >= earliest:
- pending.append(parent_obj)
- return False
- def all_wants_satisfied(self, haves):
- """Check whether all the current wants are satisfied by a set of haves.
- :param haves: A set of commits we know the client has.
- :note: Wants are specified with set_wants rather than passed in since
- in the current interface they are determined outside this class.
- """
- haves = set(haves)
- earliest = min([self.store[h].commit_time for h in haves])
- for want in self._wants:
- if not self._is_satisfied(haves, want, earliest):
- return False
- return True
- def set_ack_type(self, ack_type):
- impl_classes = {
- MULTI_ACK: MultiAckGraphWalkerImpl,
- SINGLE_ACK: SingleAckGraphWalkerImpl,
- }
- self._impl = impl_classes[ack_type](self)
- class SingleAckGraphWalkerImpl(object):
- """Graph walker implementation that speaks the single-ack protocol."""
- def __init__(self, walker):
- self.walker = walker
- self._sent_ack = False
- def ack(self, have_ref):
- if not self._sent_ack:
- self.walker.send_ack(have_ref)
- self._sent_ack = True
- def next(self):
- command, sha = self.walker.read_proto_line()
- if command in (None, 'done'):
- if not self._sent_ack:
- self.walker.send_nak()
- return None
- elif command == 'have':
- return sha
- class MultiAckGraphWalkerImpl(object):
- """Graph walker implementation that speaks the multi-ack protocol."""
- def __init__(self, walker):
- self.walker = walker
- self._found_base = False
- self._common = []
- def ack(self, have_ref):
- self._common.append(have_ref)
- if not self._found_base:
- self.walker.send_ack(have_ref, 'continue')
- if self.walker.all_wants_satisfied(self._common):
- self._found_base = True
- # else we blind ack within next
- def next(self):
- while True:
- command, sha = self.walker.read_proto_line()
- if command is None:
- self.walker.send_nak()
- # in multi-ack mode, a flush-pkt indicates the client wants to
- # flush but more have lines are still coming
- continue
- elif command == 'done':
- # don't nak unless no common commits were found, even if not
- # everything is satisfied
- if self._common:
- self.walker.send_ack(self._common[-1])
- else:
- self.walker.send_nak()
- return None
- elif command == 'have':
- if self._found_base:
- # blind ack
- self.walker.send_ack(sha, 'continue')
- return sha
- class ReceivePackHandler(Handler):
- """Protocol handler for downloading a pack to the client."""
- def default_capabilities(self):
- return ("report-status", "delete-refs")
- def handle(self):
- refs = self.backend.get_refs().items()
- if refs:
- self.proto.write_pkt_line("%s %s\x00%s\n" % (refs[0][1], refs[0][0], self.capabilities()))
- for i in range(1, len(refs)):
- ref = refs[i]
- self.proto.write_pkt_line("%s %s\n" % (ref[1], ref[0]))
- else:
- self.proto.write_pkt_line("0000000000000000000000000000000000000000 capabilities^{} %s" % self.capabilities())
- self.proto.write("0000")
- client_refs = []
- ref = self.proto.read_pkt_line()
- # if ref is none then client doesnt want to send us anything..
- if ref is None:
- return
- ref, client_capabilities = extract_capabilities(ref)
- # client will now send us a list of (oldsha, newsha, ref)
- while ref:
- client_refs.append(ref.split())
- ref = self.proto.read_pkt_line()
- # backend can now deal with this refs and read a pack using self.read
- self.backend.apply_pack(client_refs, self.proto.read)
- # when we have read all the pack from the client, it assumes
- # everything worked OK.
- # there is NO ack from the server before it reports victory.
- class TCPGitRequestHandler(SocketServer.StreamRequestHandler):
- def handle(self):
- proto = Protocol(self.rfile.read, self.wfile.write)
- command, args = proto.read_cmd()
- # switch case to handle the specific git command
- if command == 'git-upload-pack':
- cls = UploadPackHandler
- elif command == 'git-receive-pack':
- cls = ReceivePackHandler
- else:
- return
- h = cls(self.server.backend, self.rfile.read, self.wfile.write)
- h.handle()
- class TCPGitServer(SocketServer.TCPServer):
- allow_reuse_address = True
- serve = SocketServer.TCPServer.serve_forever
- def __init__(self, backend, listen_addr, port=TCP_GIT_PORT):
- self.backend = backend
- SocketServer.TCPServer.__init__(self, (listen_addr, port), TCPGitRequestHandler)
|