Compare commits

...

2 commits

Author SHA1 Message Date
4c42320c7c [plugins] Add pre-call hooks
This also greatly simplifies execution of the kdl script thing format. I 
have no idea why it was like that before
2025-02-02 23:51:01 -05:00
f593db112c Fix command parsing 2025-02-02 15:42:14 -05:00
5 changed files with 43 additions and 24 deletions

View file

@ -131,7 +131,6 @@ class MainProcess:
if e.fatal: if e.fatal:
self._unload_plugin(e.source) self._unload_plugin(e.source)
except Exception as e: except Exception as e:
self._plugin_error
logger.critical(f'Failure when processing {plugin_name} ({e}) - disabling...') logger.critical(f'Failure when processing {plugin_name} ({e}) - disabling...')
logger.debug(format_exception(e)) logger.debug(format_exception(e))
self._unload_plugin(plugin_name) self._unload_plugin(plugin_name)
@ -239,10 +238,7 @@ class MainProcess:
if plugin_module is None: if plugin_module is None:
logger.error(f'Unknown plugin: {node.name}') logger.error(f'Unknown plugin: {node.name}')
else: else:
res = await plugin_module._call(node.sub, node.tag, *node.args, **node.props, _ctx=global_ctx, _children=node.nodes) await plugin_module._kdl_call(node, global_ctx)
if node.alias:
global_ctx[node.alias] = res
async def user_shutdown(self): async def user_shutdown(self):
for process_name, process in list(reversed(self.chat_processes.items())): for process_name, process in list(reversed(self.chat_processes.items())):

View file

@ -9,6 +9,8 @@ import kdl
import quart import quart
from ovtk_audiencekit.core.Config import kdl_parse_config, compute_dynamic from ovtk_audiencekit.core.Config import kdl_parse_config, compute_dynamic
from ovtk_audiencekit.utils import format_exception
class PluginError(Exception): class PluginError(Exception):
@ -30,6 +32,7 @@ class OvtkBlueprint(quart.Blueprint):
class PluginBase(ABC): class PluginBase(ABC):
plugins = {} plugins = {}
hooks = {} # the hookerrrrrrrr
def __init__(self, chat_processes, event_queue, name, global_ctx, _children=None, **kwargs): def __init__(self, chat_processes, event_queue, name, global_ctx, _children=None, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
@ -56,24 +59,41 @@ class PluginBase(ABC):
def __del__(self): def __del__(self):
if self.plugins.get(self._name) == self: if self.plugins.get(self._name) == self:
del self.plugins[self._name] del self.plugins[self._name]
if self._name in self.hooks:
del self.hooks[self._name]
async def _kdl_call(self, node, _ctx):
args, props = compute_dynamic(node, _ctx=_ctx)
subroutine = node.sub
async def _call(self, subroutine, tag, *args, **kwargs):
try:
if subroutine: if subroutine:
func = self func = self
for accessor in subroutine: for accessor in subroutine:
func = getattr(func, accessor) func = getattr(func, accessor)
else: else:
func = self.run func = self.run
res = func(*args, **kwargs)
for hook in self.hooks.values():
try:
res = hook(self._name, node, _ctx)
if asyncio.iscoroutinefunction(hook):
await res
except Exception as e:
self.logger.warning(f'Failed to run plugin hook: {e}')
self.logger.debug(format_exception(e))
try:
result = func(*args, _children=node.nodes, _ctx=_ctx, **props)
if asyncio.iscoroutinefunction(func): if asyncio.iscoroutinefunction(func):
res = await res result = await result
return res
except Exception as e: except Exception as e:
if isinstance(e, KeyboardInterrupt): if isinstance(e, KeyboardInterrupt):
raise e raise e
raise PluginError(self._name, str(e)) from e raise PluginError(self._name, str(e)) from e
if node.alias:
_ctx[node.alias] = result
async def _tick(self, *args, **kwargs): async def _tick(self, *args, **kwargs):
try: try:
res = self.tick(*args, **kwargs) res = self.tick(*args, **kwargs)
@ -98,6 +118,7 @@ class PluginBase(ABC):
raise e raise e
raise PluginError(self._name, str(e)) from e raise PluginError(self._name, str(e)) from e
# Base class helpers # Base class helpers
def broadcast(self, event): def broadcast(self, event):
"""Send event to every active chat""" """Send event to every active chat"""
@ -106,7 +127,10 @@ class PluginBase(ABC):
continue continue
proc.control_pipe.send(event) proc.control_pipe.send(event)
async def execute_kdl(self, nodes, *py_args, _ctx={}, **py_props): def register_hook(self, hook):
self.hooks[self._name] = hook
async def execute_kdl(self, nodes, _ctx={}):
""" """
Run other plugins as configured by the passed KDL nodes collection Run other plugins as configured by the passed KDL nodes collection
If this was done in response to an event, pass it as 'event' in _ctx! If this was done in response to an event, pass it as 'event' in _ctx!
@ -114,16 +138,14 @@ class PluginBase(ABC):
_ctx = copy.deepcopy({**self._global_ctx, **_ctx}) _ctx = copy.deepcopy({**self._global_ctx, **_ctx})
for node in nodes: for node in nodes:
try: try:
args, props = compute_dynamic(node, _ctx=_ctx)
target = self.plugins.get(node.name) target = self.plugins.get(node.name)
if target is None: if target is None:
self.logger.warning(f'Could not find plugin or builtin with name {node.name}') self.logger.warning(f'Could not find plugin or builtin with name {node.name}')
break break
result = await target._call(node.sub, node.tag, *args, *py_args, **props, _ctx=_ctx, **py_props, _children=node.nodes) await target._kdl_call(node, _ctx)
if node.alias:
_ctx[node.alias] = result
except Exception as e: except Exception as e:
self.logger.warning(f'Failed to execute defered KDL: {e}') self.logger.warning(f'Failed to execute defered KDL: {e}')
self.logger.debug(format_exception(e))
break break
@ -134,6 +156,7 @@ class PluginBase(ABC):
""" """
self._event_queue.put_nowait(event) self._event_queue.put_nowait(event)
# User-defined # User-defined
async def setup(self, *args, **kwargs): async def setup(self, *args, **kwargs):
"""Called when plugin is being loaded.""" """Called when plugin is being loaded."""

View file

@ -87,7 +87,7 @@ class JailPlugin(PluginBase):
if isinstance(event, Message): if isinstance(event, Message):
if self.jail_command.invoked(event): if self.jail_command.invoked(event):
try: try:
args = self.jail_command.parse(event.text) args, _ = self.jail_command.parse(event.text)
end_date = maya.when(args['length'], prefer_dates_from='future') end_date = maya.when(args['length'], prefer_dates_from='future')
deets = self.chats[event.via].shared.api.get_user_details(args['username']) deets = self.chats[event.via].shared.api.get_user_details(args['username'])
if deets is None: if deets is None:
@ -117,7 +117,7 @@ class JailPlugin(PluginBase):
self.send_to_bus(weewoo) self.send_to_bus(weewoo)
elif self.unjail_command.invoked(event): elif self.unjail_command.invoked(event):
try: try:
args = self.jail_command.parse(event.text) args, _ = self.jail_command.parse(event.text)
deets = self.chats[event.via].shared.api.get_user_details(args['username']) deets = self.chats[event.via].shared.api.get_user_details(args['username'])
if deets is None: if deets is None:
raise ValueError() raise ValueError()

View file

@ -47,7 +47,7 @@ class ShoutoutPlugin(PluginBase):
if isinstance(event, Message): if isinstance(event, Message):
if self.command and self.command.invoked(event): if self.command and self.command.invoked(event):
try: try:
args = self.command.parse(event.text) args, _ = self.command.parse(event.text)
except ArgumentError as e: except ArgumentError as e:
msg = SysMessage(self._name, str(e), replies_to=event) msg = SysMessage(self._name, str(e), replies_to=event)
self.chats[event.via].send(msg) self.chats[event.via].send(msg)

View file

@ -139,7 +139,7 @@ class CommandPlugin(PluginBase):
if self.help_cmd.invoked(event): if self.help_cmd.invoked(event):
try: try:
args = self.help_cmd.parse(event.text) args, _ = self.help_cmd.parse(event.text)
except argparse.ArgumentError as e: except argparse.ArgumentError as e:
msg = SysMessage(self._name, f"{e}. See !help {self.help_cmd.name}", replies_to=event) msg = SysMessage(self._name, f"{e}. See !help {self.help_cmd.name}", replies_to=event)
self.chats[event.via].send(msg) self.chats[event.via].send(msg)