pypubsub.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517
  1. #!/usr/bin/env python3
  2. # Licensed to the Apache Software Foundation (ASF) under one
  3. # or more contributor license agreements. See the NOTICE file
  4. # distributed with this work for additional information
  5. # regarding copyright ownership. The ASF licenses this file
  6. # to you under the Apache License, Version 2.0 (the
  7. # "License"); you may not use this file except in compliance
  8. # with the License. You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing,
  13. # software distributed under the License is distributed on an
  14. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  15. # KIND, either express or implied. See the License for the
  16. # specific language governing permissions and limitations
  17. # under the License.
  18. """PyPubSub - a simple publisher/subscriber service written in Python 3"""
  19. import asyncio
  20. import aiohttp.web
  21. import aiofile
  22. import os
  23. import time
  24. import json
  25. import yaml
  26. import netaddr
  27. import binascii
  28. import base64
  29. import argparse
  30. import collections
  31. import plugins.ldap
  32. import plugins.sqs
  33. import typing
  34. import signal
  35. import uuid
  36. # Some consts
  37. PUBSUB_VERSION = '0.7.4'
  38. PUBSUB_CONTENT_TYPE = 'application/vnd.pypubsub-stream'
  39. PUBSUB_DEFAULT_PORT = 2069
  40. PUBSUB_DEFAULT_IP = '0.0.0.0'
  41. PUBSUB_DEFAULT_MAX_PAYLOAD_SIZE = 102400
  42. PUBSUB_DEFAULT_BACKLOG_SIZE = 0
  43. PUBSUB_DEFAULT_BACKLOG_AGE = 0
  44. PUBSUB_BAD_REQUEST = "I could not understand your request, sorry! Please see https://pubsub.apache.org/api.html \
  45. for usage documentation.\n"
  46. PUBSUB_PAYLOAD_RECEIVED = "Payload received, thank you very much!\n"
  47. PUBSUB_NOT_ALLOWED = "You are not authorized to deliver payloads!\n"
  48. PUBSUB_BAD_PAYLOAD = "Bad payload type. Payloads must be JSON dictionary objects, {..}!\n"
  49. PUBSUB_PAYLOAD_TOO_LARGE = "Payload is too large for me to serve, please make it shorter.\n"
  50. PUBSUB_WRITE_TIMEOUT = 0.35 # If we can't deliver to a pipe within N seconds, drop it.
  51. class ServerConfig(typing.NamedTuple):
  52. ip: str
  53. port: int
  54. payload_limit: int
  55. tls_port: int
  56. tls_ctx: typing.Any
  57. class BacklogConfig(typing.NamedTuple):
  58. max_age: int
  59. queue_size: int
  60. storage: typing.Optional[str]
  61. class Configuration:
  62. server: ServerConfig
  63. backlog: BacklogConfig
  64. payloaders: typing.List[netaddr.ip.IPNetwork]
  65. oldschoolers: typing.List[str]
  66. secure_topics: typing.Optional[typing.List[str]]
  67. def __init__(self, yml: dict):
  68. # LDAP Settings
  69. self.ldap = None
  70. lyml = yml.get('clients', {}).get('ldap')
  71. if isinstance(lyml, dict):
  72. self.ldap = plugins.ldap.LDAPConnection(lyml)
  73. # SQS?
  74. self.sqs = yml.get('sqs')
  75. # Main server config
  76. server_ip = yml['server'].get('bind', PUBSUB_DEFAULT_IP)
  77. server_port = int(yml['server'].get('port', PUBSUB_DEFAULT_PORT))
  78. server_payload_limit = int(yml['server'].get('max_payload_size', PUBSUB_DEFAULT_MAX_PAYLOAD_SIZE))
  79. tls_port = 0
  80. tls_ctx = None
  81. # TLS support, if configured
  82. if 'tls' in yml['server'] and isinstance(yml['server']['tls'], dict):
  83. for required_element in ("port", "cert", "key", ):
  84. assert yml['server']['tls'].get(required_element), f"TLS: configuration option '{required_element}' is missing or invalid, cannot enable TLS!"
  85. import ssl
  86. tls_port = int(yml['server']['tls']['port'])
  87. # Create TLS context and load cert+key
  88. tls_ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
  89. assert os.path.isfile(yml['server']['tls']['cert']), f"Could not locate domain certificate file {yml['server']['tls']['cert']}"
  90. assert os.path.isfile(yml['server']['tls']['key']), f"Could not locate domain certificate key {yml['server']['tls']['key']}"
  91. tls_ctx.load_cert_chain(yml['server']['tls']['cert'], yml['server']['tls']['key'])
  92. # Add intermediate cert chain, if available
  93. if 'chain' in yml['server']['tls']:
  94. assert os.path.isfile(yml['server']['tls']['chain']), f"Could not locate domain certificate chain {yml['server']['tls']['chain']}"
  95. tls_ctx.load_verify_locations(yml['server']['tls']['chain'])
  96. self.server = ServerConfig(ip=server_ip, port=server_port, payload_limit=server_payload_limit, tls_port=tls_port, tls_ctx=tls_ctx)
  97. # Backlog settings
  98. bma = yml['server'].get('backlog', {}).get('max_age', PUBSUB_DEFAULT_BACKLOG_AGE)
  99. if isinstance(bma, str):
  100. bma = bma.lower()
  101. if bma.endswith('s'):
  102. bma = int(bma.replace('s', ''))
  103. elif bma.endswith('m'):
  104. bma = int(bma.replace('m', '')) * 60
  105. elif bma.endswith('h'):
  106. bma = int(bma.replace('h', '')) * 3600
  107. elif bma.endswith('d'):
  108. bma = int(bma.replace('d', '')) * 86400
  109. bqs = yml['server'].get('backlog', {}).get('size',
  110. PUBSUB_DEFAULT_BACKLOG_SIZE)
  111. bst = yml['server'].get('backlog', {}).get('storage')
  112. self.backlog = BacklogConfig(max_age=bma, queue_size=bqs, storage=bst)
  113. # Payloaders - clients that can post payloads
  114. self.payloaders = [netaddr.IPNetwork(x) for x in yml['clients'].get('payloaders', [])]
  115. # Binary backwards compatibility
  116. self.oldschoolers = yml['clients'].get('oldschoolers', [])
  117. # Secure topics, if any
  118. self.secure_topics = set(yml['clients'].get('secure_topics', []) or [])
  119. class Server:
  120. """Main server class, responsible for handling requests and publishing events """
  121. yaml: dict
  122. config: Configuration
  123. subscribers: list
  124. pending_events: asyncio.Queue
  125. backlog: list
  126. last_ping = typing.Type[float]
  127. server: aiohttp.web.Server
  128. def __init__(self, args: argparse.Namespace):
  129. self.yaml = yaml.safe_load(open(args.config))
  130. self.config = Configuration(self.yaml)
  131. self.subscribers = []
  132. self.pending_events = asyncio.Queue()
  133. self.backlog = []
  134. self.last_ping = time.time()
  135. self.acl_file = args.acl
  136. self.acl = {}
  137. self.load_acl()
  138. def load_acl(self):
  139. """Loads ACL from file"""
  140. try:
  141. self.acl = yaml.safe_load(open(self.acl_file))
  142. print(f"Loaded ACL from {self.acl_file}")
  143. except FileNotFoundError:
  144. print(f"ACL configuration file {self.acl_file} not found, private events will not be broadcast.")
  145. async def poll(self):
  146. """Polls for new stuff to publish, and if found, publishes to whomever wants it."""
  147. while True:
  148. payload: Payload = await self.pending_events.get()
  149. bad_subs: list = await payload.publish(self.subscribers)
  150. self.pending_events.task_done()
  151. # Cull subscribers we couldn't deliver payload to.
  152. for bad_sub in bad_subs:
  153. print("Culling %r due to connection errors" % bad_sub)
  154. try:
  155. self.subscribers.remove(bad_sub)
  156. except ValueError: # Already removed elsewhere
  157. pass
  158. async def handle_request(self, request: aiohttp.web.BaseRequest):
  159. """Generic handler for all incoming HTTP requests"""
  160. resp: typing.Union[aiohttp.web.Response, aiohttp.web.StreamResponse]
  161. # Define response headers first...
  162. headers = {
  163. 'Server': 'PyPubSub/%s' % PUBSUB_VERSION,
  164. 'X-Subscribers': str(len(self.subscribers)),
  165. 'X-Requests': str(self.server.requests_count),
  166. }
  167. subscriber = Subscriber(self, request)
  168. # Is there a basic auth in this request? If so, set up ACL
  169. auth = request.headers.get('Authorization')
  170. if auth:
  171. await subscriber.parse_acl(auth)
  172. # Are we handling a publisher payload request? (PUT/POST)
  173. if request.method in ['PUT', 'POST']:
  174. ip = netaddr.IPAddress(request.remote)
  175. allowed = False
  176. for network in self.config.payloaders:
  177. if ip in network:
  178. allowed = True
  179. break
  180. # Check for secure topics
  181. payload_topics = set(request.path.split("/"))
  182. if any(x in self.config.secure_topics for x in payload_topics):
  183. allowed = False
  184. # Figure out which secure topics we need permission for:
  185. which_secure = [x for x in self.config.secure_topics if x in payload_topics]
  186. # Is the user allowed to post to all of these secure topics?
  187. if subscriber.secure_topics and all(x in subscriber.secure_topics for x in which_secure):
  188. allowed = True
  189. if not allowed:
  190. resp = aiohttp.web.Response(headers=headers, status=403, text=PUBSUB_NOT_ALLOWED)
  191. return resp
  192. if request.can_read_body:
  193. try:
  194. if request.content_length and request.content_length > self.config.server.payload_limit:
  195. resp = aiohttp.web.Response(headers=headers, status=400, text=PUBSUB_PAYLOAD_TOO_LARGE)
  196. return resp
  197. body = await request.text()
  198. as_json = json.loads(body)
  199. assert isinstance(as_json, dict) # Payload MUST be an dictionary object, {...}
  200. pl = Payload(request.path, as_json)
  201. self.pending_events.put_nowait(pl)
  202. # Add to backlog?
  203. if self.config.backlog.queue_size > 0:
  204. self.backlog.append(pl)
  205. # If backlog has grown too large, delete the first (oldest) item in it.
  206. while len(self.backlog) > self.config.backlog.queue_size:
  207. del self.backlog[0]
  208. resp = aiohttp.web.Response(headers=headers, status=202, text=PUBSUB_PAYLOAD_RECEIVED)
  209. return resp
  210. except json.decoder.JSONDecodeError:
  211. resp = aiohttp.web.Response(headers=headers, status=400, text=PUBSUB_BAD_REQUEST)
  212. return resp
  213. except AssertionError:
  214. resp = aiohttp.web.Response(headers=headers, status=400, text=PUBSUB_BAD_PAYLOAD)
  215. return resp
  216. # Is this a subscriber request? (GET)
  217. elif request.method == 'GET':
  218. resp = aiohttp.web.StreamResponse(headers=headers)
  219. # We do not support HTTP 1.0 here...
  220. if request.version.major == 1 and request.version.minor == 0:
  221. return resp
  222. # Subscribe the user before we deal with the potential backlog request and pings
  223. subscriber.connection = resp
  224. self.subscribers.append(subscriber)
  225. resp.content_type = PUBSUB_CONTENT_TYPE
  226. try:
  227. resp.enable_chunked_encoding()
  228. await resp.prepare(request)
  229. # Is the client requesting a backlog of items?
  230. epoch_based_backlog = request.headers.get('X-Fetch-Since')
  231. cursor_based_backlog = request.headers.get('X-Fetch-Since-Cursor')
  232. if epoch_based_backlog: # epoch-based backlog search
  233. try:
  234. backlog_ts = int(backlog)
  235. except ValueError: # Default to 0 if we can't parse the epoch
  236. backlog_ts = 0
  237. # If max age is specified, force the TS to minimum that age
  238. if self.config.backlog.max_age > 0:
  239. backlog_ts = max(backlog_ts, int(time.time() - self.config.backlog.max_age))
  240. # For each item, publish to client if new enough.
  241. for item in self.backlog:
  242. if item.timestamp >= backlog_ts:
  243. await item.publish([subscriber])
  244. if cursor_based_backlog and len(cursor_based_backlog) == 36: # UUID4 cursor-based backlog search
  245. # For each item, publish to client if it was published after this cursor
  246. is_after_cursor = False
  247. for item in self.backlog:
  248. if item.cursor == cursor_based_backlog: # Found cursor, mark it!
  249. is_after_cursor = True
  250. elif is_after_cursor: # This is after the cursor, stream it
  251. await item.publish([subscriber])
  252. while True:
  253. await subscriber.ping()
  254. if subscriber not in self.subscribers: # If we got dislodged somehow, end session
  255. break
  256. await asyncio.sleep(5)
  257. # We may get exception types we don't have imported, so grab ANY exception and kick out the subscriber
  258. except:
  259. pass
  260. if subscriber in self.subscribers:
  261. self.subscribers.remove(subscriber)
  262. return resp
  263. elif request.method == 'HEAD':
  264. resp = aiohttp.web.Response(headers=headers, status=204, text="")
  265. return resp
  266. # I don't know this type of request :/ (DELETE, PATCH, etc)
  267. else:
  268. resp = aiohttp.web.Response(headers=headers, status=400, text=PUBSUB_BAD_REQUEST)
  269. return resp
  270. async def write_backlog_storage(self):
  271. previous_backlog = []
  272. while True:
  273. if self.config.backlog.storage:
  274. try:
  275. backlog_list = self.backlog.copy()
  276. if backlog_list != previous_backlog:
  277. previous_backlog = backlog_list
  278. async with aiofile.AIOFile(self.config.backlog.storage, 'w+') as afp:
  279. offset = 0
  280. for item in backlog_list:
  281. js =json.dumps({
  282. 'timestamp': item.timestamp,
  283. 'topics': item.topics,
  284. 'json': item.json,
  285. 'private': item.private
  286. }) + '\n'
  287. await afp.write(js, offset=offset)
  288. offset += len(js)
  289. await afp.fsync()
  290. except Exception as e:
  291. print(f"Could not write to backlog file {self.config.backlog.storage}: {e}")
  292. await asyncio.sleep(10)
  293. def read_backlog_storage(self):
  294. if self.config.backlog.storage and os.path.exists(self.config.backlog.storage):
  295. try:
  296. readlines = 0
  297. with open(self.config.backlog.storage, 'r') as fp:
  298. for line in fp.readlines():
  299. js = json.loads(line)
  300. readlines += 1
  301. ppath = "/".join(js['topics'])
  302. if js['private']:
  303. ppath = '/private/' + ppath
  304. payload = Payload(ppath, js['json'], js['timestamp'])
  305. self.backlog.append(payload)
  306. if self.config.backlog.queue_size < len(self.backlog):
  307. self.backlog.pop(0)
  308. except Exception as e:
  309. print(f"Error while reading backlog: {e}")
  310. print(f"Read {readlines} objects from {self.config.backlog.storage}, applied {len(self.backlog)} to backlog.")
  311. async def server_loop(self, loop: asyncio.BaseEventLoop):
  312. self.server = aiohttp.web.Server(self.handle_request)
  313. runner = aiohttp.web.ServerRunner(self.server)
  314. await runner.setup()
  315. site = aiohttp.web.TCPSite(runner, self.config.server.ip, self.config.server.port)
  316. await site.start()
  317. print("==== PyPubSub v/%s starting... ====" % PUBSUB_VERSION)
  318. print("==== Serving up PubSub goodness at %s:%s ====" % (
  319. self.config.server.ip, self.config.server.port))
  320. if self.config.server.tls_ctx:
  321. site_tls = aiohttp.web.TCPSite(runner, self.config.server.ip, self.config.server.tls_port, ssl_context=self.config.server.tls_ctx)
  322. await site_tls.start()
  323. print("==== Serving up PubSub TLS goodness at %s:%s ====" % (
  324. self.config.server.ip, self.config.server.tls_port))
  325. if self.config.sqs:
  326. for key, config in self.config.sqs.items():
  327. loop.create_task(plugins.sqs.get_payloads(self, config))
  328. self.read_backlog_storage()
  329. loop.create_task(self.write_backlog_storage())
  330. await self.poll()
  331. def run(self):
  332. loop = asyncio.get_event_loop()
  333. # add a signal handler for SIGUSR2 to reload the ACL from disk
  334. try:
  335. loop.add_signal_handler(signal.SIGUSR2, self.load_acl)
  336. except ValueError:
  337. pass
  338. try:
  339. loop.run_until_complete(self.server_loop(loop))
  340. except KeyboardInterrupt:
  341. pass
  342. loop.close()
  343. class Subscriber:
  344. """Basic subscriber (client) class. Holds information about the connection and ACL"""
  345. acl: dict
  346. topics: typing.List[typing.List[str]]
  347. def __init__(self, server: Server, request: aiohttp.web.BaseRequest):
  348. self.connection: typing.Optional[aiohttp.web.StreamResponse] = None
  349. self.acl = {}
  350. self.server = server
  351. self.lock = asyncio.Lock()
  352. self.secure_topics = []
  353. # Set topics subscribed to
  354. self.topics = []
  355. for topic_batch in request.path.split(','):
  356. sub_to = [x for x in topic_batch.split('/') if x]
  357. self.topics.append(sub_to)
  358. # Is the client old and expecting zero-terminators?
  359. self.old_school = False
  360. for ua in self.server.config.oldschoolers:
  361. if ua in request.headers.get('User-Agent', ''):
  362. self.old_school = True
  363. break
  364. async def parse_acl(self, basic: str):
  365. """Sets the ACL if possible, based on Basic Auth"""
  366. try:
  367. decoded = str(base64.decodebytes(bytes(basic.replace('Basic ', ''), 'ascii')), 'utf-8')
  368. u, p = decoded.split(':', 1)
  369. if u in self.server.acl:
  370. acl_pass = self.server.acl[u].get('password')
  371. if acl_pass and acl_pass == p:
  372. acl = self.server.acl[u].get('acl', {})
  373. # Vet ACL for user
  374. assert isinstance(acl, dict), f"ACL for user {u} " \
  375. f"must be a dictionary of sub-IDs and topics, but is not."
  376. # Make sure each ACL segment is a list of topics
  377. for k, v in acl.items():
  378. assert isinstance(v, list), f"ACL segment {k} for user {u} is not a list of topics!"
  379. print(f"Client {u} successfully authenticated (and ACL is valid).")
  380. self.acl = acl
  381. self.secure_topics = set(self.server.acl[u].get('topics', []) or [])
  382. elif self.server.config.ldap:
  383. acl = {}
  384. groups = await self.server.config.ldap.get_groups(u,p)
  385. # Make sure each ACL segment is a list of topics
  386. for k, v in self.server.config.ldap.acl.items():
  387. if k in groups:
  388. assert isinstance(v, dict), f"ACL segment {k} for user {u} is not a dictionary of segments!"
  389. for segment, topics in v.items():
  390. print(f"Enabling ACL segment {segment} for user {u}")
  391. assert isinstance(topics,
  392. list), f"ACL segment {segment} for user {u} is not a list of topics!"
  393. acl[segment] = topics
  394. self.acl = acl
  395. except binascii.Error as e:
  396. pass # Bad Basic Auth params, bail quietly
  397. except AssertionError as e:
  398. print(e)
  399. print(f"ACL configuration error: ACL scheme for {u} contains errors, setting ACL to nothing.")
  400. except Exception as e:
  401. print(f"Basic unknown exception occurred: {e}")
  402. async def ping(self):
  403. """Generic ping-back to the client"""
  404. js = b"%s\n" % json.dumps({"stillalive": time.time()}).encode('utf-8')
  405. if self.old_school:
  406. js += b"\0"
  407. async with self.lock:
  408. await asyncio.wait_for(self.connection.write(js), timeout=PUBSUB_WRITE_TIMEOUT)
  409. class Payload:
  410. """A payload (event) object sent by a registered publisher."""
  411. def __init__(self, path: str, data: dict, timestamp: typing.Optional[float] = None):
  412. self.json = data
  413. self.timestamp = timestamp or time.time()
  414. self.topics = [x for x in path.split('/') if x]
  415. self.private = False
  416. self.cursor = str(uuid.uuid4()) # Event cursor for playback - UUID4 style
  417. # Private payload?
  418. if self.topics and self.topics[0] == 'private':
  419. self.private = True
  420. del self.topics[0] # Remove the private bit from topics now.
  421. # Set standard pubsub meta data in the payload
  422. self.json['pubsub_timestamp'] = self.timestamp
  423. self.json['pubsub_topics'] = self.topics
  424. self.json['pubsub_path'] = path
  425. self.json['pubsub_cursor'] = self.cursor
  426. async def publish(self, subscribers: typing.List[Subscriber]):
  427. """Publishes an object to all subscribers using those topics (or a sub-set thereof)"""
  428. js = b"%s\n" % json.dumps(self.json).encode('utf-8')
  429. ojs = js + b"\0"
  430. bad_subs = []
  431. for sub in subscribers:
  432. # If a private payload, check ACL and bail if not a match
  433. if self.private:
  434. can_see = False
  435. for key, private_topics in sub.acl.items():
  436. if all(el in self.topics for el in private_topics):
  437. can_see = True
  438. break
  439. if not can_see:
  440. continue
  441. # If subscribed to all the topics, tell a subscriber about this
  442. for topic_batch in sub.topics:
  443. if all(el in self.topics for el in topic_batch):
  444. try:
  445. if sub.old_school:
  446. async with sub.lock:
  447. await asyncio.wait_for(sub.connection.write(ojs), timeout=PUBSUB_WRITE_TIMEOUT)
  448. else:
  449. async with sub.lock:
  450. await asyncio.wait_for(sub.connection.write(js), timeout=PUBSUB_WRITE_TIMEOUT)
  451. except Exception:
  452. bad_subs.append(sub)
  453. break
  454. return bad_subs
  455. if __name__ == '__main__':
  456. parser = argparse.ArgumentParser()
  457. parser.add_argument("--config", help="Configuration file to load (default: pypubsub.yaml)", default="pypubsub.yaml")
  458. parser.add_argument("--acl", help="ACL Configuration file to load (default: pypubsub_acl.yaml)",
  459. default="pypubsub_acl.yaml")
  460. cliargs = parser.parse_args()
  461. pubsub_server = Server(cliargs)
  462. pubsub_server.run()