Skip to content

Commit

Permalink
✨ version 0.26.1
Browse files Browse the repository at this point in the history
add parse wrapper for ext
  • Loading branch information
RF-Tar-Railt committed Oct 5, 2023
1 parent 20ba409 commit a63e809
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 17 deletions.
10 changes: 3 additions & 7 deletions docs.md
Original file line number Diff line number Diff line change
Expand Up @@ -261,13 +261,9 @@ class LLMExtension(Extension):
def post_init(self, alc: Alconna) -> None:
self.llm.add_context(alc.command, alc.meta.description)

async def message_provider(
self, event, state, bot, use_origin: bool = False
):
if event.get_type() != "message":
return
resp = await self.llm.input(str(event.get_message()))
return event.get_message().__class__(resp.content)
async def receive_wrapper(self, bot, event, receive):
resp = await self.llm.input(str(receive))
return receive.__class__(resp.content)

matcher = on_alconna(Alconna(...), extensions=[DemoExtension(LLM)])
...
Expand Down
2 changes: 1 addition & 1 deletion src/nonebot_plugin_alconna/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@
from .consts import ALCONNA_EXEC_RESULT as ALCONNA_EXEC_RESULT
from .extension import add_global_extension as add_global_extension

__version__ = "0.26.0"
__version__ = "0.26.1"

_meta_source = {
"name": "Alconna 插件",
Expand Down
6 changes: 3 additions & 3 deletions src/nonebot_plugin_alconna/adapters/discord.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ async def edit_response(
bot = current_bot.get()
if not isinstance(event, ApplicationCommandInteractionEvent) or not isinstance(bot, Bot):
raise ValueError("Invalid event or bot")
_message = await matcher.executor.send_hook(bot, event, matcher.convert(message))
_message = await matcher.executor.send_wrapper(bot, event, matcher.convert(message))
if isinstance(_message, UniMessage):
message_data = parse_message(await _message.export(bot, fallback))
else:
Expand Down Expand Up @@ -490,7 +490,7 @@ async def send_followup_msg(
bot = current_bot.get()
if not isinstance(event, ApplicationCommandInteractionEvent) or not isinstance(bot, Bot):
raise ValueError("Invalid event or bot")
_message = await matcher.executor.send_hook(bot, event, matcher.convert(message))
_message = await matcher.executor.send_wrapper(bot, event, matcher.convert(message))
if isinstance(_message, UniMessage):
message_data = parse_message(await _message.export(bot, fallback))
else:
Expand Down Expand Up @@ -526,7 +526,7 @@ async def edit_followup_msg(
bot = current_bot.get()
if not isinstance(event, ApplicationCommandInteractionEvent) or not isinstance(bot, Bot):
raise ValueError("Invalid event or bot")
_message = await matcher.executor.send_hook(bot, event, matcher.convert(message))
_message = await matcher.executor.send_wrapper(bot, event, matcher.convert(message))
if isinstance(_message, UniMessage):
message_data = parse_message(await _message.export(bot, fallback))
else:
Expand Down
45 changes: 41 additions & 4 deletions src/nonebot_plugin_alconna/extension.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from __future__ import annotations

import re
import asyncio
import functools
import importlib as imp
from typing_extensions import Self
from typing import Literal, TypeVar
from abc import ABCMeta, abstractmethod

from tarina import lang
from arclet.alconna import Alconna
from nonebot.typing import T_State
from arclet.alconna import Alconna, Arparma
from nonebot.adapters import Bot, Event, Message

from .uniseg import UniMessage, FallbackMessage
Expand All @@ -19,6 +20,15 @@


class Extension(metaclass=ABCMeta):
_overrides: dict[str, bool]

def __init_subclass__(cls, **kwargs):
cls._overrides = {
"send_wrapper": cls.send_wrapper == Extension.send_wrapper,
"receive_wrapper": cls.receive_wrapper == Extension.receive_wrapper,
"parse_wrapper": cls.parse_wrapper == Extension.parse_wrapper,
}

@property
@abstractmethod
def priority(self) -> int:
Expand Down Expand Up @@ -52,7 +62,15 @@ async def message_provider(
return None
return msg

async def send_hook(self, bot: Bot, event: Event, send: TM) -> TM:
async def receive_wrapper(self, bot: Bot, event: Event, receive: TM) -> TM:
"""接收消息后的钩子函数。"""
return receive

async def parse_wrapper(self, bot: Bot, state: T_State, event: Event, res: Arparma) -> None:
"""解析消息后的钩子函数。"""
pass

async def send_wrapper(self, bot: Bot, event: Event, send: TM) -> TM:
"""发送消息前的钩子函数。"""
return send

Expand Down Expand Up @@ -131,10 +149,29 @@ async def message_provider(
if exc is not None:
raise exc

async def send_hook(self, bot: Bot, event: Event, send: TM) -> TM:
return None

async def receive_wrapper(self, bot: Bot, event: Event, receive: TM) -> TM:
res = receive
for ext in self.context:
if ext._overrides["receive_wrapper"]:
res = await ext.receive_wrapper(bot, event, res)
return res

async def parse_wrapper(self, bot: Bot, state: T_State, event: Event, res: Arparma) -> None:
await asyncio.gather(
*(
ext.parse_wrapper(bot, state, event, res)
for ext in self.context
if ext._overrides["parse_wrapper"]
)
)

async def send_wrapper(self, bot: Bot, event: Event, send: TM) -> TM:
res = send
for ext in self.context:
res = await ext.send_hook(bot, event, res)
if ext._overrides["send_wrapper"]:
res = await ext.send_wrapper(bot, event, res)
return res

def post_init(self, alc: Alconna) -> None:
Expand Down
8 changes: 7 additions & 1 deletion src/nonebot_plugin_alconna/matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from nonebot.params import Depends
from nonebot.permission import Permission
from nonebot.dependencies import Dependent
from nonebot.message import run_postprocessor
from nepattern import STRING, AnyOne, AnyString
from nonebot.consts import ARG_KEY, RECEIVE_KEY
from tarina import lang, is_awaitable, run_always_await
Expand Down Expand Up @@ -356,7 +357,7 @@ async def send(
"""
bot = current_bot.get()
event = current_event.get()
_message = await cls.executor.send_hook(bot, event, cls.convert(message))
_message = await cls.executor.send_wrapper(bot, event, cls.convert(message))
if isinstance(_message, UniMessage):
res = await _message.export(bot, fallback)
else:
Expand Down Expand Up @@ -653,3 +654,8 @@ def build(
params.pop("__class__")
alc = super().build()
return on_alconna(alc, **params)


@run_postprocessor
def _exit_executor(matcher: AlconnaMatcher):
matcher.executor.context.clear()
4 changes: 3 additions & 1 deletion src/nonebot_plugin_alconna/rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,8 @@ async def __call__(self, event: Event, state: T_State, bot: Bot) -> bool:
self.executor.select(bot, event)
if not (msg := await self.executor.message_provider(event, state, bot, self.use_origin)):
return False
elif isinstance(msg, UniMessage):
msg = await self.executor.receive_wrapper(bot, event, msg)
if isinstance(msg, UniMessage):
msg = await msg.export(bot, fallback=True)
Arparma._additional.update(bot=lambda: bot, event=lambda: event, state=lambda: state)
with output_manager.capture(self.command.name) as cap:
Expand All @@ -214,6 +215,7 @@ async def __call__(self, event: Event, state: T_State, bot: Bot) -> bool:
if self.auto_send and may_help_text:
await self._send(may_help_text, bot, event, arp)
return False
await self.executor.parse_wrapper(bot, state, event, arp)
state[ALCONNA_RESULT] = CommandResult(self.command, arp, may_help_text)
exec_result = self.command.exec_result
for key, value in exec_result.items():
Expand Down

0 comments on commit a63e809

Please sign in to comment.