|
@@ -31,6 +31,7 @@ import argparse
|
|
|
import collections
|
|
|
import plugins.ldap
|
|
|
import plugins.sqs
|
|
|
+import typing
|
|
|
|
|
|
|
|
|
PUBSUB_VERSION = '0.7.1'
|
|
@@ -48,8 +49,26 @@ PUBSUB_BAD_PAYLOAD = "Bad payload type. Payloads must be JSON dictionary objects
|
|
|
PUBSUB_PAYLOAD_TOO_LARGE = "Payload is too large for me to serve, please make it shorter.\n"
|
|
|
PUBSUB_WRITE_TIMEOUT = 0.35
|
|
|
|
|
|
+
|
|
|
+class ServerConfig(typing.NamedTuple):
|
|
|
+ ip: str
|
|
|
+ port: int
|
|
|
+ payload_limit: int
|
|
|
+
|
|
|
+
|
|
|
+class BacklogConfig(typing.NamedTuple):
|
|
|
+ max_age: int
|
|
|
+ queue_size: int
|
|
|
+ storage: typing.Optional[str]
|
|
|
+
|
|
|
+
|
|
|
class Configuration:
|
|
|
- def __init__(self, yml):
|
|
|
+ server: ServerConfig
|
|
|
+ backlog: BacklogConfig
|
|
|
+ payloaders: typing.List[netaddr.ip.IPNetwork]
|
|
|
+ oldschoolers: typing.List[str]
|
|
|
+
|
|
|
+ def __init__(self, yml: dict):
|
|
|
|
|
|
|
|
|
self.ldap = None
|
|
@@ -61,13 +80,12 @@ class Configuration:
|
|
|
self.sqs = yml.get('sqs')
|
|
|
|
|
|
|
|
|
- self.server = collections.namedtuple('serverConfig', 'ip port payload_limit')
|
|
|
- self.server.ip = yml['server'].get('bind', PUBSUB_DEFAULT_IP)
|
|
|
- self.server.port = int(yml['server'].get('port', PUBSUB_DEFAULT_PORT))
|
|
|
- self.server.payload_limit = int(yml['server'].get('max_payload_size', PUBSUB_DEFAULT_MAX_PAYLOAD_SIZE))
|
|
|
+ server_ip = yml['server'].get('bind', PUBSUB_DEFAULT_IP)
|
|
|
+ server_port = int(yml['server'].get('port', PUBSUB_DEFAULT_PORT))
|
|
|
+ server_payload_limit = int(yml['server'].get('max_payload_size', PUBSUB_DEFAULT_MAX_PAYLOAD_SIZE))
|
|
|
+ self.server = ServerConfig(ip=server_ip, port=server_port, payload_limit=server_payload_limit)
|
|
|
|
|
|
|
|
|
- self.backlog = collections.namedtuple('backlogConfig', 'max_age queue_size storage')
|
|
|
bma = yml['server'].get('backlog', {}).get('max_age', PUBSUB_DEFAULT_BACKLOG_AGE)
|
|
|
if isinstance(bma, str):
|
|
|
bma = bma.lower()
|
|
@@ -79,10 +97,10 @@ class Configuration:
|
|
|
bma = int(bma.replace('h', '')) * 3600
|
|
|
elif bma.endswith('d'):
|
|
|
bma = int(bma.replace('d', '')) * 86400
|
|
|
- self.backlog.max_age = bma
|
|
|
- self.backlog.queue_size = yml['server'].get('backlog', {}).get('size',
|
|
|
- PUBSUB_DEFAULT_BACKLOG_SIZE)
|
|
|
- self.backlog.storage = yml['server'].get('backlog', {}).get('storage')
|
|
|
+ bqs = yml['server'].get('backlog', {}).get('size',
|
|
|
+ PUBSUB_DEFAULT_BACKLOG_SIZE)
|
|
|
+ bst = yml['server'].get('backlog', {}).get('storage')
|
|
|
+ self.backlog = BacklogConfig(max_age=bma, queue_size=bqs, storage=bst)
|
|
|
|
|
|
|
|
|
self.payloaders = [netaddr.IPNetwork(x) for x in yml['clients'].get('payloaders', [])]
|
|
@@ -90,29 +108,35 @@ class Configuration:
|
|
|
|
|
|
self.oldschoolers = yml['clients'].get('oldschoolers', [])
|
|
|
|
|
|
+
|
|
|
class Server:
|
|
|
"""Main server class, responsible for handling requests and publishing events """
|
|
|
-
|
|
|
- def __init__(self, args):
|
|
|
+ yaml: dict
|
|
|
+ config: Configuration
|
|
|
+ subscribers: list
|
|
|
+ pending_events: asyncio.Queue
|
|
|
+ backlog: list
|
|
|
+ last_ping = typing.Type[float]
|
|
|
+ server: aiohttp.web.Server
|
|
|
+
|
|
|
+ def __init__(self, args: argparse.Namespace):
|
|
|
self.yaml = yaml.safe_load(open(args.config))
|
|
|
self.config = Configuration(self.yaml)
|
|
|
self.subscribers = []
|
|
|
self.pending_events = asyncio.Queue()
|
|
|
self.backlog = []
|
|
|
self.last_ping = time.time()
|
|
|
- self.server = None
|
|
|
self.acl = {}
|
|
|
try:
|
|
|
self.acl = yaml.safe_load(open(args.acl))
|
|
|
except FileNotFoundError:
|
|
|
print(f"ACL configuration file {args.acl} not found, private events will not be broadcast.")
|
|
|
|
|
|
-
|
|
|
async def poll(self):
|
|
|
"""Polls for new stuff to publish, and if found, publishes to whomever wants it."""
|
|
|
while True:
|
|
|
- payload = await self.pending_events.get()
|
|
|
- bad_subs = await payload.publish(self.subscribers)
|
|
|
+ payload: Payload = await self.pending_events.get()
|
|
|
+ bad_subs: list = await payload.publish(self.subscribers)
|
|
|
self.pending_events.task_done()
|
|
|
|
|
|
|
|
@@ -123,9 +147,10 @@ class Server:
|
|
|
except ValueError:
|
|
|
pass
|
|
|
|
|
|
-
|
|
|
- async def handle_request(self, request):
|
|
|
+ async def handle_request(self, request: aiohttp.web.BaseRequest):
|
|
|
"""Generic handler for all incoming HTTP requests"""
|
|
|
+ resp: typing.Union[aiohttp.web.Response, aiohttp.web.StreamResponse]
|
|
|
+
|
|
|
|
|
|
headers = {
|
|
|
'Server': 'PyPubSub/%s' % PUBSUB_VERSION,
|
|
@@ -146,7 +171,7 @@ class Server:
|
|
|
return resp
|
|
|
if request.can_read_body:
|
|
|
try:
|
|
|
- if request.content_length > self.config.server.payload_limit:
|
|
|
+ if request.content_length and request.content_length > self.config.server.payload_limit:
|
|
|
resp = aiohttp.web.Response(headers=headers, status=400, text=PUBSUB_PAYLOAD_TOO_LARGE)
|
|
|
return resp
|
|
|
body = await request.text()
|
|
@@ -198,7 +223,7 @@ class Server:
|
|
|
backlog_ts = 0
|
|
|
|
|
|
if self.config.backlog.max_age > 0:
|
|
|
- backlog_ts = max(backlog_ts, time.time() - self.config.backlog.max_age)
|
|
|
+ backlog_ts = max(backlog_ts, int(time.time() - self.config.backlog.max_age))
|
|
|
|
|
|
for item in self.backlog:
|
|
|
if item.timestamp >= backlog_ts:
|
|
@@ -267,7 +292,7 @@ class Server:
|
|
|
|
|
|
print(f"Read {readlines} objects from {self.config.backlog.storage}, applied {len(self.backlog)} to backlog.")
|
|
|
|
|
|
- async def server_loop(self, loop):
|
|
|
+ async def server_loop(self, loop: asyncio.BaseEventLoop):
|
|
|
self.server = aiohttp.web.Server(self.handle_request)
|
|
|
runner = aiohttp.web.ServerRunner(self.server)
|
|
|
await runner.setup()
|
|
@@ -292,11 +317,12 @@ class Server:
|
|
|
loop.close()
|
|
|
|
|
|
|
|
|
-
|
|
|
class Subscriber:
|
|
|
"""Basic subscriber (client) class. Holds information about the connection and ACL"""
|
|
|
+ acl: dict
|
|
|
+ topics: typing.List[typing.List[str]]
|
|
|
|
|
|
- def __init__(self, server, connection, request):
|
|
|
+ def __init__(self, server: Server, connection: aiohttp.web.StreamResponse, request: aiohttp.web.BaseRequest):
|
|
|
self.connection = connection
|
|
|
self.acl = {}
|
|
|
self.server = server
|
|
@@ -315,7 +341,7 @@ class Subscriber:
|
|
|
self.old_school = True
|
|
|
break
|
|
|
|
|
|
- async def parse_acl(self, basic):
|
|
|
+ async def parse_acl(self, basic: str):
|
|
|
"""Sets the ACL if possible, based on Basic Auth"""
|
|
|
try:
|
|
|
decoded = str(base64.decodebytes(bytes(basic.replace('Basic ', ''), 'ascii')), 'utf-8')
|
|
@@ -366,7 +392,7 @@ class Subscriber:
|
|
|
class Payload:
|
|
|
"""A payload (event) object sent by a registered publisher."""
|
|
|
|
|
|
- def __init__(self, path, data, timestamp=None):
|
|
|
+ def __init__(self, path: str, data: dict, timestamp: typing.Optional[float] = None):
|
|
|
self.json = data
|
|
|
self.timestamp = timestamp or time.time()
|
|
|
self.topics = [x for x in path.split('/') if x]
|
|
@@ -381,7 +407,7 @@ class Payload:
|
|
|
self.json['pubsub_topics'] = self.topics
|
|
|
self.json['pubsub_path'] = path
|
|
|
|
|
|
- async def publish(self, subscribers):
|
|
|
+ async def publish(self, subscribers: typing.List[Subscriber]):
|
|
|
"""Publishes an object to all subscribers using those topics (or a sub-set thereof)"""
|
|
|
js = b"%s\n" % json.dumps(self.json).encode('utf-8')
|
|
|
ojs = js + b"\0"
|