Browse Source

type checks, remove unused imports

Daniel Gruno 4 years ago
parent
commit
db92365781
3 changed files with 58 additions and 33 deletions
  1. 3 3
      plugins/ldap.py
  2. 3 4
      plugins/sqs.py
  3. 52 26
      pypubsub.py

+ 3 - 3
plugins/ldap.py

@@ -23,7 +23,7 @@ import asyncio
 
 
 class LDAPConnection:
-    def __init__(self, yml):
+    def __init__(self, yml: dict):
         self.uri = yml.get('uri')
         assert isinstance(self.uri, str), "LDAP URI must be a string."
 
@@ -33,7 +33,7 @@ class LDAPConnection:
         self.base = yml.get('base_scope')
         assert isinstance(self.base, str), "LDAP Base scope must be a string"
 
-        self.patterns = yml.get('membership_patterns')
+        self.patterns: list = yml.get('membership_patterns', [])
         assert isinstance(self.patterns, list), "LDAP membership patterns must be a list of pattern strings"
 
         self.acl = yml.get('acl')
@@ -42,7 +42,7 @@ class LDAPConnection:
         assert ldap.initialize(self.uri)
         print("==== LDAP configuration looks kosher, enabling LDAP authentication as fallback ====")
 
-    async def get_groups(self, user, password):
+    async def get_groups(self, user: str, password: str):
         """Async fetching of groups an LDAP user belongs to"""
         bind_dn = self.user % user  # Interpolate user DN with username
 

+ 3 - 4
plugins/sqs.py

@@ -18,18 +18,17 @@
 
 """ This is the SQS component of PyPubSub """
 
-import asyncio
 import aiobotocore
 import botocore.exceptions
-import sys
 import json
 import pypubsub
+import typing
 
 # Global to hold ID of all items seem across all queues, to dedup things.
-ITEMS_SEEN = []
+ITEMS_SEEN: typing.List[str] = []
 
 
-async def get_payloads(server, config):
+async def get_payloads(server: pypubsub.Server, config: dict):
     # Assume everything is configured in the client's .aws config
     session = aiobotocore.get_session()
     queue_name = config.get('queue', '???')

+ 52 - 26
pypubsub.py

@@ -31,6 +31,7 @@ import argparse
 import collections
 import plugins.ldap
 import plugins.sqs
+import typing
 
 # Some consts
 PUBSUB_VERSION = '0.7.1'
@@ -48,8 +49,26 @@ PUBSUB_BAD_PAYLOAD = "Bad payload type. Payloads must be JSON dictionary objects
 PUBSUB_PAYLOAD_TOO_LARGE = "Payload is too large for me to serve, please make it shorter.\n"
 PUBSUB_WRITE_TIMEOUT = 0.35  # If we can't deliver to a pipe within N seconds, drop it.
 
+
+class ServerConfig(typing.NamedTuple):
+    ip: str
+    port: int
+    payload_limit: int
+
+
+class BacklogConfig(typing.NamedTuple):
+    max_age: int
+    queue_size: int
+    storage: typing.Optional[str]
+
+
 class Configuration:
-    def __init__(self, yml):
+    server: ServerConfig
+    backlog: BacklogConfig
+    payloaders: typing.List[netaddr.ip.IPNetwork]
+    oldschoolers: typing.List[str]
+
+    def __init__(self, yml: dict):
 
         # LDAP Settings
         self.ldap = None
@@ -61,13 +80,12 @@ class Configuration:
         self.sqs = yml.get('sqs')
 
         # Main server config
-        self.server = collections.namedtuple('serverConfig', 'ip port payload_limit')
-        self.server.ip = yml['server'].get('bind', PUBSUB_DEFAULT_IP)
-        self.server.port = int(yml['server'].get('port', PUBSUB_DEFAULT_PORT))
-        self.server.payload_limit = int(yml['server'].get('max_payload_size', PUBSUB_DEFAULT_MAX_PAYLOAD_SIZE))
+        server_ip = yml['server'].get('bind', PUBSUB_DEFAULT_IP)
+        server_port = int(yml['server'].get('port', PUBSUB_DEFAULT_PORT))
+        server_payload_limit = int(yml['server'].get('max_payload_size', PUBSUB_DEFAULT_MAX_PAYLOAD_SIZE))
+        self.server = ServerConfig(ip=server_ip, port=server_port, payload_limit=server_payload_limit)
 
         # Backlog settings
-        self.backlog = collections.namedtuple('backlogConfig', 'max_age queue_size storage')
         bma = yml['server'].get('backlog', {}).get('max_age', PUBSUB_DEFAULT_BACKLOG_AGE)
         if isinstance(bma, str):
             bma = bma.lower()
@@ -79,10 +97,10 @@ class Configuration:
                 bma = int(bma.replace('h', '')) * 3600
             elif bma.endswith('d'):
                 bma = int(bma.replace('d', '')) * 86400
-        self.backlog.max_age = bma
-        self.backlog.queue_size = yml['server'].get('backlog', {}).get('size',
-                                                                       PUBSUB_DEFAULT_BACKLOG_SIZE)
-        self.backlog.storage = yml['server'].get('backlog', {}).get('storage')
+        bqs = yml['server'].get('backlog', {}).get('size',
+                                                   PUBSUB_DEFAULT_BACKLOG_SIZE)
+        bst = yml['server'].get('backlog', {}).get('storage')
+        self.backlog = BacklogConfig(max_age=bma, queue_size=bqs, storage=bst)
 
         # Payloaders - clients that can post payloads
         self.payloaders = [netaddr.IPNetwork(x) for x in yml['clients'].get('payloaders', [])]
@@ -90,29 +108,35 @@ class Configuration:
         # Binary backwards compatibility
         self.oldschoolers = yml['clients'].get('oldschoolers', [])
 
+
 class Server:
     """Main server class, responsible for handling requests and publishing events """
-
-    def __init__(self, args):
+    yaml: dict
+    config: Configuration
+    subscribers: list
+    pending_events: asyncio.Queue
+    backlog: list
+    last_ping = typing.Type[float]
+    server: aiohttp.web.Server
+
+    def __init__(self, args: argparse.Namespace):
         self.yaml = yaml.safe_load(open(args.config))
         self.config = Configuration(self.yaml)
         self.subscribers = []
         self.pending_events = asyncio.Queue()
         self.backlog = []
         self.last_ping = time.time()
-        self.server = None
         self.acl = {}
         try:
             self.acl = yaml.safe_load(open(args.acl))
         except FileNotFoundError:
             print(f"ACL configuration file {args.acl} not found, private events will not be broadcast.")
 
-
     async def poll(self):
         """Polls for new stuff to publish, and if found, publishes to whomever wants it."""
         while True:
-            payload = await self.pending_events.get()
-            bad_subs = await payload.publish(self.subscribers)
+            payload: Payload = await self.pending_events.get()
+            bad_subs: list = await payload.publish(self.subscribers)
             self.pending_events.task_done()
 
             # Cull subscribers we couldn't deliver payload to.
@@ -123,9 +147,10 @@ class Server:
                 except ValueError:  # Already removed elsewhere
                     pass
 
-
-    async def handle_request(self, request):
+    async def handle_request(self, request: aiohttp.web.BaseRequest):
         """Generic handler for all incoming HTTP requests"""
+        resp: typing.Union[aiohttp.web.Response, aiohttp.web.StreamResponse]
+
         # Define response headers first...
         headers = {
             'Server': 'PyPubSub/%s' % PUBSUB_VERSION,
@@ -146,7 +171,7 @@ class Server:
                 return resp
             if request.can_read_body:
                 try:
-                    if request.content_length > self.config.server.payload_limit:
+                    if request.content_length and request.content_length > self.config.server.payload_limit:
                         resp = aiohttp.web.Response(headers=headers, status=400, text=PUBSUB_PAYLOAD_TOO_LARGE)
                         return resp
                     body = await request.text()
@@ -198,7 +223,7 @@ class Server:
                         backlog_ts = 0
                     # If max age is specified, force the TS to minimum that age
                     if self.config.backlog.max_age > 0:
-                        backlog_ts = max(backlog_ts, time.time() - self.config.backlog.max_age)
+                        backlog_ts = max(backlog_ts, int(time.time() - self.config.backlog.max_age))
                     # For each item, publish to client if new enough.
                     for item in self.backlog:
                         if item.timestamp >= backlog_ts:
@@ -267,7 +292,7 @@ class Server:
 
             print(f"Read {readlines} objects from {self.config.backlog.storage}, applied {len(self.backlog)} to backlog.")
 
-    async def server_loop(self, loop):
+    async def server_loop(self, loop: asyncio.BaseEventLoop):
         self.server = aiohttp.web.Server(self.handle_request)
         runner = aiohttp.web.ServerRunner(self.server)
         await runner.setup()
@@ -292,11 +317,12 @@ class Server:
         loop.close()
 
 
-
 class Subscriber:
     """Basic subscriber (client) class. Holds information about the connection and ACL"""
+    acl:    dict
+    topics: typing.List[typing.List[str]]
 
-    def __init__(self, server, connection, request):
+    def __init__(self, server: Server, connection: aiohttp.web.StreamResponse, request: aiohttp.web.BaseRequest):
         self.connection = connection
         self.acl = {}
         self.server = server
@@ -315,7 +341,7 @@ class Subscriber:
                 self.old_school = True
                 break
 
-    async def parse_acl(self, basic):
+    async def parse_acl(self, basic: str):
         """Sets the ACL if possible, based on Basic Auth"""
         try:
             decoded = str(base64.decodebytes(bytes(basic.replace('Basic ', ''), 'ascii')), 'utf-8')
@@ -366,7 +392,7 @@ class Subscriber:
 class Payload:
     """A payload (event) object sent by a registered publisher."""
 
-    def __init__(self, path, data, timestamp=None):
+    def __init__(self, path: str, data: dict, timestamp: typing.Optional[float] = None):
         self.json = data
         self.timestamp = timestamp or time.time()
         self.topics = [x for x in path.split('/') if x]
@@ -381,7 +407,7 @@ class Payload:
         self.json['pubsub_topics'] = self.topics
         self.json['pubsub_path'] = path
 
-    async def publish(self, subscribers):
+    async def publish(self, subscribers: typing.List[Subscriber]):
         """Publishes an object to all subscribers using those topics (or a sub-set thereof)"""
         js = b"%s\n" % json.dumps(self.json).encode('utf-8')
         ojs = js + b"\0"