浏览代码

Classify and refactor configuration file and parsing

Daniel Gruno 5 年之前
父节点
当前提交
ea1e457bcc
共有 3 个文件被更改,包括 110 次插入87 次删除
  1. 47 44
      plugins/ldap.py
  2. 59 40
      pypubsub.py
  3. 4 3
      pypubsub.yaml

+ 47 - 44
plugins/ldap.py

@@ -21,57 +21,60 @@
 import ldap
 import asyncio
 
-LDAP_SCHEME = {
-    'uri': str,
-    'user_dn': str,
-    'base_scope': str,
-    'membership_patterns': list,
-    'acl': dict
-}
 
+class LDAPConnection:
+    def __init__(self, yml):
+        self.uri = yml.get('uri')
+        assert isinstance(self.uri, str), "LDAP URI must be a string."
 
-def vet_settings(lconf):
-    """Simple test to vet LDAP settings if present"""
-    if lconf:
-        for k, v in LDAP_SCHEME.items():
-            assert isinstance(lconf.get(k), v), f"LDAP configuration item {k} must be of type {v.__name__}!"
-        assert ldap.initialize(lconf['uri'])
-    print("==== LDAP configuration looks kosher, enabling LDAP authentication as fallback ====")
+        self.user = yml.get('user_dn')
+        assert isinstance(self.user, str) or self.user is None, "LDAP user DN must be a string or absent."
 
+        self.base = yml.get('base_scope')
+        assert isinstance(self.base, str), "LDAP Base scope must be a string"
 
-async def get_groups(lconf, user, password):
-    """Async fetching of groups an LDAP user belongs to"""
-    bind_dn = lconf['user_dn'] % user
+        self.patterns = yml.get('membership_patterns')
+        assert isinstance(self.patterns, list), "LDAP membership patterns must be a list of pattern strings"
 
-    try:
-        client = ldap.initialize(lconf['uri'])
-        client.set_option(ldap.OPT_REFERRALS, 0)
-        client.set_option(ldap.OPT_TIMEOUT, 0)
-        rv = client.simple_bind(bind_dn, password)
-        while True:
-            res = client.result(rv, timeout=0)
-            if res and res != (None, None):
-                break
-            await asyncio.sleep(0.25)
+        self.acl = yml.get('acl')
+        assert isinstance(self.acl, dict), "LDAP ACL must be a dictionary (hash) of ACLs"
 
-        groups = []
-        for role in lconf['membership_patterns']:
-            rv = client.search(lconf['base_scope'], ldap.SCOPE_SUBTREE, role % user, ['dn'])
+        assert ldap.initialize(self.uri)
+        print("==== LDAP configuration looks kosher, enabling LDAP authentication as fallback ====")
+
+    async def get_groups(self, user, password):
+        """Async fetching of groups an LDAP user belongs to"""
+        bind_dn = self.user % user  # Interpolate user DN with username
+
+        try:
+            client = ldap.initialize(self.uri)
+            client.set_option(ldap.OPT_REFERRALS, 0)
+            client.set_option(ldap.OPT_TIMEOUT, 0)
+            rv = client.simple_bind(bind_dn, password)
             while True:
-                res = client.result(rv, all=0, timeout=0)
-                if res:
-                    if res == (None, None):
-                        await asyncio.sleep(0.25)
-                    else:
-                        if not res[1]:
-                            break
-                        for tuples in res[1]:
-                            groups.append(tuples[0])
-                else:
+                res = client.result(rv, timeout=0)
+                if res and res != (None, None):
                     break
-        return groups
+                await asyncio.sleep(0.25)
+
+            groups = []
+            for role in self.patterns:
+                rv = client.search(self.base, ldap.SCOPE_SUBTREE, role % user, ['dn'])
+                while True:
+                    res = client.result(rv, all=0, timeout=0)
+                    if res:
+                        if res == (None, None):
+                            await asyncio.sleep(0.25)
+                        else:
+                            if not res[1]:
+                                break
+                            for tuples in res[1]:
+                                groups.append(tuples[0])
+                    else:
+                        break
+            return groups
 
-    except Exception as e:
-        print(f"LDAP Exception for user {user}: {e}")
-        return []
+        except Exception as e:
+            print(f"LDAP Exception for user {user}: {e}")
+            return []
 

+ 59 - 40
pypubsub.py

@@ -26,12 +26,15 @@ import netaddr
 import binascii
 import base64
 import argparse
+import collections
 import plugins.ldap
 import plugins.sqs
 
 # Some consts
-PUBSUB_VERSION = '0.5.1'
+PUBSUB_VERSION = '0.6.0'
 PUBSUB_CONTENT_TYPE = 'application/vnd.pypubsub-stream'
+PUBSUB_DEFAULT_PORT = 2069
+PUBSUB_DEFAULT_IP = '0.0.0.0'
 PUBSUB_DEFAULT_MAX_PAYLOAD_SIZE = 102400
 PUBSUB_DEFAULT_BACKLOG_SIZE = 0
 PUBSUB_DEFAULT_BACKLOG_AGE = 0
@@ -43,21 +46,27 @@ PUBSUB_BAD_PAYLOAD = "Bad payload type. Payloads must be JSON dictionary objects
 PUBSUB_PAYLOAD_TOO_LARGE = "Payload is too large for me to serve, please make it shorter.\n"
 
 
-class Server:
-    """Main server class, responsible for handling requests and publishing events """
+class Configuration:
+    def __init__(self, yml):
 
-    def __init__(self, args):
-        self.config = yaml.safe_load(open(args.config))
-        self.lconfig = None
-        self.sqsconfig = None
-        self.subscribers = []
-        self.pending_events = []
-        self.backlog = []
-        self.last_ping = time.time()
-        self.server = None
+        # LDAP Settings
+        self.ldap = None
+        lyml = yml.get('clients', {}).get('ldap')
+        if isinstance(lyml, dict):
+            self.ldap = plugins.ldap.LDAPConnection(lyml)
+
+        # SQS?
+        self.sqs = yml.get('sqs')
 
-        # Backlog age calcs
-        bma = self.config['clients'].get('payload_backlog_max_age', PUBSUB_DEFAULT_BACKLOG_AGE)
+        # Main server config
+        self.server = collections.namedtuple('serverConfig', 'ip port payload_limit')
+        self.server.ip = yml['server'].get('bind', PUBSUB_DEFAULT_IP)
+        self.server.port = int(yml['server'].get('port', PUBSUB_DEFAULT_PORT))
+        self.server.payload_limit = int(yml['server'].get('max_payload_size', PUBSUB_DEFAULT_MAX_PAYLOAD_SIZE))
+
+        # Backlog settings
+        self.backlog = collections.namedtuple('backlogConfig', 'max_age queue_size')
+        bma = yml['server'].get('backlog', {}).get('max_age', PUBSUB_DEFAULT_BACKLOG_AGE)
         if isinstance(bma, str):
             bma = bma.lower()
             if bma.endswith('s'):
@@ -68,22 +77,33 @@ class Server:
                 bma = int(bma.replace('h', '')) * 3600
             elif bma.endswith('d'):
                 bma = int(bma.replace('d', '')) * 86400
-        self.backlog_max_age = bma
+        self.backlog.max_age = bma
+        self.backlog.queue_size = yml['server'].get('backlog', {}).get('size',
+                                                                       PUBSUB_DEFAULT_BACKLOG_SIZE)
 
-        # LDAP configuration present?
-        if 'ldap' in self.config.get('clients', {}):
-            self.lconfig = self.config['clients']['ldap']
-            plugins.ldap.vet_settings(self.lconfig)
-        self.acl = {}
+        # Payloaders - clients that can post payloads
+        self.payloaders = [netaddr.IPNetwork(x) for x in yml['clients'].get('payloaders', [])]
+
+        # Binary backwards compatibility
+        self.oldschoolers = yml['clients'].get('oldschoolers', [])
 
-        # SQS configuration present?
-        if 'sqs' in self.config:
-            self.sqsconfig = self.config.get('sqs')
+class Server:
+    """Main server class, responsible for handling requests and publishing events """
+
+    def __init__(self, args):
+        self.yaml = yaml.safe_load(open(args.config))
+        self.config = Configuration(self.yaml)
+        self.subscribers = []
+        self.pending_events = []
+        self.backlog = []
+        self.last_ping = time.time()
+        self.server = None
+        self.acl = {}
         try:
             self.acl = yaml.safe_load(open(args.acl))
         except FileNotFoundError:
             print(f"ACL configuration file {args.acl} not found, private events will not be broadcast.")
-        self.payloaders = [netaddr.IPNetwork(x) for x in self.config['clients']['payloaders']]
+
 
     async def poll(self):
         """Polls for new stuff to publish, and if found, publishes to whomever wants it."""
@@ -110,7 +130,7 @@ class Server:
         if request.method in ['PUT', 'POST']:
             ip = netaddr.IPAddress(request.remote)
             allowed = False
-            for network in self.payloaders:
+            for network in self.config.payloaders:
                 if ip in network:
                     allowed = True
                     break
@@ -119,8 +139,7 @@ class Server:
                 return resp
             if request.can_read_body:
                 try:
-                    if request.content_length > self.config['clients'].get('max_payload_size',
-                                                                           PUBSUB_DEFAULT_MAX_PAYLOAD_SIZE):
+                    if request.content_length > self.config.server.payload_limit:
                         resp = aiohttp.web.Response(headers=headers, status=400, text=PUBSUB_PAYLOAD_TOO_LARGE)
                         return resp
                     body = await request.text()
@@ -128,11 +147,11 @@ class Server:
                     assert isinstance(as_json, dict)  # Payload MUST be an dictionary object, {...}
                     pl = Payload(request.path, as_json)
                     self.pending_events.append(pl)
-                    backlog_size = self.config['clients'].get('payload_backlog_size', PUBSUB_DEFAULT_BACKLOG_SIZE)
-                    if backlog_size > 0:
+                    # Add to backlog?
+                    if self.config.backlog.queue_size > 0:
                         self.backlog.append(pl)
                         # If backlog has grown too large, delete the first (oldest) item in it.
-                        if len(self.backlog) > backlog_size:
+                        if len(self.backlog) > self.config.backlog.queue_size:
                             del self.backlog[0]
 
                     resp = aiohttp.web.Response(headers=headers, status=202, text=PUBSUB_PAYLOAD_RECEIVED)
@@ -171,8 +190,8 @@ class Server:
                     except ValueError:  # Default to 0 if we can't parse the epoch
                         backlog_ts = 0
                     # If max age is specified, force the TS to minimum that age
-                    if self.backlog_max_age and self.backlog_max_age > 0:
-                        backlog_ts = max(backlog_ts, time.time() - self.backlog_max_age)
+                    if self.config.backlog.max_age > 0:
+                        backlog_ts = max(backlog_ts, time.time() - self.config.backlog.max_age)
                     # For each item, publish to client if new enough.
                     for item in self.backlog:
                         if item.timestamp >= backlog_ts:
@@ -201,13 +220,13 @@ class Server:
         self.server = aiohttp.web.Server(self.handle_request)
         runner = aiohttp.web.ServerRunner(self.server)
         await runner.setup()
-        site = aiohttp.web.TCPSite(runner, self.config['server']['bind'], self.config['server']['port'])
+        site = aiohttp.web.TCPSite(runner, self.config.server.ip, self.config.server.port)
         await site.start()
         print("==== PyPubSub v/%s starting... ====" % PUBSUB_VERSION)
         print("==== Serving up PubSub goodness at %s:%s ====" % (
-        self.config['server']['bind'], self.config['server']['port']))
-        if self.sqsconfig:
-            for key, config in self.sqsconfig.items():
+            self.config.server.ip, self.config.server.port))
+        if self.config.sqs:
+            for key, config in self.config.sqs.items():
                 loop.create_task(plugins.sqs.get_payloads(self, config))
         await self.poll()
 
@@ -233,7 +252,7 @@ class Subscriber:
 
         # Is the client old and expecting zero-terminators?
         self.old_school = False
-        for ua in self.server.config['clients'].get('oldschoolers', []):
+        for ua in self.server.config.oldschoolers:
             if ua in request.headers.get('User-Agent', ''):
                 self.old_school = True
                 break
@@ -255,11 +274,11 @@ class Subscriber:
                         assert isinstance(v, list), f"ACL segment {k} for user {u} is not a list of topics!"
                     print(f"Client {u} successfully authenticated (and ACL is valid).")
                     return acl
-            elif self.server.lconfig:
+            elif self.server.config.ldap:
                 acl = {}
-                groups = await plugins.ldap.get_groups(self.server.lconfig, u, p)
+                groups = await self.server.config.ldap.get_groups(u,p)
                 # Make sure each ACL segment is a list of topics
-                for k, v in self.server.lconfig['acl'].items():
+                for k, v in self.server.config.ldap.acl.items():
                     if k in groups:
                         assert isinstance(v, dict), f"ACL segment {k} for user {u} is not a dictionary of segments!"
                         for segment, topics in v.items():

+ 4 - 3
pypubsub.yaml

@@ -2,6 +2,10 @@
 server:
   port: 2069
   bind: 0.0.0.0
+  max_payload_size:          102400   # Max size of each JSON payload
+  backlog:
+    size:           0   # Max number of payloads to keep in backlog cache (set to 0 to disable)
+    max_age:      48h   # Maximum age of a backlog item before culling it (set to 0 to never prune on age)
 
 # Client settings
 clients:
@@ -9,9 +13,6 @@ clients:
   payloaders:
     - 127.0.0.1/24
     - 10.0.0.1/24
-  max_payload_size:          102400   # Max size of each JSON payload
-  payload_backlog_size:           0   # Max number of payloads to keep in backlog cache (set to 0 to disable)
-  payload_backlog_max_age:      48h   # Maximum age of a backlog item before culling it (set to 0 to never prune on age)
   # Oldschoolers denotes clients expecting binary events, such as svnwcsub
   oldschoolers:
     - svnwcsub