Move websocket bus to asyncio operation

This commit is contained in:
Derek 2025-02-17 23:25:50 -05:00
parent 7685170714
commit 7ebf0b48a4
2 changed files with 49 additions and 70 deletions

View file

@ -64,11 +64,6 @@ class MainProcess:
self.plugins = {}
self.event_queue = asyncio.Queue()
# Init websocket server (event bus)
# HACK: Must be done here to avoid shadowing its asyncio loop
self.server_process = WebsocketServerProcess(*self.bus_conf)
self.server_process.start()
# Save sys.path since some config will clobber it
self._initial_syspath = sys.path
@ -78,10 +73,6 @@ class MainProcess:
del self.plugins[plugin_name]
del plugin
def _get_event_from_pipe(self, pipe):
event = pipe.recv()
self.event_queue.put_nowait(event)
def _setup_webserver(self):
self.webserver = Quart(__name__, static_folder=None, template_folder=None)
listen = ':'.join(self.web_conf)
@ -137,10 +128,10 @@ class MainProcess:
if event is None:
break
else:
self.server_process.message_pipe.send(event)
await self.bus_server.send(event)
logger.debug(f'Event after plugin chain - {event}')
elif isinstance(event, Delete):
self.server_process.message_pipe.send(event)
await self.bus_server.send(event)
else:
logger.error(f'Unknown data in event loop - {event}')
@ -250,6 +241,15 @@ class MainProcess:
self._unload_plugin(plugin_name)
sys.path = self._initial_syspath
async def _discount_repl(self):
# REVIEW: Not a good UX at the moment (as new logs clobber the terminal entry)
async for line in reader:
line = line.strip()
if line == b'reload':
self.reload_ev.set()
elif line == b'quit':
self.shutdown_ev.set()
async def run(self):
self.shutdown_ev = asyncio.Event()
self.reload_ev = asyncio.Event()
@ -258,26 +258,13 @@ class MainProcess:
try:
# System setup
## Bridge websocket server pipe to asyncio loop
## REVIEW: This does not work on windows!!!! add_reader is not implemented
## in a way that supports pipes on either windows loop runners
ws_pipe = self.server_process.message_pipe
loop.add_reader(ws_pipe.fileno(), lambda: self._get_event_from_pipe(ws_pipe))
## Register stdin handler
## Make stdin handler
reader = asyncio.StreamReader()
await loop.connect_read_pipe(lambda: asyncio.StreamReaderProtocol(reader), sys.stdin)
async def discount_repl():
# REVIEW: Not a good UX at the moment (as new logs clobber the terminal entry)
async for line in reader:
line = line.strip()
if line == b'reload':
self.reload_ev.set()
elif line == b'quit':
self.shutdown_ev.set()
self.cli_task = loop.create_task(discount_repl())
## Scheduler for timed tasks
self._skehdule = TimedScheduler(max_tasks=1)
self._skehdule.start()
self.cli_task = loop.create_task(self._discount_repl())
## Init websocket server (external end of the event bus)
self.bus_server = WebsocketServerProcess(self.event_queue, *self.bus_conf)
self.bus_task = loop.create_task(self.bus_server.run())
## UI server
serve_coro = self._setup_webserver()
self.webserver_task = loop.create_task(serve_coro)
@ -294,7 +281,7 @@ class MainProcess:
logger.info(f'Ready to rumble! Press Ctrl+C to shut down')
reload_task = loop.create_task(self.reload_ev.wait())
done, pending = await asyncio.wait([*user_tasks, self.webserver_task, reload_task], return_when=asyncio.FIRST_COMPLETED)
done, pending = await asyncio.wait([*user_tasks, self.webserver_task, self.bus_task, reload_task], return_when=asyncio.FIRST_COMPLETED)
if reload_task in done:
logger.warn('Reloading (some events may be missed!)')
@ -319,6 +306,7 @@ class MainProcess:
serve_coro = self._setup_webserver()
self.webserver_task = loop.create_task(serve_coro)
else:
logger.debug(f'Task {done} completed - assuming something went wrong!')
break
except KeyboardInterrupt:
pass
@ -336,4 +324,4 @@ class MainProcess:
task.cancel()
await self.user_shutdown()
self.webserver_task.cancel()
self.server_process.terminate()
self.bus_task.cancel()

View file

@ -1,34 +1,28 @@
import asyncio
import json
from multiprocessing import Process, Pipe
import logging
import websockets
from ovtk_audiencekit.events import Event
from ovtk_audiencekit.utils import get_subclasses
from ovtk_audiencekit.utils import get_subclasses, format_exception
logger = logging.getLogger(__name__)
class WebsocketServerProcess(Process):
def __init__(self, bind, port):
super().__init__()
class WebsocketServerProcess:
def __init__(self, event_queue, bind, port):
self._bind = bind
self._port = port
self._pipe, self._caller_pipe = Pipe()
self.event_queue = event_queue
self._send_queue = asyncio.Queue()
self.clients = set()
self._event_classes = get_subclasses(Event)
@property
def message_pipe(self):
return self._caller_pipe
# Data input (external application socket -> plugin/chat pipe)
async def handle_websocket(self, ws, path):
async def _handle_client(self, ws, path):
self.clients.add(ws)
try:
async for message in ws:
@ -39,7 +33,7 @@ class WebsocketServerProcess(Process):
type = type[0]
event_class = next(cls for cls in self._event_classes if cls.__name__ == type)
event = event_class.hydrate(**data.get('data', {}))
self._pipe.send(event)
self.event_queue.put_nowait(event)
else:
logger.warn('Unknown data recieved on websocket', message)
except json.decoder.JSONDecodeError as e:
@ -56,42 +50,39 @@ class WebsocketServerProcess(Process):
self.clients.discard(ws)
# Data output (plugin/chat pipe -> external application socket)
async def handle_pipe(self, pipe_ready):
async def send(self, event):
await self._send_queue.put(event)
async def _send_loop(self):
while True:
# Let other co-routines process until file descriptor is readable
await pipe_ready.wait()
pipe_ready.clear()
# Check if messages exist on the pipe before attempting to recv
# to avoid accidentally blocking the event loop when file
# descriptor does stuff we don't expect
if not self._pipe.poll():
continue
event = self._pipe.recv()
event = await self._send_queue.get()
# Serialize and send to registered clients
serialized = event.serialize()
if self.clients:
await asyncio.wait([self.safe_send(client, serialized) for client in self.clients])
await asyncio.wait([self._safe_send(client, serialized) for client in self.clients])
async def safe_send(self, client, serialized):
async def _safe_send(self, client, serialized):
try:
await client.send(serialized)
except (websockets.exceptions.ConnectionClosedError, websockets.exceptions.ConnectionClosedOK):
self.clients.discard(client)
def run(self):
# Setup asyncio websocket server
start_server = websockets.serve(self.handle_websocket, self._bind, self._port)
asyncio.get_event_loop().run_until_complete(start_server)
async def run(self):
loop = asyncio.get_event_loop()
tasks = set()
# Make an awaitable object that flips when the pipe's underlying file descriptor is readable
pipe_ready = asyncio.Event()
# REVIEW: This does not work on windows!!!!
asyncio.get_event_loop().add_reader(self._pipe.fileno(), pipe_ready.set)
# Make and start our infinite pipe listener task
asyncio.get_event_loop().create_task(self.handle_pipe(pipe_ready))
# Setup websocket server (input)
self.ws_server = await websockets.serve(self._handle_client, self._bind, self._port)
tasks.add(loop.create_task(self.ws_server.serve_forever()))
# Setup sending loop (output)
tasks.add(loop.create_task(self._send_loop()))
# Keep the asyncio code running in this thread until explicitly stopped
try:
asyncio.get_event_loop().run_forever()
except KeyboardInterrupt:
return 0
await asyncio.gather(*tasks)
except Exception as e:
logger.critical(f'Failure in bus process - {e}')
logger.info(format_exception(e))
raise e
finally:
for task in tasks:
task.cancel()