Prechádzať zdrojové kódy

import pypubsub from puppet

Daniel Gruno 5 rokov pred
rodič
commit
71edc34ec2
1 zmenil súbory, kde vykonal 251 pridanie a 0 odobranie
  1. 251 0
      pypubsub.py

+ 251 - 0
pypubsub.py

@@ -0,0 +1,251 @@
+#!/usr/bin/env python3
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+""" This is the ASF's simplified publisher/subscriber service """
+import asyncio
+import aiohttp.web
+import time
+import json
+import yaml
+import netaddr
+import binascii
+import base64
+
+# Some consts
+PUBSUB_VERSION = '0.3.0'
+PUBSUB_BAD_REQUEST = "I could not understand your request, sorry! Please see https://pubsub.apache.org/api.html \
+for usage documentation.\n"
+PUBSUB_PAYLOAD_RECEIVED = "Payload received, thank you very much!\n"
+PUBSUB_NOT_ALLOWED = "You are not authorized to deliver payloads!\n"
+PUBSUB_BAD_PAYLOAD = "Bad payload type. Payloads must be JSON dictionary objects, {..}!\n"
+CONF = None
+ACL = None
+OLD_SCHOOLERS = ['svnwcsub', ]  # Old-school clients that use \0 terminators.
+
+# Internal score-keeping vars
+NUM_REQUESTS = 0   # Number of total requests served
+SUBSCRIBERS = []   # Current subscribers to everything
+PENDING_PUBS = []  # Payloads pending publication
+LAST_PING = 0      # Last time we did a global ping (every 5 seconds)
+PAYLOADERS = []    # IPs that can deliver payloads
+
+
+class Subscriber:
+    """ Basic subscriber (client) class.
+        Holds information about the connection and ACL
+    """
+    def __init__(self, connection, request):
+        self.connection = connection
+        self.request = request
+        self.acl = {}
+
+        # Set topics subscribed to
+        self.topics = [x for x in request.path.split('/') if x]
+
+        # Is the client old and expecting zero-terminators?
+        self.old_school = False
+        for ua in OLD_SCHOOLERS:
+            if ua in request.headers.get('User-Agent', ''):
+                self.old_school = True
+                break
+        # Is there a basic auth in this request? If so, set up ACL
+        auth = request.headers.get('Authorization')
+        if auth:
+            self.set_acl(auth)
+
+    def set_acl(self, basic):
+        """ Sets the ACL if possible, based on Basic Auth """
+        try:
+            decoded = str(base64.decodebytes(bytes(basic.replace('Basic ', ''), 'utf-8')), 'utf-8')
+            u, p = decoded.split(':', 1)
+            if u in ACL:
+                acl_pass = ACL[u].get('password')
+                if acl_pass and acl_pass == p:
+                    self.acl = ACL[u].get('acl', {})
+                    # Vet ACL for user
+                    if not isinstance(self.acl, dict):
+                        raise AssertionError(f"ACL for user {u} must be a dictionary of sub-IDs and topics, but is not.")
+                    # Make sure each ACL segment is a list of topics
+                    for k, v in self.acl.items():
+                        if not isinstance(v, list):
+                            raise AssertionError(f"ACL segment {k} for user {u} is not a list of topics!")
+                    print(f"Client {u} successfully authenticated (and ACL is valid).")
+        except binascii.Error as e:
+            self.acl = {}
+        except AssertionError as e:
+            print(e)
+            print(f"ACL configuration error: ACL scheme for {u} contains errors, setting ACL to nothing.")
+            self.acl = {}
+        except Exception as e:
+            print(f"Basic unknown exception occurred: {e}")
+
+    async def ping(self):
+        """ Generic ping-back to the client """
+        js = b"%s\n" % json.dumps({"stillalive": time.time()}).encode('utf-8')
+        if self.old_school:
+            js += b"\0"
+        await self.connection.write(js)
+
+
+class Payload:
+    """ A payload (event) object sent by a registered publisher. """
+    def __init__(self, path, data):
+        self.json = data
+        self.topics = [x for x in path.split('/') if x]
+        self.path = path
+        self.private = False
+
+        # Private payload?
+        if self.topics[0] == 'private':
+            self.private = True
+            self.topics = self.topics[1:]  # Remove the private bit from topics now.
+
+        self.json['pubsub_topics'] = self.topics
+        self.json['pubsub_path'] = self.path
+
+    async def publish(self, subscribers):
+        """ Publishes an object to all subscribers using those topics (or a sub-set thereof) """
+        js = b"%s\n" % json.dumps(self.json).encode('utf-8')
+        ojs = js + b"\0"
+        for sub in subscribers:
+            # If a private payload, check ACL and bail if not a match
+            if self.private:
+                can_see = False
+                for key, private_topics in sub.acl.items():
+                    if all(el in self.topics for el in private_topics):
+                        can_see = True
+                        break
+                if not can_see:
+                    continue
+            # If subscribed to all the topics, tell a subscriber about this
+            if all(el in self.topics for el in sub.topics):
+                try:
+                    if sub.old_school:
+                        await sub.connection.write(ojs)
+                    else:
+                        await sub.connection.write(js)
+                except ConnectionResetError:
+                    pass
+                except RuntimeError:
+                    pass
+                except AssertionError:  # drain helper throws these sometimes
+                    pass
+
+
+async def poll():
+    """ Polls for new stuff to publish, and if found, publishes to whomever wants it. """
+    global LAST_PING, PENDING_PUBS
+    while True:
+        for payload in PENDING_PUBS:
+            await payload.publish(SUBSCRIBERS)
+        PENDING_PUBS = []
+        await asyncio.sleep(0.5)
+
+
+async def handler(request):
+    """ Generic handler for all incoming HTTP requests """
+    global NUM_REQUESTS
+    NUM_REQUESTS += 1
+    # Define response headers first...
+    headers = {
+        'Server': 'PyPubSub/%s' % PUBSUB_VERSION,
+        'X-Subscribers': str(len(SUBSCRIBERS)),
+        'X-Requests': str(NUM_REQUESTS),
+    }
+
+    # Are we handling a publisher payload request? (PUT/POST)
+    if request.method in ['PUT', 'POST']:
+        ip = netaddr.IPAddress(request.remote)
+        allowed = False
+        for network in PAYLOADERS:
+            if ip in network:
+                allowed = True
+                break
+        if not allowed:
+            resp = aiohttp.web.Response(headers=headers, status=403, text=PUBSUB_NOT_ALLOWED)
+            return resp
+        if request.can_read_body:
+            try:
+                body = await request.json()
+                assert isinstance(body, dict)  # Payload MUST be an dictionary object, {...}
+                PENDING_PUBS.append(Payload(request.path, body))
+                resp = aiohttp.web.Response(headers=headers, status=202, text=PUBSUB_PAYLOAD_RECEIVED)
+                return resp
+            except json.decoder.JSONDecodeError:
+                resp = aiohttp.web.Response(headers=headers, status=400, text=PUBSUB_BAD_REQUEST)
+                return resp
+            except AssertionError:
+                resp = aiohttp.web.Response(headers=headers, status=400, text=PUBSUB_BAD_PAYLOAD)
+                return resp
+    # Is this a subscriber request? (GET)
+    elif request.method == 'GET':
+        resp = aiohttp.web.StreamResponse(headers=headers)
+        # We do not support HTTP 1.0 here...
+        if request.version.major == 1 and request.version.minor == 0:
+            return resp
+        subscriber = Subscriber(resp, request)
+        SUBSCRIBERS.append(subscriber)
+        # We'll change the content type once we're ready
+        # resp.content_type = 'application/vnd.apache-pubsub-stream'
+        resp.content_type = 'application/json'
+        try:
+            resp.enable_chunked_encoding()
+            await resp.prepare(request)
+            while True:
+                await subscriber.ping()
+                await asyncio.sleep(5)
+        # We may get exception types we don't have imported, so grab ANY exception and kick out the subscriber
+        except:
+            pass
+        SUBSCRIBERS.remove(subscriber)
+        return resp
+    elif request.method == 'HEAD':
+        resp = aiohttp.web.Response(headers=headers, status=204, text="")
+        return resp
+    #  I don't know this type of request :/ (DELETE, PATCH, etc)
+    else:
+        resp = aiohttp.web.Response(headers=headers, status=400, text=PUBSUB_BAD_REQUEST)
+        return resp
+
+
+async def main():
+    """ Main loop... """
+    server = aiohttp.web.Server(handler)
+    runner = aiohttp.web.ServerRunner(server)
+    await runner.setup()
+    site = aiohttp.web.TCPSite(runner, CONF['server']['bind'], CONF['server']['port'])
+    await site.start()
+    print("==== PyPubSub v/%s starting... ====" % PUBSUB_VERSION)
+    print("==== Serving up PubSub goodness at %s:%s ====" % (CONF['server']['bind'], CONF['server']['port']))
+
+    await poll()
+
+
+if __name__ == '__main__':
+    CONF = yaml.safe_load(open('pypubsub.yaml'))
+    try:
+        ACL = yaml.safe_load(open('pypubsub_acl.yaml'))
+    except:
+        ACL = {}
+    PAYLOADERS = [netaddr.IPNetwork(x) for x in CONF['clients']['payloaders']]
+    loop = asyncio.get_event_loop()
+    try:
+        loop.run_until_complete(main())
+    except KeyboardInterrupt:
+        pass
+    loop.close()