Procházet zdrojové kódy

Turn server into a class of its own

- Turn server into its own class
- Fold configuration into server class, ditch globals
- Change set_acl to parse_acl
Daniel Gruno před 5 roky
rodič
revize
0cfd11130d
1 změnil soubory, kde provedl 136 přidání a 133 odebrání
  1. 136 133
      pypubsub.py

+ 136 - 133
pypubsub.py

@@ -28,53 +28,159 @@ import base64
 import pypubsub_ldap
 
 # Some consts
-PUBSUB_VERSION = '0.3.0'
+PUBSUB_VERSION = '0.4.0'
 PUBSUB_BAD_REQUEST = "I could not understand your request, sorry! Please see https://pubsub.apache.org/api.html \
 for usage documentation.\n"
 PUBSUB_PAYLOAD_RECEIVED = "Payload received, thank you very much!\n"
 PUBSUB_NOT_ALLOWED = "You are not authorized to deliver payloads!\n"
 PUBSUB_BAD_PAYLOAD = "Bad payload type. Payloads must be JSON dictionary objects, {..}!\n"
-CONF = None
-LCONF = None
-ACL = None
-OLD_SCHOOLERS = ['svnwcsub', ]  # Old-school clients that use \0 terminators.
 
-# Internal score-keeping vars
-NUM_REQUESTS = 0   # Number of total requests served
-SUBSCRIBERS = []   # Current subscribers to everything
-PENDING_PUBS = []  # Payloads pending publication
-LAST_PING = 0      # Last time we did a global ping (every 5 seconds)
-PAYLOADERS = []    # IPs that can deliver payloads
+
+class Server:
+    def __init__(self):
+        self.config = yaml.safe_load(open('pypubsub.yaml'))
+        self.lconfig = None
+        self.no_requests = 0
+        self.subscribers = []
+        self.pending_events = []
+        self.last_ping = time.time()
+
+        if 'ldap' in self.config.get('clients', {}):
+            pypubsub_ldap.vet_settings(self.config['clients']['ldap'])
+            self.lconfig = self.config['clients']['ldap']
+        self.acl = {}
+        try:
+            self.acl = yaml.safe_load(open('pypubsub_acl.yaml'))
+        except FileNotFoundError:
+            print("No ACL configuration file found, private events will not be broadcast.")
+        self.payloaders = [netaddr.IPNetwork(x) for x in self.config['clients']['payloaders']]
+
+    async def poll(self):
+        """ Polls for new stuff to publish, and if found, publishes to whomever wants it. """
+        while True:
+            for payload in self.pending_events:
+                await payload.publish(self.subscribers)
+            self.pending_events = []
+            await asyncio.sleep(0.5)
+
+    async def handle_request(self, request):
+        """ Generic handler for all incoming HTTP requests """
+        self.no_requests += 1
+        # Define response headers first...
+        headers = {
+            'Server': 'PyPubSub/%s' % PUBSUB_VERSION,
+            'X-Subscribers': str(len(self.subscribers)),
+            'X-Requests': str(self.no_requests),
+        }
+
+        # Are we handling a publisher payload request? (PUT/POST)
+        if request.method in ['PUT', 'POST']:
+            ip = netaddr.IPAddress(request.remote)
+            allowed = False
+            for network in self.payloaders:
+                if ip in network:
+                    allowed = True
+                    break
+            if not allowed:
+                resp = aiohttp.web.Response(headers=headers, status=403, text=PUBSUB_NOT_ALLOWED)
+                return resp
+            if request.can_read_body:
+                try:
+                    body = await request.json()
+                    assert isinstance(body, dict)  # Payload MUST be an dictionary object, {...}
+                    self.pending_events.append(Payload(request.path, body))
+                    resp = aiohttp.web.Response(headers=headers, status=202, text=PUBSUB_PAYLOAD_RECEIVED)
+                    return resp
+                except json.decoder.JSONDecodeError:
+                    resp = aiohttp.web.Response(headers=headers, status=400, text=PUBSUB_BAD_REQUEST)
+                    return resp
+                except AssertionError:
+                    resp = aiohttp.web.Response(headers=headers, status=400, text=PUBSUB_BAD_PAYLOAD)
+                    return resp
+        # Is this a subscriber request? (GET)
+        elif request.method == 'GET':
+            resp = aiohttp.web.StreamResponse(headers=headers)
+            # We do not support HTTP 1.0 here...
+            if request.version.major == 1 and request.version.minor == 0:
+                return resp
+            subscriber = Subscriber(self, resp, request)
+
+            # Is there a basic auth in this request? If so, set up ACL
+            auth = request.headers.get('Authorization')
+            if auth:
+                subscriber.acl = await subscriber.parse_acl(auth)
+
+            self.subscribers.append(subscriber)
+            # We'll change the content type once we're ready
+            # resp.content_type = 'application/vnd.apache-pubsub-stream'
+            resp.content_type = 'application/json'
+            try:
+                resp.enable_chunked_encoding()
+                await resp.prepare(request)
+                while True:
+                    await subscriber.ping()
+                    await asyncio.sleep(5)
+            # We may get exception types we don't have imported, so grab ANY exception and kick out the subscriber
+            except:
+                pass
+            self.subscribers.remove(subscriber)
+            return resp
+        elif request.method == 'HEAD':
+            resp = aiohttp.web.Response(headers=headers, status=204, text="")
+            return resp
+        #  I don't know this type of request :/ (DELETE, PATCH, etc)
+        else:
+            resp = aiohttp.web.Response(headers=headers, status=400, text=PUBSUB_BAD_REQUEST)
+            return resp
+
+    async def server_loop(self):
+        server = aiohttp.web.Server(self.handle_request)
+        runner = aiohttp.web.ServerRunner(server)
+        await runner.setup()
+        site = aiohttp.web.TCPSite(runner, self.config['server']['bind'], self.config['server']['port'])
+        await site.start()
+        print("==== PyPubSub v/%s starting... ====" % PUBSUB_VERSION)
+        print("==== Serving up PubSub goodness at %s:%s ====" % (self.config['server']['bind'], self.config['server']['port']))
+        await self.poll()
+
+    def run(self):
+        loop = asyncio.get_event_loop()
+        try:
+            loop.run_until_complete(self.server_loop())
+        except KeyboardInterrupt:
+            pass
+        loop.close()
+
 
 
 class Subscriber:
     """ Basic subscriber (client) class.
         Holds information about the connection and ACL
     """
-    def __init__(self, connection, request):
+    def __init__(self, server, connection, request):
         self.connection = connection
-        self.request = request
         self.acl = {}
+        self.server = server
 
         # Set topics subscribed to
         self.topics = [x for x in request.path.split('/') if x]
 
         # Is the client old and expecting zero-terminators?
         self.old_school = False
-        for ua in OLD_SCHOOLERS:
+        for ua in self.server.config['clients'].get('oldscoolers', []):
             if ua in request.headers.get('User-Agent', ''):
                 self.old_school = True
                 break
 
-    async def set_acl(self, basic):
+    async def parse_acl(self, basic):
         """ Sets the ACL if possible, based on Basic Auth """
         try:
-            decoded = str(base64.decodebytes(bytes(basic.replace('Basic ', ''), 'utf-8')), 'utf-8')
+            decoded = str(base64.decodebytes(bytes(basic.replace('Basic ', ''), 'ascii')), 'utf-8')
             u, p = decoded.split(':', 1)
-            if u in ACL:
-                acl_pass = ACL[u].get('password')
+            if u in self.server.acl:
+                acl_pass = self.server.acl[u].get('password')
                 if acl_pass and acl_pass == p:
-                    self.acl = ACL[u].get('acl', {})
+                    acl = self.server.acl[u].get('acl', {})
                     # Vet ACL for user
                     if not isinstance(self.acl, dict):
                         raise AssertionError(f"ACL for user {u} must be a dictionary of sub-IDs and topics, but is not.")
@@ -83,23 +189,26 @@ class Subscriber:
                         if not isinstance(v, list):
                             raise AssertionError(f"ACL segment {k} for user {u} is not a list of topics!")
                     print(f"Client {u} successfully authenticated (and ACL is valid).")
-            elif LCONF:
-                groups = await pypubsub_ldap.get_groups(LCONF, u, p)
+                    return acl
+            elif self.server.lconfig:
+                acl = {}
+                groups = await pypubsub_ldap.get_groups(self.server.lconfig, u, p)
                 # Make sure each ACL segment is a list of topics
-                for k, v in CONF['clients']['ldap']['acl'].items():
+                for k, v in self.server.lconfig['acl'].items():
                     if not isinstance(v, list):
                         raise AssertionError(f"ACL segment {k} for user {u} is not a list of topics!")
                     if k in groups:
                         print(f"Enabling ACL segment {k} for user {u}")
-                        self.acl[k] = v
+                        acl[k] = v
+                return acl
         except binascii.Error as e:
-            self.acl = {}
+            pass  # Bad Basic Auth params, bail quietly
         except AssertionError as e:
             print(e)
             print(f"ACL configuration error: ACL scheme for {u} contains errors, setting ACL to nothing.")
-            self.acl = {}
         except Exception as e:
             print(f"Basic unknown exception occurred: {e}")
+        return {}
 
     async def ping(self):
         """ Generic ping-back to the client """
@@ -154,115 +263,9 @@ class Payload:
                     pass
 
 
-async def poll():
-    """ Polls for new stuff to publish, and if found, publishes to whomever wants it. """
-    global LAST_PING, PENDING_PUBS
-    while True:
-        for payload in PENDING_PUBS:
-            await payload.publish(SUBSCRIBERS)
-        PENDING_PUBS = []
-        await asyncio.sleep(0.5)
-
-
-async def handler(request):
-    """ Generic handler for all incoming HTTP requests """
-    global NUM_REQUESTS
-    NUM_REQUESTS += 1
-    # Define response headers first...
-    headers = {
-        'Server': 'PyPubSub/%s' % PUBSUB_VERSION,
-        'X-Subscribers': str(len(SUBSCRIBERS)),
-        'X-Requests': str(NUM_REQUESTS),
-    }
-
-    # Are we handling a publisher payload request? (PUT/POST)
-    if request.method in ['PUT', 'POST']:
-        ip = netaddr.IPAddress(request.remote)
-        allowed = False
-        for network in PAYLOADERS:
-            if ip in network:
-                allowed = True
-                break
-        if not allowed:
-            resp = aiohttp.web.Response(headers=headers, status=403, text=PUBSUB_NOT_ALLOWED)
-            return resp
-        if request.can_read_body:
-            try:
-                body = await request.json()
-                assert isinstance(body, dict)  # Payload MUST be an dictionary object, {...}
-                PENDING_PUBS.append(Payload(request.path, body))
-                resp = aiohttp.web.Response(headers=headers, status=202, text=PUBSUB_PAYLOAD_RECEIVED)
-                return resp
-            except json.decoder.JSONDecodeError:
-                resp = aiohttp.web.Response(headers=headers, status=400, text=PUBSUB_BAD_REQUEST)
-                return resp
-            except AssertionError:
-                resp = aiohttp.web.Response(headers=headers, status=400, text=PUBSUB_BAD_PAYLOAD)
-                return resp
-    # Is this a subscriber request? (GET)
-    elif request.method == 'GET':
-        resp = aiohttp.web.StreamResponse(headers=headers)
-        # We do not support HTTP 1.0 here...
-        if request.version.major == 1 and request.version.minor == 0:
-            return resp
-        subscriber = Subscriber(resp, request)
-        # Is there a basic auth in this request? If so, set up ACL
-        auth = request.headers.get('Authorization')
-        if auth:
-            await subscriber.set_acl(auth)
-
-        SUBSCRIBERS.append(subscriber)
-        # We'll change the content type once we're ready
-        # resp.content_type = 'application/vnd.apache-pubsub-stream'
-        resp.content_type = 'application/json'
-        try:
-            resp.enable_chunked_encoding()
-            await resp.prepare(request)
-            while True:
-                await subscriber.ping()
-                await asyncio.sleep(5)
-        # We may get exception types we don't have imported, so grab ANY exception and kick out the subscriber
-        except:
-            pass
-        SUBSCRIBERS.remove(subscriber)
-        return resp
-    elif request.method == 'HEAD':
-        resp = aiohttp.web.Response(headers=headers, status=204, text="")
-        return resp
-    #  I don't know this type of request :/ (DELETE, PATCH, etc)
-    else:
-        resp = aiohttp.web.Response(headers=headers, status=400, text=PUBSUB_BAD_REQUEST)
-        return resp
-
-
-async def main():
-    """ Main loop... """
-    server = aiohttp.web.Server(handler)
-    runner = aiohttp.web.ServerRunner(server)
-    await runner.setup()
-    site = aiohttp.web.TCPSite(runner, CONF['server']['bind'], CONF['server']['port'])
-    await site.start()
-    print("==== PyPubSub v/%s starting... ====" % PUBSUB_VERSION)
-    print("==== Serving up PubSub goodness at %s:%s ====" % (CONF['server']['bind'], CONF['server']['port']))
-
-    await poll()
 
 
 if __name__ == '__main__':
-    CONF = yaml.safe_load(open('pypubsub.yaml'))
-    if 'ldap' in CONF.get('clients', {}):
-        pypubsub_ldap.vet_settings(CONF['clients']['ldap'])
-        LCONF = CONF['clients']['ldap']
-    ACL = {}
-    try:
-        ACL = yaml.safe_load(open('pypubsub_acl.yaml'))
-    except FileNotFoundError:
-        print("No ACL configuration file found, private events will not be broadcast.")
-    PAYLOADERS = [netaddr.IPNetwork(x) for x in CONF['clients']['payloaders']]
-    loop = asyncio.get_event_loop()
-    try:
-        loop.run_until_complete(main())
-    except KeyboardInterrupt:
-        pass
-    loop.close()
+    pubsub_server = Server()
+    pubsub_server.run()