pypubsub.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359
  1. #!/usr/bin/env python3
  2. # Licensed to the Apache Software Foundation (ASF) under one
  3. # or more contributor license agreements. See the NOTICE file
  4. # distributed with this work for additional information
  5. # regarding copyright ownership. The ASF licenses this file
  6. # to you under the Apache License, Version 2.0 (the
  7. # "License"); you may not use this file except in compliance
  8. # with the License. You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing,
  13. # software distributed under the License is distributed on an
  14. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  15. # KIND, either express or implied. See the License for the
  16. # specific language governing permissions and limitations
  17. # under the License.
  18. """PyPubSub - a simple publisher/subscriber service written in Python 3"""
  19. import asyncio
  20. import aiohttp.web
  21. import time
  22. import json
  23. import yaml
  24. import netaddr
  25. import binascii
  26. import base64
  27. import argparse
  28. import collections
  29. import plugins.ldap
  30. import plugins.sqs
  31. # Some consts
  32. PUBSUB_VERSION = '0.6.0'
  33. PUBSUB_CONTENT_TYPE = 'application/vnd.pypubsub-stream'
  34. PUBSUB_DEFAULT_PORT = 2069
  35. PUBSUB_DEFAULT_IP = '0.0.0.0'
  36. PUBSUB_DEFAULT_MAX_PAYLOAD_SIZE = 102400
  37. PUBSUB_DEFAULT_BACKLOG_SIZE = 0
  38. PUBSUB_DEFAULT_BACKLOG_AGE = 0
  39. PUBSUB_BAD_REQUEST = "I could not understand your request, sorry! Please see https://pubsub.apache.org/api.html \
  40. for usage documentation.\n"
  41. PUBSUB_PAYLOAD_RECEIVED = "Payload received, thank you very much!\n"
  42. PUBSUB_NOT_ALLOWED = "You are not authorized to deliver payloads!\n"
  43. PUBSUB_BAD_PAYLOAD = "Bad payload type. Payloads must be JSON dictionary objects, {..}!\n"
  44. PUBSUB_PAYLOAD_TOO_LARGE = "Payload is too large for me to serve, please make it shorter.\n"
  45. class Configuration:
  46. def __init__(self, yml):
  47. # LDAP Settings
  48. self.ldap = None
  49. lyml = yml.get('clients', {}).get('ldap')
  50. if isinstance(lyml, dict):
  51. self.ldap = plugins.ldap.LDAPConnection(lyml)
  52. # SQS?
  53. self.sqs = yml.get('sqs')
  54. # Main server config
  55. self.server = collections.namedtuple('serverConfig', 'ip port payload_limit')
  56. self.server.ip = yml['server'].get('bind', PUBSUB_DEFAULT_IP)
  57. self.server.port = int(yml['server'].get('port', PUBSUB_DEFAULT_PORT))
  58. self.server.payload_limit = int(yml['server'].get('max_payload_size', PUBSUB_DEFAULT_MAX_PAYLOAD_SIZE))
  59. # Backlog settings
  60. self.backlog = collections.namedtuple('backlogConfig', 'max_age queue_size')
  61. bma = yml['server'].get('backlog', {}).get('max_age', PUBSUB_DEFAULT_BACKLOG_AGE)
  62. if isinstance(bma, str):
  63. bma = bma.lower()
  64. if bma.endswith('s'):
  65. bma = int(bma.replace('s', ''))
  66. elif bma.endswith('m'):
  67. bma = int(bma.replace('m', '')) * 60
  68. elif bma.endswith('h'):
  69. bma = int(bma.replace('h', '')) * 3600
  70. elif bma.endswith('d'):
  71. bma = int(bma.replace('d', '')) * 86400
  72. self.backlog.max_age = bma
  73. self.backlog.queue_size = yml['server'].get('backlog', {}).get('size',
  74. PUBSUB_DEFAULT_BACKLOG_SIZE)
  75. # Payloaders - clients that can post payloads
  76. self.payloaders = [netaddr.IPNetwork(x) for x in yml['clients'].get('payloaders', [])]
  77. # Binary backwards compatibility
  78. self.oldschoolers = yml['clients'].get('oldschoolers', [])
  79. class Server:
  80. """Main server class, responsible for handling requests and publishing events """
  81. def __init__(self, args):
  82. self.yaml = yaml.safe_load(open(args.config))
  83. self.config = Configuration(self.yaml)
  84. self.subscribers = []
  85. self.pending_events = []
  86. self.backlog = []
  87. self.last_ping = time.time()
  88. self.server = None
  89. self.acl = {}
  90. try:
  91. self.acl = yaml.safe_load(open(args.acl))
  92. except FileNotFoundError:
  93. print(f"ACL configuration file {args.acl} not found, private events will not be broadcast.")
  94. async def poll(self):
  95. """Polls for new stuff to publish, and if found, publishes to whomever wants it."""
  96. while True:
  97. for payload in self.pending_events:
  98. bad_subs = await payload.publish(self.subscribers)
  99. # Cull subscribers we couldn't deliver payload to.
  100. for bad_sub in bad_subs:
  101. print("Culling %r due to connection errors" % bad_sub)
  102. self.subscribers.remove(bad_sub)
  103. self.pending_events = []
  104. await asyncio.sleep(0.5)
  105. async def handle_request(self, request):
  106. """Generic handler for all incoming HTTP requests"""
  107. # Define response headers first...
  108. headers = {
  109. 'Server': 'PyPubSub/%s' % PUBSUB_VERSION,
  110. 'X-Subscribers': str(len(self.subscribers)),
  111. 'X-Requests': str(self.server.requests_count),
  112. }
  113. # Are we handling a publisher payload request? (PUT/POST)
  114. if request.method in ['PUT', 'POST']:
  115. ip = netaddr.IPAddress(request.remote)
  116. allowed = False
  117. for network in self.config.payloaders:
  118. if ip in network:
  119. allowed = True
  120. break
  121. if not allowed:
  122. resp = aiohttp.web.Response(headers=headers, status=403, text=PUBSUB_NOT_ALLOWED)
  123. return resp
  124. if request.can_read_body:
  125. try:
  126. if request.content_length > self.config.server.payload_limit:
  127. resp = aiohttp.web.Response(headers=headers, status=400, text=PUBSUB_PAYLOAD_TOO_LARGE)
  128. return resp
  129. body = await request.text()
  130. as_json = json.loads(body)
  131. assert isinstance(as_json, dict) # Payload MUST be an dictionary object, {...}
  132. pl = Payload(request.path, as_json)
  133. self.pending_events.append(pl)
  134. # Add to backlog?
  135. if self.config.backlog.queue_size > 0:
  136. self.backlog.append(pl)
  137. # If backlog has grown too large, delete the first (oldest) item in it.
  138. if len(self.backlog) > self.config.backlog.queue_size:
  139. del self.backlog[0]
  140. resp = aiohttp.web.Response(headers=headers, status=202, text=PUBSUB_PAYLOAD_RECEIVED)
  141. return resp
  142. except json.decoder.JSONDecodeError:
  143. resp = aiohttp.web.Response(headers=headers, status=400, text=PUBSUB_BAD_REQUEST)
  144. return resp
  145. except AssertionError:
  146. resp = aiohttp.web.Response(headers=headers, status=400, text=PUBSUB_BAD_PAYLOAD)
  147. return resp
  148. # Is this a subscriber request? (GET)
  149. elif request.method == 'GET':
  150. resp = aiohttp.web.StreamResponse(headers=headers)
  151. # We do not support HTTP 1.0 here...
  152. if request.version.major == 1 and request.version.minor == 0:
  153. return resp
  154. subscriber = Subscriber(self, resp, request)
  155. # Is there a basic auth in this request? If so, set up ACL
  156. auth = request.headers.get('Authorization')
  157. if auth:
  158. subscriber.acl = await subscriber.parse_acl(auth)
  159. # Subscribe the user before we deal with the potential backlog request and pings
  160. self.subscribers.append(subscriber)
  161. resp.content_type = PUBSUB_CONTENT_TYPE
  162. try:
  163. resp.enable_chunked_encoding()
  164. await resp.prepare(request)
  165. # Is the client requesting a backlog of items?
  166. backlog = request.headers.get('X-Fetch-Since')
  167. if backlog:
  168. try:
  169. backlog_ts = int(backlog)
  170. except ValueError: # Default to 0 if we can't parse the epoch
  171. backlog_ts = 0
  172. # If max age is specified, force the TS to minimum that age
  173. if self.config.backlog.max_age > 0:
  174. backlog_ts = max(backlog_ts, time.time() - self.config.backlog.max_age)
  175. # For each item, publish to client if new enough.
  176. for item in self.backlog:
  177. if item.timestamp >= backlog_ts:
  178. await item.publish([subscriber])
  179. while True:
  180. await subscriber.ping()
  181. if subscriber not in self.subscribers: # If we got dislodged somehow, end session
  182. break
  183. await asyncio.sleep(5)
  184. # We may get exception types we don't have imported, so grab ANY exception and kick out the subscriber
  185. except:
  186. pass
  187. if subscriber in self.subscribers:
  188. self.subscribers.remove(subscriber)
  189. return resp
  190. elif request.method == 'HEAD':
  191. resp = aiohttp.web.Response(headers=headers, status=204, text="")
  192. return resp
  193. # I don't know this type of request :/ (DELETE, PATCH, etc)
  194. else:
  195. resp = aiohttp.web.Response(headers=headers, status=400, text=PUBSUB_BAD_REQUEST)
  196. return resp
  197. async def server_loop(self, loop):
  198. self.server = aiohttp.web.Server(self.handle_request)
  199. runner = aiohttp.web.ServerRunner(self.server)
  200. await runner.setup()
  201. site = aiohttp.web.TCPSite(runner, self.config.server.ip, self.config.server.port)
  202. await site.start()
  203. print("==== PyPubSub v/%s starting... ====" % PUBSUB_VERSION)
  204. print("==== Serving up PubSub goodness at %s:%s ====" % (
  205. self.config.server.ip, self.config.server.port))
  206. if self.config.sqs:
  207. for key, config in self.config.sqs.items():
  208. loop.create_task(plugins.sqs.get_payloads(self, config))
  209. await self.poll()
  210. def run(self):
  211. loop = asyncio.get_event_loop()
  212. try:
  213. loop.run_until_complete(self.server_loop(loop))
  214. except KeyboardInterrupt:
  215. pass
  216. loop.close()
  217. class Subscriber:
  218. """Basic subscriber (client) class. Holds information about the connection and ACL"""
  219. def __init__(self, server, connection, request):
  220. self.connection = connection
  221. self.acl = {}
  222. self.server = server
  223. # Set topics subscribed to
  224. self.topics = [x for x in request.path.split('/') if x]
  225. # Is the client old and expecting zero-terminators?
  226. self.old_school = False
  227. for ua in self.server.config.oldschoolers:
  228. if ua in request.headers.get('User-Agent', ''):
  229. self.old_school = True
  230. break
  231. async def parse_acl(self, basic):
  232. """Sets the ACL if possible, based on Basic Auth"""
  233. try:
  234. decoded = str(base64.decodebytes(bytes(basic.replace('Basic ', ''), 'ascii')), 'utf-8')
  235. u, p = decoded.split(':', 1)
  236. if u in self.server.acl:
  237. acl_pass = self.server.acl[u].get('password')
  238. if acl_pass and acl_pass == p:
  239. acl = self.server.acl[u].get('acl', {})
  240. # Vet ACL for user
  241. assert isinstance(acl, dict), f"ACL for user {u} " \
  242. f"must be a dictionary of sub-IDs and topics, but is not."
  243. # Make sure each ACL segment is a list of topics
  244. for k, v in acl.items():
  245. assert isinstance(v, list), f"ACL segment {k} for user {u} is not a list of topics!"
  246. print(f"Client {u} successfully authenticated (and ACL is valid).")
  247. return acl
  248. elif self.server.config.ldap:
  249. acl = {}
  250. groups = await self.server.config.ldap.get_groups(u,p)
  251. # Make sure each ACL segment is a list of topics
  252. for k, v in self.server.config.ldap.acl.items():
  253. if k in groups:
  254. assert isinstance(v, dict), f"ACL segment {k} for user {u} is not a dictionary of segments!"
  255. for segment, topics in v.items():
  256. print(f"Enabling ACL segment {segment} for user {u}")
  257. assert isinstance(topics,
  258. list), f"ACL segment {segment} for user {u} is not a list of topics!"
  259. acl[segment] = topics
  260. return acl
  261. except binascii.Error as e:
  262. pass # Bad Basic Auth params, bail quietly
  263. except AssertionError as e:
  264. print(e)
  265. print(f"ACL configuration error: ACL scheme for {u} contains errors, setting ACL to nothing.")
  266. except Exception as e:
  267. print(f"Basic unknown exception occurred: {e}")
  268. return {}
  269. async def ping(self):
  270. """Generic ping-back to the client"""
  271. js = b"%s\n" % json.dumps({"stillalive": time.time()}).encode('utf-8')
  272. if self.old_school:
  273. js += b"\0"
  274. await self.connection.write(js)
  275. class Payload:
  276. """A payload (event) object sent by a registered publisher."""
  277. def __init__(self, path, data):
  278. self.json = data
  279. self.timestamp = time.time()
  280. self.topics = [x for x in path.split('/') if x]
  281. self.private = False
  282. # Private payload?
  283. if self.topics and self.topics[0] == 'private':
  284. self.private = True
  285. del self.topics[0] # Remove the private bit from topics now.
  286. self.json['pubsub_timestamp'] = self.timestamp
  287. self.json['pubsub_topics'] = self.topics
  288. self.json['pubsub_path'] = path
  289. async def publish(self, subscribers):
  290. """Publishes an object to all subscribers using those topics (or a sub-set thereof)"""
  291. js = b"%s\n" % json.dumps(self.json).encode('utf-8')
  292. ojs = js + b"\0"
  293. bad_subs = []
  294. for sub in subscribers:
  295. # If a private payload, check ACL and bail if not a match
  296. if self.private:
  297. can_see = False
  298. for key, private_topics in sub.acl.items():
  299. if all(el in self.topics for el in private_topics):
  300. can_see = True
  301. break
  302. if not can_see:
  303. continue
  304. # If subscribed to all the topics, tell a subscriber about this
  305. if all(el in self.topics for el in sub.topics):
  306. try:
  307. if sub.old_school:
  308. await sub.connection.write(ojs)
  309. else:
  310. await sub.connection.write(js)
  311. except Exception:
  312. bad_subs.append(sub)
  313. return bad_subs
  314. if __name__ == '__main__':
  315. parser = argparse.ArgumentParser()
  316. parser.add_argument("--config", help="Configuration file to load (default: pypubsub.yaml)", default="pypubsub.yaml")
  317. parser.add_argument("--acl", help="ACL Configuration file to load (default: pypubsub_acl.yaml)",
  318. default="pypubsub_acl.yaml")
  319. cliargs = parser.parse_args()
  320. pubsub_server = Server(cliargs)
  321. pubsub_server.run()