diff --git a/.github/ISSUE_TEMPLATE/bug.md b/.github/ISSUE_TEMPLATE/bug.md new file mode 100644 index 0000000..1d0b478 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug.md @@ -0,0 +1,34 @@ +--- +name: Bug 报告 +about: 有关 bug 的报告 +title: "[Bug]" +labels: bug, triage +assignees: "" +--- + +## 请确认: + +* [ ] 问题的标题明确 +* [ ] 我翻阅过其他的 issue 并且找不到类似的问题 +* [ ] 我已经阅读了[相关文档](https://satori.js.org/zh-CN/) 并仍然认为这是一个Bug + +# Bug + +## 问题 + + +## 如何复现 + + +## 预期行为 + + +## 使用环境: +- 操作系统 (Windows/Linux/Mac): +- Python 版本: +- Nonebot2 版本: +- 适配器版本: +- 使用的 Satori 服务端 (例如 Chronocat): + +## 日志/截图 + diff --git a/.github/ISSUE_TEMPLATE/feature.md b/.github/ISSUE_TEMPLATE/feature.md new file mode 100644 index 0000000..9be8fcc --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature.md @@ -0,0 +1,27 @@ +--- +name: Feature 特性请求 +about: 为适配器加份菜 +title: "[Feature] " +labels: enhancement, triage +assignees: "" +--- + +## 请确认: + +* [ ] 新特性的目的明确 +* [ ] 我已经阅读了[相关文档](https://satori.js.org/zh-CN/) 并且找不到类似特性 + + +## Feature +### 概要 + + + +### 是否已有相关实现 + +暂无 + + +### 其他内容 + +暂无 diff --git a/.github/actions/setup-python/action.yml b/.github/actions/setup-python/action.yml new file mode 100644 index 0000000..ebf8f36 --- /dev/null +++ b/.github/actions/setup-python/action.yml @@ -0,0 +1,21 @@ +name: Setup Python +description: Setup Python + +inputs: + python-version: + description: Python version + required: false + default: "3.10" + +runs: + using: "composite" + steps: + - uses: pdm-project/setup-pdm@v3 + name: Setup PDM + with: + python-version: ${{ inputs.python-version }} + architecture: "x64" + cache: true + + - run: pdm sync -G:all + shell: bash diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 0000000..a341b06 --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,36 @@ +name: Release + +on: + push: + tags: + - v* + +jobs: + release: + runs-on: ubuntu-latest + permissions: + id-token: write + contents: write + steps: + - uses: actions/checkout@v3 + + - name: Setup Python environment + uses: ./.github/actions/setup-python + + - name: Get Version + id: version + run: | + echo "VERSION=$(pdm show --version)" >> $GITHUB_OUTPUT + echo "TAG_VERSION=${GITHUB_REF#refs/tags/v}" >> $GITHUB_OUTPUT + echo "TAG_NAME=${GITHUB_REF#refs/tags/}" >> $GITHUB_OUTPUT + + - name: Check Version + if: steps.version.outputs.VERSION != steps.version.outputs.TAG_VERSION + run: exit 1 + + - name: Publish Package + run: | + pdm publish + gh release upload --clobber ${{ steps.version.outputs.TAG_NAME }} dist/*.tar.gz dist/*.whl + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml new file mode 100644 index 0000000..ec3f9c2 --- /dev/null +++ b/.github/workflows/ruff.yml @@ -0,0 +1,17 @@ +name: Ruff Lint + +on: + push: + branches: + - main + pull_request: + +jobs: + ruff: + name: Ruff Lint + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + + - name: Run Ruff Lint + uses: chartboost/ruff-action@v1 diff --git a/.gitignore b/.gitignore index 68bc17f..4b3ab30 100644 --- a/.gitignore +++ b/.gitignore @@ -108,6 +108,8 @@ ipython_config.py # in version control. # https://pdm.fming.dev/#use-with-ide .pdm.toml +.pdm-python +.pdm-build/ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm __pypackages__/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..73fa8ca --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,26 @@ +default_install_hook_types: [pre-commit, prepare-commit-msg] +ci: + autofix_commit_msg: ":rotating_light: auto fix by pre-commit hooks" + autofix_prs: true + autoupdate_branch: master + autoupdate_schedule: monthly + autoupdate_commit_msg: ":arrow_up: auto update by pre-commit hooks" +repos: + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.0.276 + hooks: + - id: ruff + args: [--fix, --exit-non-zero-on-fix] + stages: [commit] + + - repo: https://github.com/pycqa/isort + rev: 5.12.0 + hooks: + - id: isort + stages: [commit] + + - repo: https://ghproxy.com/github.com/psf/black + rev: 23.3.0 + hooks: + - id: black + stages: [commit] diff --git a/README.md b/README.md index cab78d5..db1a12c 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,71 @@ -# adapter-satori -NoneBot2 Satori 适配器 / Satori Protocol adapter for nonebot2 +
+ +# NoneBot-Adapter-Satori + +_✨ NoneBot2 Satori Protocol适配器 / Satori Protocol Adapter for NoneBot2 ✨_ + +
+ +## 协议介绍 + +[Satori Protocol](https://satori.js.org/zh-CN/) + +## 协议端 + +目前提供了 `satori` 协议实现的有: +- [Chronocat](https://chronocat.vercel.app) +- Satori +- Koishi (搭配 `@koishijs/plugin-server`) + +## 配置 + +修改 NoneBot 配置文件 `.env` 或者 `.env.*`。 + +### Driver + +参考 [driver](https://nonebot.dev/docs/appendices/config#driver) 配置项,添加 `ForwardDriver` 支持。 + +如: + +```dotenv +DRIVER=~httpx+~websockets +DRIVER=~aiohttp +``` + +### SATORI_CLIENTS + +配置连接配置,如: + +```dotenv +SATORI_CLIENTS=' +[ + { + "port": "localhost", + "host": "5500", + "token": "xxx" + } +] +' +``` + +`host` 与 `port` 为 Satori 服务端的监听地址与端口, + +`token` 由 Satori 服务端决定是否需要。 + + +## 示例 + +```python +from nonebot import on_command +from nonebot.adapters.satori import Bot +from nonebot.adapters.satori.event import MessageEvent +from nonebot.adapters.satori.message import MessageSegment + + +matcher = on_command("test") + +@matcher.handle() +async def handle_receive(bot: Bot, event: MessageEvent): + if event.is_private: + await bot.send(event, MessageSegment.text("Hello, world!")) +``` diff --git a/nonebot/adapters/satori/__init__.py b/nonebot/adapters/satori/__init__.py new file mode 100644 index 0000000..2bdb6f9 --- /dev/null +++ b/nonebot/adapters/satori/__init__.py @@ -0,0 +1,5 @@ +from .bot import Bot as Bot +from .adapter import Adapter as Adapter +from .message import Message as Message +from .event import MessageEvent as MessageEvent +from .message import MessageSegment as MessageSegment diff --git a/nonebot/adapters/satori/adapter.py b/nonebot/adapters/satori/adapter.py new file mode 100644 index 0000000..01fbb44 --- /dev/null +++ b/nonebot/adapters/satori/adapter.py @@ -0,0 +1,282 @@ +import asyncio +from typing_extensions import override +from typing import Any, Dict, List, Literal, Optional + +from pydantic import parse_raw_as +from nonebot.utils import escape_tag +from nonebot.exception import WebSocketClosed +from nonebot.drivers import ( + Driver, + Request, + WebSocket, + HTTPClientMixin, + WebSocketClientMixin, +) + +from nonebot.adapters import Adapter as BaseAdapter + +from .bot import Bot +from .utils import API, log +from .config import Config, ClientInfo +from .exception import ApiNotAvailable +from .models import Event as SatoriEvent +from .event import ( + EVENT_CLASSES, + Event, + MessageEvent, + LoginAddedEvent, + LoginRemovedEvent, + LoginUpdatedEvent, +) +from .models import ( + Payload, + PayloadType, + PingPayload, + PongPayload, + EventPayload, + ReadyPayload, + IdentifyPayload, +) + + +class Adapter(BaseAdapter): + @override + def __init__(self, driver: Driver, **kwargs: Any): + super().__init__(driver, **kwargs) + # 读取适配器所需的配置项 + self.satori_config: Config = Config.parse_obj(self.config) + self.tasks: List[asyncio.Task] = [] # 存储 ws 任务 + self.sequences: Dict[str, int] = {} # 存储 连接序列号 + self.setup() + + @classmethod + @override + def get_name(cls) -> str: + """适配器名称""" + return "Satori" + + def setup(self) -> None: + if not isinstance(self.driver, HTTPClientMixin): + # 判断用户配置的Driver类型是否符合适配器要求,不符合时应抛出异常 + raise RuntimeError( + f"Current driver {self.config.driver} " + f"doesn't support http client requests!" + f"{self.get_name()} Adapter need a HTTPClient Driver to work." + ) + if not isinstance(self.driver, WebSocketClientMixin): + raise RuntimeError( + f"Current driver {self.config.driver} does not support " + "websocket client! " + f"{self.get_name()} Adapter need a WebSocketClient Driver to work." + ) + # 在 NoneBot 启动和关闭时进行相关操作 + self.driver.on_startup(self.startup) + self.driver.on_shutdown(self.shutdown) + + async def startup(self) -> None: + """定义启动时的操作,例如和平台建立连接""" + for client in self.satori_config.satori_clients: + self.tasks.append(asyncio.create_task(self.ws(client))) + + async def shutdown(self) -> None: + for task in self.tasks: + if not task.done(): + task.cancel() + + await asyncio.gather( + *(asyncio.wait_for(task, timeout=10) for task in self.tasks), + return_exceptions=True, + ) + + @staticmethod + def payload_to_json(payload: Payload) -> str: + return payload.__config__.json_dumps( + payload.dict(), default=payload.__json_encoder__ + ) + + async def receive_payload(self, info: ClientInfo, ws: WebSocket) -> Payload: + payload = parse_raw_as(PayloadType, await ws.receive()) + if isinstance(payload, EventPayload): + self.sequences[info.identity] = payload.body.id + return payload + + async def _authenticate( + self, info: ClientInfo, ws: WebSocket + ) -> Optional[Literal[True]]: + """鉴权连接""" + payload = IdentifyPayload.parse_obj( + { + "body": { + "token": info.token, + }, + } + ) + if info.identity in self.sequences: + payload.body.sequence = self.sequences[info.identity] + + try: + await ws.send(self.payload_to_json(payload)) + except Exception as e: + log( + "ERROR", + "Error while sending " + + "Identify event", + e, + ) + return + + resp = await self.receive_payload(info, ws) + if not isinstance(resp, ReadyPayload): + log( + "ERROR", + "Received unexpected payload while authenticating: " + f"{escape_tag(repr(resp))}", + ) + return + for login in resp.body.logins: + if login.self_id not in self.bots: + bot = Bot(self, login.self_id, login.platform, info) + self.bot_connect(bot) + log( + "INFO", + f"Bot {escape_tag(bot.self_id)} connected", + ) + else: + bot = self.bots[login.self_id] + bot.on_ready(login.user) + + return True + + async def _heartbeat(self, info: ClientInfo, ws: WebSocket): + """心跳""" + while True: + log("TRACE", f"Heartbeat {self.sequences[info.identity]}") + payload = PingPayload.parse_obj({}) + try: + await ws.send(self.payload_to_json(payload)) + except Exception as e: + log("WARNING", "Error while sending heartbeat, Ignored!", e) + await asyncio.sleep(9) + + async def ws(self, info: ClientInfo) -> None: + ws_url = f"ws://{info.host}:{info.port}/v1/events" + req = Request("GET", ws_url, timeout=60.0) + heartbeat_task: Optional["asyncio.Task"] = None + while True: + try: + async with self.websocket(req) as ws: + log( + "DEBUG", + f"WebSocket Connection to " + f"{escape_tag(str(ws_url))} established", + ) + try: + if not await self._authenticate(info, ws): + await asyncio.sleep(3) + continue + heartbeat_task = asyncio.create_task(self._heartbeat(info, ws)) + await self._loop(info, ws) + except WebSocketClosed as e: + log( + "ERROR", + "WebSocket Closed", + e, + ) + except Exception as e: + log( + "ERROR", + "Error while process data from websocket " + f"{escape_tag(str(ws_url))}. " + f"Trying to reconnect...", + e, + ) + finally: + if heartbeat_task: + heartbeat_task.cancel() + heartbeat_task = None + bots = list(self.bots.values()) + for bot in bots: + self.bot_disconnect(bot) + bots.clear() + except Exception as e: + log( + "ERROR", + ( + "" + "Error while setup websocket to " + f"{escape_tag(str(ws_url))}. Trying to reconnect..." + "" + ), + e, + ) + await asyncio.sleep(3) # 重连间隔 + + async def _loop(self, info: ClientInfo, ws: WebSocket): + while True: + payload = await self.receive_payload(info, ws) + log( + "TRACE", + f"Received payload: {escape_tag(repr(payload))}", + ) + if isinstance(payload, EventPayload): + try: + event = self.payload_to_event(payload.body) + except Exception as e: + log( + "WARNING", + f"Failed to parse event {escape_tag(repr(payload))}", + e, + ) + else: + if isinstance(event, LoginAddedEvent): + bot = Bot(self, event.self_id, event.platform, info) + bot.on_ready(event.user) + self.bot_connect(bot) + log( + "INFO", + f"Bot {escape_tag(bot.self_id)} connected", + ) + elif isinstance(event, LoginRemovedEvent): + self.bot_disconnect(self.bots[event.self_id]) + log( + "INFO", + f"Bot {escape_tag(event.self_id)} disconnected", + ) + continue + elif isinstance(event, LoginUpdatedEvent): + self.bots[event.self_id].on_ready(event.user) + if not (bot := self.bots.get(event.self_id)): + log( + "WARNING", + f"Received event for unknown bot " + f"{escape_tag(event.self_id)}", + ) + continue + if isinstance(event, MessageEvent): + event = event.convert() + asyncio.create_task(bot.handle_event(event)) + elif isinstance(payload, PongPayload): + log("TRACE", "Pong") + continue + else: + log( + "WARNING", + f"Unknown payload from server: {escape_tag(repr(payload))}", + ) + + @staticmethod + def payload_to_event(payload: SatoriEvent) -> Event: + EventClass = EVENT_CLASSES.get(payload.type, None) + if EventClass is None: + log("WARNING", f"Unknown payload type: {payload.type}") + event = Event.parse_obj(payload) + return event + return EventClass.parse_obj(payload) + + @override + async def _call_api(self, bot: Bot, api: str, **data: Any) -> Any: + log("DEBUG", f"Bot {bot.self_id} calling API {api}") + api_handler: Optional[API] = getattr(bot.__class__, api, None) + if api_handler is None: + raise ApiNotAvailable + return await api_handler(bot, **data) diff --git a/nonebot/adapters/satori/bot.py b/nonebot/adapters/satori/bot.py new file mode 100644 index 0000000..dd665aa --- /dev/null +++ b/nonebot/adapters/satori/bot.py @@ -0,0 +1,544 @@ +import re +import json +from typing_extensions import override +from typing import TYPE_CHECKING, Any, Dict, List, Union, Optional + +from nonebot.message import handle_event +from nonebot.drivers import Request, Response + +from nonebot.adapters import Bot as BaseBot + +from .utils import API, log +from .config import ClientInfo +from .event import Event, MessageEvent +from .message import Message, MessageSegment +from .models import InnerMessage as SatoriMessage +from .models import Role, User, Guild, Channel, OuterLogin, OuterMember +from .exception import ( + ActionFailed, + NetworkError, + NotFoundException, + ForbiddenException, + BadRequestException, + UnauthorizedException, + MethodNotAllowedException, +) + +if TYPE_CHECKING: + from .adapter import Adapter + + +def _check_reply( + bot: "Bot", + event: MessageEvent, +) -> None: + """检查消息中存在的回复,赋值 `event.reply`, `event.to_me`。 + + 参数: + bot: Bot 对象 + event: MessageEvent 对象 + """ + message = event.get_message() + try: + index = message.index("quote") + except ValueError: + return + + msg_seg = message[index] + + event.reply = msg_seg # type: ignore + + del message[index] + if ( + len(message) > index + and message[index].type == "at" + and message[index].data.get("id") == str(bot.self_info.id) + ): + del message[index] + if len(message) > index and message[index].type == "text": + message[index].data["text"] = message[index].data["text"].lstrip() + if not message[index].data["text"]: + del message[index] + if not message: + message.append(MessageSegment.text("")) + + +def _check_at_me( + bot: "Bot", + event: MessageEvent, +): + def _is_at_me_seg(segment: MessageSegment) -> bool: + return segment.type == "at" and segment.data.get("id") == str(bot.self_info.id) + + message = event.get_message() + + # ensure message is not empty + if not message: + message.append(MessageSegment.text("")) + + deleted = False + if _is_at_me_seg(message[0]): + message.pop(0) + deleted = True + if message and message[0].type == "text": + message[0].data["text"] = message[0].data["text"].lstrip("\xa0").lstrip() + if not message[0].data["text"]: + del message[0] + + if not deleted: + # check the last segment + i = -1 + last_msg_seg = message[i] + if ( + last_msg_seg.type == "text" + and not last_msg_seg.data["text"].strip() + and len(message) >= 2 + ): + i -= 1 + last_msg_seg = message[i] + + if _is_at_me_seg(last_msg_seg): + deleted = True + del message[i:] + + if not message: + message.append(MessageSegment.text("")) + + +def _check_nickname(bot: "Bot", event: MessageEvent) -> None: + """检查消息开头是否存在昵称,去除并赋值 `event.to_me`。 + + 参数: + bot: Bot 对象 + event: MessageEvent 对象 + """ + message = event.get_message() + first_msg_seg = message[0] + if first_msg_seg.type != "text": + return + + nicknames = {re.escape(n) for n in bot.config.nickname} + if not nicknames: + return + + # check if the user is calling me with my nickname + nickname_regex = "|".join(nicknames) + first_text = first_msg_seg.data["text"] + if m := re.search(rf"^({nickname_regex})([\s,,]*|$)", first_text, re.IGNORECASE): + log("DEBUG", f"User is calling me {m[1]}") + event.to_me = True + first_msg_seg.data["text"] = first_text[m.end() :] + + +class Bot(BaseBot): + adapter: "Adapter" + + @override + def __init__( + self, adapter: "Adapter", self_id: str, platform: str, info: ClientInfo + ): + super().__init__(adapter, self_id) + + # Bot 配置信息 + self.info: ClientInfo = info + # Bot 自身所属平台 + self.platform: str = platform + # Bot 自身信息 + self._self_info: Optional[User] = None + + def __getattr__(self, item): + raise AttributeError(f"'Bot' object has no attribute '{item}'") + + @property + def ready(self) -> bool: + """Bot 是否已连接""" + return self._self_info is not None + + @property + def self_info(self) -> User: + """Bot 自身信息,仅当 Bot 连接鉴权完成后可用""" + if self._self_info is None: + raise RuntimeError( + f"Bot {self.self_id} of {self.platform} is not connected!" + ) + return self._self_info + + def on_ready(self, user: User) -> None: + self._self_info = user + + def get_authorization_header(self) -> Dict[str, str]: + """获取当前 Bot 的鉴权信息""" + return { + "Authorization": f"Bearer {self.info.token}", + "X-Self-ID": self.self_id, + "X-Platform": self.platform, + } + + async def handle_event(self, event: Event) -> None: + if isinstance(event, MessageEvent): + _check_reply(self, event) + _check_at_me(self, event) + _check_nickname(self, event) + await handle_event(self, event) + + def _handle_response(self, response: Response) -> Any: + if 200 <= response.status_code < 300: + return response.content and json.loads(response.content) + elif response.status_code == 400: + raise BadRequestException(response) + elif response.status_code == 401: + raise UnauthorizedException(response) + elif response.status_code == 403: + raise ForbiddenException(response) + elif response.status_code == 404: + raise NotFoundException(response) + elif response.status_code == 405: + raise MethodNotAllowedException(response) + else: + raise ActionFailed(response) + + async def _request(self, request: Request) -> Any: + request.headers.update(self.get_authorization_header()) + + try: + response = await self.adapter.request(request) + except Exception as e: + raise NetworkError("API request failed") from e + + return self._handle_response(response) + + @override + async def send( + self, + event: Event, + message: Union[str, Message, MessageSegment], + **kwargs, + ) -> List[SatoriMessage]: + if not event.channel: + raise RuntimeError("Event cannot be replied to!") + return await self.send_message(event.channel.id, message) + + async def send_message( + self, + channel_id: str, + message: Union[str, Message, MessageSegment], + ): + """发送消息 + + 参数: + channel_id: 要发送的频道 ID + message: 要发送的消息 + """ + return await self.message_create(channel_id=channel_id, content=str(message)) + + async def update_message( + self, + channel_id: str, + message_id: str, + message: Union[str, Message, MessageSegment], + ): + """更新消息 + + 参数: + channel_id: 要更新的频道 ID + message_id: 要更新的消息 ID + message: 要更新的消息 + """ + return await self.message_update( + channel_id=channel_id, message_id=message_id, content=str(message) + ) + + @API + async def message_create( + self, + *, + channel_id: str, + content: str, + ): + request = Request( + "POST", + self.info.api_base / "message.create", + json={"channel_id": channel_id, "content": content}, + ) + res = await self._request(request) + return [SatoriMessage.parse_obj(i) for i in res] + + @API + async def message_get(self, *, channel_id: str, message_id: str): + request = Request( + "POST", + self.info.api_base / "message.get", + json={"channel_id": channel_id, "message_id": message_id}, + ) + res = await self._request(request) + return SatoriMessage.parse_obj(res) + + @API + async def message_delete(self, *, channel_id: str, message_id: str): + request = Request( + "POST", + self.info.api_base / "message.delete", + json={"channel_id": channel_id, "message_id": message_id}, + ) + await self._request(request) + + @API + async def message_update( + self, + *, + channel_id: str, + message_id: str, + content: str, + ): + request = Request( + "POST", + self.info.api_base / "message.update", + json={ + "channel_id": channel_id, + "message_id": message_id, + "content": content, + }, + ) + await self._request(request) + + @API + async def message_list(self, *, channel_id: str, next_token: Optional[str] = None): + request = Request( + "POST", + self.info.api_base / "message.list", + json={"channel_id": channel_id, "next": next_token}, + ) + res = await self._request(request) + return [SatoriMessage.parse_obj(i) for i in res] + + @API + async def channel_get(self, *, channel_id: str): + request = Request( + "POST", + self.info.api_base / "channel.get", + json={"channel_id": channel_id}, + ) + res = await self._request(request) + return Channel.parse_obj(res) + + @API + async def channel_list(self, *, guild_id: str, next_token: Optional[str] = None): + request = Request( + "POST", + self.info.api_base / "channel.list", + json={"guild_id": guild_id, "next": next_token}, + ) + res = await self._request(request) + return [Channel.parse_obj(i) for i in res] + + @API + async def channel_create(self, *, guild_id: str, data: Channel): + request = Request( + "POST", + self.info.api_base / "channel.create", + json={"guild_id": guild_id, "data": data.dict()}, + ) + return Channel.parse_obj(await self._request(request)) + + @API + async def channel_update( + self, + *, + channel_id: str, + data: Channel, + ): + request = Request( + "POST", + self.info.api_base / "channel.update", + json={"channel_id": channel_id, "data": data.dict()}, + ) + await self._request(request) + + @API + async def channel_delete(self, *, channel_id: str): + request = Request( + "POST", + self.info.api_base / "channel.delete", + json={"channel_id": channel_id}, + ) + await self._request(request) + + @API + async def user_channel_create(self, *, user_id: str): + request = Request( + "POST", + self.info.api_base / "user.channel.create", + json={"user_id": user_id}, + ) + return Channel.parse_obj(await self._request(request)) + + @API + async def guild_get(self, *, guild_id: str): + request = Request( + "POST", + self.info.api_base / "guild.get", + json={"guild_id": guild_id}, + ) + return Guild.parse_obj(await self._request(request)) + + @API + async def guild_list(self, *, next_token: Optional[str] = None): + request = Request( + "POST", + self.info.api_base / "guild.list", + json={"next": next_token}, + ) + return [Guild.parse_obj(i) for i in await self._request(request)] + + @API + async def guild_approve(self, *, request_id: str, approve: bool, comment: str): + request = Request( + "POST", + self.info.api_base / "guild.approve", + json={"message_id": request_id, "approve": approve, "comment": comment}, + ) + await self._request(request) + + @API + async def guild_member_list( + self, *, guild_id: str, next_token: Optional[str] = None + ): + request = Request( + "POST", + self.info.api_base / "guild.member.list", + json={"guild_id": guild_id, "next": next_token}, + ) + return [OuterMember.parse_obj(i) for i in await self._request(request)] + + @API + async def guild_member_get(self, *, guild_id: str, user_id: str): + request = Request( + "POST", + self.info.api_base / "guild.member.get", + json={"guild_id": guild_id, "user_id": user_id}, + ) + return OuterMember.parse_obj(await self._request(request)) + + @API + async def guild_member_kick( + self, *, guild_id: str, user_id: str, permanent: bool = False + ): + request = Request( + "POST", + self.info.api_base / "guild.member.kick", + json={"guild_id": guild_id, "user_id": user_id, "permanent": permanent}, + ) + await self._request(request) + + @API + async def guild_member_approve( + self, *, request_id: str, approve: bool, comment: str + ): + request = Request( + "POST", + self.info.api_base / "guild.member.approve", + json={"message_id": request_id, "approve": approve, "comment": comment}, + ) + await self._request(request) + + @API + async def guild_member_role_set(self, *, guild_id: str, user_id: str, role_id: str): + request = Request( + "POST", + self.info.api_base / "guild.member.role.set", + json={"guild_id": guild_id, "user_id": user_id, "role_id": role_id}, + ) + await self._request(request) + + @API + async def guild_member_role_unset( + self, *, guild_id: str, user_id: str, role_id: str + ): + request = Request( + "POST", + self.info.api_base / "guild.member.role.unset", + json={"guild_id": guild_id, "user_id": user_id, "role_id": role_id}, + ) + await self._request(request) + + @API + async def guild_role_list(self, guild_id: str, next_token: Optional[str] = None): + request = Request( + "POST", + self.info.api_base / "guild.role.list", + json={"guild_id": guild_id, "next": next_token}, + ) + return [Role.parse_obj(i) for i in await self._request(request)] + + @API + async def guild_role_create( + self, + *, + guild_id: str, + role: Role, + ): + request = Request( + "POST", + self.info.api_base / "guild.role.create", + json={"guild_id": guild_id, "role": role.dict()}, + ) + return Role.parse_obj(await self._request(request)) + + @API + async def guild_role_update( + self, + *, + guild_id: str, + role_id: str, + role: Role, + ): + request = Request( + "POST", + self.info.api_base / "guild.role.update", + json={"guild_id": guild_id, "role_id": role_id, "role": role.dict()}, + ) + await self._request(request) + + @API + async def guild_role_delete(self, *, guild_id: str, role_id: str): + request = Request( + "POST", + self.info.api_base / "guild.role.delete", + json={"guild_id": guild_id, "role_id": role_id}, + ) + await self._request(request) + + @API + async def login_get(self): + request = Request( + "POST", + self.info.api_base / "login.get", + ) + return OuterLogin.parse_obj(await self._request(request)) + + @API + async def user_get(self, *, user_id: str): + request = Request( + "POST", + self.info.api_base / "user.get", + json={"user_id": user_id}, + ) + return User.parse_obj(await self._request(request)) + + @API + async def friend_list(self, *, next_token: Optional[str] = None): + request = Request( + "POST", + self.info.api_base / "friend.list", + json={"next": next_token}, + ) + return [User.parse_obj(i) for i in await self._request(request)] + + @API + async def friend_approve(self, *, request_id: str, approve: bool, comment: str): + request = Request( + "POST", + self.info.api_base / "friend.approve", + json={"message_id": request_id, "approve": approve, "comment": comment}, + ) + await self._request(request) diff --git a/nonebot/adapters/satori/config.py b/nonebot/adapters/satori/config.py new file mode 100644 index 0000000..56038da --- /dev/null +++ b/nonebot/adapters/satori/config.py @@ -0,0 +1,27 @@ +from typing import List, Optional + +from yarl import URL +from pydantic import Extra, Field, BaseModel + + +class ClientInfo(BaseModel): + host: str = "localhost" + port: int + token: Optional[str] = None + + @property + def identity(self): + return f"{self.host}:{self.port}#{self.token}" + + @property + def api_base(self): + return URL(f"http://{self.host}:{self.port}") / "v1" + + @property + def ws_base(self): + return URL(f"ws://{self.host}:{self.port}") / "v1" + + +class Config(BaseModel, extra=Extra.ignore): + satori_clients: List[ClientInfo] = Field(default_factory=list) + """client 配置""" diff --git a/nonebot/adapters/satori/event.py b/nonebot/adapters/satori/event.py new file mode 100644 index 0000000..754dd4f --- /dev/null +++ b/nonebot/adapters/satori/event.py @@ -0,0 +1,421 @@ +from enum import Enum +from copy import deepcopy +from typing_extensions import override +from typing import TYPE_CHECKING, Any, Dict, Type, TypeVar, Optional + +from pydantic import root_validator +from nonebot.utils import escape_tag + +from nonebot.adapters import Event as BaseEvent + +from .models import Role, User +from .models import Event as SatoriEvent +from .message import Message, RenderMessage +from .models import InnerMessage as SatoriMessage +from .models import Guild, Channel, InnerLogin, ChannelType, InnerMember + +E = TypeVar("E", bound="Event") + + +class EventType(str, Enum): + FRIEND_REQUEST = "friend-request" + GUILD_ADDED = "guild-added" + GUILD_MEMBER_ADDED = "guild-member-added" + GUILD_MEMBER_REMOVED = "guild-member-removed" + GUILD_MEMBER_REQUEST = "guild-member-request" + GUILD_MEMBER_UPDATED = "guild-member-updated" + GUILD_REMOVED = "guild-removed" + GUILD_REQUEST = "guild-request" + GUILD_ROLE_CREATED = "guild-role-created" + GUILD_ROLE_DELETED = "guild-role-deleted" + GUILD_ROLE_UPDATED = "guild-role-updated" + GUILD_UPDATED = "guild-updated" + LOGIN_ADDED = "login-added" + LOGIN_REMOVED = "login-removed" + LOGIN_UPDATED = "login-updated" + MESSAGE_CREATED = "message-created" + MESSAGE_DELETED = "message-deleted" + MESSAGE_UPDATED = "message-updated" + REACTION_ADDED = "reaction-added" + REACTION_REMOVED = "reaction-removed" + + +class Event(BaseEvent, SatoriEvent): + __type__: EventType + + @override + def get_type(self) -> str: + return "" + + @override + def get_event_name(self) -> str: + return self.type + + @override + def get_event_description(self) -> str: + return escape_tag(str(self.dict())) + + @override + def get_message(self) -> Message: + raise ValueError("Event has no message!") + + @override + def get_user_id(self) -> str: + raise ValueError("Event has no context!") + + @override + def get_session_id(self) -> str: + raise ValueError("Event has no context!") + + @override + def is_tome(self) -> bool: + return False + + +EVENT_CLASSES: Dict[str, Type[Event]] = {} + + +def register_event_class(event_class: Type[E]) -> Type[E]: + EVENT_CLASSES[event_class.__type__.value] = event_class + return event_class + + +class NoticeEvent(Event): + @override + def get_type(self) -> str: + return "notice" + + +class FriendEvent(NoticeEvent): + channel: Channel + user: User + + @override + def get_user_id(self) -> str: + return self.user.id + + @override + def get_session_id(self) -> str: + return self.user.id + + +@register_event_class +class FriendRequestEvent(FriendEvent): + __type__ = EventType.FRIEND_REQUEST + + +class GuildEvent(NoticeEvent): + channel: Channel + guild: Guild + + @override + def get_session_id(self) -> str: + return f"{self.guild.id}/{self.channel.id}" + + +@register_event_class +class GuildAddedEvent(GuildEvent): + __type__ = EventType.GUILD_ADDED + + +@register_event_class +class GuildRemovedEvent(GuildEvent): + __type__ = EventType.GUILD_REMOVED + + +@register_event_class +class GuildRequestEvent(GuildEvent): + __type__ = EventType.GUILD_REQUEST + + +@register_event_class +class GuildUpdatedEvent(GuildEvent): + __type__ = EventType.GUILD_UPDATED + + +class GuildInnerMemberEvent(GuildEvent): + user: User + + @override + def get_user_id(self) -> str: + return self.member.user.id if self.member else self.user.id + + @override + def get_session_id(self) -> str: + return f"{self.guild.id}/{self.channel.id}/{self.get_user_id()}" + + +@register_event_class +class GuildInnerMemberAddedEvent(GuildInnerMemberEvent): + __type__ = EventType.GUILD_MEMBER_ADDED + + +@register_event_class +class GuildInnerMemberRemovedEvent(GuildInnerMemberEvent): + member: InnerMember + __type__ = EventType.GUILD_MEMBER_REMOVED + + +@register_event_class +class GuildInnerMemberRequestEvent(GuildInnerMemberEvent): + __type__ = EventType.GUILD_MEMBER_REQUEST + + +@register_event_class +class GuildInnerMemberUpdatedEvent(GuildInnerMemberEvent): + member: InnerMember + __type__ = EventType.GUILD_MEMBER_UPDATED + + +class GuildRoleEvent(GuildEvent): + role: Role + + @override + def get_session_id(self) -> str: + return f"{self.guild.id}/{self.channel.id}/{self.role.id}" + + +@register_event_class +class GuildRoleCreatedEvent(GuildRoleEvent): + __type__ = EventType.GUILD_ROLE_CREATED + + +@register_event_class +class GuildRoleDeletedEvent(GuildRoleEvent): + __type__ = EventType.GUILD_ROLE_DELETED + + +@register_event_class +class GuildRoleUpdatedEvent(GuildRoleEvent): + __type__ = EventType.GUILD_ROLE_UPDATED + + +class LoginEvent(NoticeEvent): + login: InnerLogin + user: User + + @override + def get_user_id(self) -> str: + return self.user.id + + @override + def get_session_id(self) -> str: + return self.user.id + + +@register_event_class +class LoginAddedEvent(LoginEvent): + __type__ = EventType.LOGIN_ADDED + + +@register_event_class +class LoginRemovedEvent(LoginEvent): + __type__ = EventType.LOGIN_REMOVED + + +@register_event_class +class LoginUpdatedEvent(LoginEvent): + __type__ = EventType.LOGIN_UPDATED + + +class MessageEvent(Event): + channel: Channel + user: User + message: SatoriMessage + to_me: bool = False + reply: Optional[RenderMessage] = None + + if TYPE_CHECKING: + _message: Message + original_message: Message + + @override + def get_type(self) -> str: + return "message" + + @override + def is_tome(self) -> bool: + return self.to_me + + @override + def get_message(self) -> Message: + return self._message + + @root_validator + def generate_message(cls, values: Dict[str, Any]) -> Dict[str, Any]: + values["_message"] = Message.from_satori_element(values["message"].content) + values["original_message"] = deepcopy(values["_message"]) + return values + + @property + def msg_id(self) -> str: + return self.message.id + + def convert(self) -> "MessageEvent": + raise NotImplementedError + + +@register_event_class +class MessageCreatedEvent(MessageEvent): + __type__ = EventType.MESSAGE_CREATED + + def convert(self): + if self.channel.type == ChannelType.DIRECT: + return PrivateMessageCreatedEvent.parse_obj(self) + else: + return PublicMessageCreatedEvent.parse_obj(self) + + +@register_event_class +class MessageDeletedEvent(MessageEvent): + __type__ = EventType.MESSAGE_DELETED + + def convert(self): + if self.channel.type == ChannelType.DIRECT: + return PrivateMessageDeletedEvent.parse_obj(self) + else: + return PublicMessageDeletedEvent.parse_obj(self) + + +@register_event_class +class MessageUpdatedEvent(MessageEvent): + __type__ = EventType.MESSAGE_UPDATED + + def convert(self): + if self.channel.type == ChannelType.DIRECT: + return PrivateMessageUpdatedEvent.parse_obj(self) + else: + return PublicMessageUpdatedEvent.parse_obj(self) + + +class PrivateMessageEvent(MessageEvent): + @override + def is_tome(self) -> bool: + return True + + @override + def get_session_id(self) -> str: + return self.channel.id + + @override + def get_user_id(self) -> str: + return self.user.id + + +class PublicMessageEvent(MessageEvent): + guild: Guild + member: InnerMember + + @override + def get_session_id(self) -> str: + return f"{self.guild.id}/{self.channel.id}/{self.user.id}" + + @override + def get_user_id(self) -> str: + return self.user.id + + +class PrivateMessageCreatedEvent(MessageCreatedEvent, PrivateMessageEvent): + @override + def get_event_description(self) -> str: + return escape_tag( + f"Message {self.msg_id} from " + f"{self.user.name}({self.channel.id}): {self.get_message()!r}" + ) + + +class PublicMessageCreatedEvent(MessageCreatedEvent, PublicMessageEvent): + @override + def get_event_description(self) -> str: + return escape_tag( + f"Message {self.msg_id} from " + f"{self.user.name or ''}({self.channel.id})" + f"@[{self.channel.name}:{self.guild.id}/{self.channel.id}]" + f": {self.get_message()!r}" + ) + + +class PrivateMessageDeletedEvent(MessageDeletedEvent, PrivateMessageEvent): + @override + def get_event_description(self) -> str: + return escape_tag( + f"Message {self.msg_id} from " + f"{self.user.name}({self.channel.id}) deleted" + ) + + +class PublicMessageDeletedEvent(MessageDeletedEvent, PublicMessageEvent): + @override + def get_event_description(self) -> str: + return escape_tag( + f"Message {self.msg_id} from " + f"{self.user.name or ''}({self.channel.id})" + f"@[{self.channel.name}:{self.guild.id}/{self.channel.id}] deleted" + ) + + +class PrivateMessageUpdatedEvent(MessageUpdatedEvent, PrivateMessageEvent): + @override + def get_event_description(self) -> str: + return escape_tag( + f"Message {self.msg_id} from " + f"{self.user.name}({self.channel.id}) updated" + f": {self.get_message()!r}" + ) + + +class PublicMessageUpdatedEvent(MessageUpdatedEvent, PublicMessageEvent): + @override + def get_event_description(self) -> str: + return escape_tag( + f"Message {self.msg_id} from " + f"{self.user.name or ''}({self.channel.id})" + f"@[{self.channel.name}:{self.guild.id}/{self.channel.id}] updated" + f": {self.get_message()!r}" + ) + + +class ReactionEvent(NoticeEvent): + channel: Channel + user: User + message: SatoriMessage + + if TYPE_CHECKING: + _message: Message + + @override + def get_user_id(self) -> str: + return self.user.id + + @override + def get_session_id(self) -> str: + return f"{self.channel.id}/{self.user.id}" + + @root_validator + def generate_message(cls, values: Dict[str, Any]) -> Dict[str, Any]: + values["_message"] = Message.from_satori_element(values["message"]["content"]) + return values + + @property + def msg_id(self) -> str: + return self.message.id + + +@register_event_class +class ReactionAddedEvent(ReactionEvent): + __type__ = EventType.REACTION_ADDED + + @override + def get_event_description(self) -> str: + return escape_tag( + f"Reaction added to {self.msg_id} by {self.user.name}({self.channel.id})" + ) + + +@register_event_class +class ReactionRemovedEvent(ReactionEvent): + __type__ = EventType.REACTION_REMOVED + + @override + def get_event_description(self) -> str: + return escape_tag(f"Reaction removed from {self.msg_id}") diff --git a/nonebot/adapters/satori/exception.py b/nonebot/adapters/satori/exception.py new file mode 100644 index 0000000..5d6e9aa --- /dev/null +++ b/nonebot/adapters/satori/exception.py @@ -0,0 +1,75 @@ +import json +from typing import Optional + +from nonebot.drivers import Response +from nonebot.exception import AdapterException +from nonebot.exception import ActionFailed as BaseActionFailed +from nonebot.exception import NetworkError as BaseNetworkError +from nonebot.exception import ApiNotAvailable as BaseApiNotAvailable + + +class SatoriAdapterException(AdapterException): + def __init__(self): + super().__init__("satori") + + +class ActionFailed(BaseActionFailed, SatoriAdapterException): + def __init__(self, response: Response): + self.status_code: int = response.status_code + self.code: Optional[int] = None + self.message: Optional[str] = None + self.data: Optional[dict] = None + if response.content: + body = json.loads(response.content) + self._prepare_body(body) + + def __repr__(self) -> str: + return ( + f"" + ) + + def __str__(self): + return self.__repr__() + + def _prepare_body(self, body: dict): + self.code = body.get("code", None) + self.message = body.get("message", None) + self.data = body.get("data", None) + + +class BadRequestException(ActionFailed): + pass + + +class UnauthorizedException(ActionFailed): + pass + + +class ForbiddenException(ActionFailed): + pass + + +class NotFoundException(ActionFailed): + pass + + +class MethodNotAllowedException(ActionFailed): + pass + + +class NetworkError(BaseNetworkError, SatoriAdapterException): + def __init__(self, msg: Optional[str] = None): + super().__init__() + self.msg: Optional[str] = msg + """错误原因""" + + def __repr__(self): + return f"" + + def __str__(self): + return self.__repr__() + + +class ApiNotAvailable(BaseApiNotAvailable, SatoriAdapterException): + pass diff --git a/nonebot/adapters/satori/message.py b/nonebot/adapters/satori/message.py new file mode 100644 index 0000000..2d95730 --- /dev/null +++ b/nonebot/adapters/satori/message.py @@ -0,0 +1,422 @@ +from dataclasses import field, dataclass +from typing_extensions import NotRequired, override +from typing import Any, List, Type, Union, Iterable, Optional, TypedDict + +from nonebot.adapters import Message as BaseMessage +from nonebot.adapters import MessageSegment as BaseMessageSegment + +from .utils import Element, parse, escape + + +class MessageSegment(BaseMessageSegment["Message"]): + def __str__(self) -> str: + def _attr(key: str, value: Any): + if value is True: + return key + if value is False: + return f"no-{key}" + if isinstance(value, (int, float)): + return f"{key}={value}" + return f'{key}="{escape(str(value))}"' + + attrs = " ".join(_attr(k, v) for k, v in self.data.items()) + return f"<{self.type} {attrs} />" + + @classmethod + @override + def get_message_class(cls) -> Type["Message"]: + return Message + + @staticmethod + def text(content: str) -> "Text": + return Text("text", {"text": content}) + + @staticmethod + def entity(content: str, style: str) -> "Entity": + return Entity("entity", {"text": content, "style": style}) + + @staticmethod + def at( + user_id: str, + name: Optional[str] = None, + ) -> "At": + data = {"id": user_id} + if name: + data["name"] = name + return At("at", data) + + @staticmethod + def at_role( + role: str, + name: Optional[str] = None, + ) -> "At": + data = {"role": role} + if name: + data["name"] = name + return At("at_role", data) + + @staticmethod + def at_all(here: bool = False) -> "At": + return At("at", {"type": "here" if here else "all"}) + + @staticmethod + def sharp(channel_id: str, name: Optional[str] = None) -> "Sharp": + data = {"id": channel_id} + if name: + data["name"] = name + return Sharp("sharp", data) + + @staticmethod + def link(href: str) -> "Link": + return Link("link", {"href": href}) + + @staticmethod + def image( + src: str, cache: Optional[bool] = None, timeout: Optional[str] = None + ) -> "Image": + data = {"src": src} + if cache is not None: + data["cache"] = cache + if timeout is not None: + data["timeout"] = timeout + return Image("img", data) + + @staticmethod + def audio( + src: str, cache: Optional[bool] = None, timeout: Optional[str] = None + ) -> "Audio": + data = {"src": src} + if cache is not None: + data["cache"] = cache + if timeout is not None: + data["timeout"] = timeout + return Audio("audio", data) + + @staticmethod + def video( + src: str, cache: Optional[bool] = None, timeout: Optional[str] = None + ) -> "Video": + data = {"src": src} + if cache is not None: + data["cache"] = cache + if timeout is not None: + data["timeout"] = timeout + return Video("video", data) + + @staticmethod + def file( + src: str, cache: Optional[bool] = None, timeout: Optional[str] = None + ) -> "File": + data = {"src": src} + if cache is not None: + data["cache"] = cache + if timeout is not None: + data["timeout"] = timeout + return File("file", data) + + @staticmethod + def br() -> "Br": + return Br("br", {}) + + @staticmethod + def paragraph(text: str) -> "Paragraph": + return Paragraph("paragraph", {"text": text}) + + @staticmethod + def message( + mid: Optional[str] = None, + forward: Optional[bool] = None, + content: Optional["Message"] = None, + ) -> "RenderMessage": + data = {} + if mid: + data["id"] = mid + if forward is not None: + data["forward"] = forward + if content: + data["content"] = content + return RenderMessage("message", data) + + @staticmethod + def quote( + mid: str, + forward: Optional[bool] = None, + content: Optional["Message"] = None, + ) -> "RenderMessage": + data = {"id": mid} + if forward is not None: + data["forward"] = forward + if content: + data["content"] = content + return RenderMessage("quote", data) + + @staticmethod + def author( + user_id: str, + nickname: Optional[str] = None, + avatar: Optional[str] = None, + ) -> "Author": + data = {"id": user_id} + if nickname: + data["nickname"] = nickname + if avatar: + data["avatar"] = avatar + return Author("author", data) + + @override + def is_text(self) -> bool: + return self.type == "text" + + +class TextData(TypedDict): + text: str + + +@dataclass +class Text(MessageSegment): + data: TextData = field(default_factory=dict) + + @override + def __str__(self) -> str: + return escape(self.data["text"]) + + +class EntityData(TypedDict): + text: str + style: str + + +@dataclass +class Entity(MessageSegment): + data: EntityData = field(default_factory=dict) + + @override + def __str__(self) -> str: + style = self.data["style"] + return f'<{style}>{escape(self.data["text"])}' + + +class AtData(TypedDict): + id: NotRequired[str] + name: NotRequired[str] + role: NotRequired[str] + type: NotRequired[str] + + +@dataclass +class At(MessageSegment): + data: AtData = field(default_factory=dict) + + +class SharpData(TypedDict): + id: str + name: NotRequired[str] + + +@dataclass +class Sharp(MessageSegment): + data: SharpData = field(default_factory=dict) + + +class LinkData(TypedDict): + href: str + + +@dataclass +class Link(MessageSegment): + data: LinkData = field(default_factory=dict) + + @override + def __str__(self): + return f'' + + +class ImageData(TypedDict): + src: str + cache: NotRequired[bool] + timeout: NotRequired[str] + width: NotRequired[int] + height: NotRequired[int] + + +@dataclass +class Image(MessageSegment): + data: ImageData = field(default_factory=dict) + + +class AudioData(TypedDict): + src: str + cache: NotRequired[bool] + timeout: NotRequired[str] + + +@dataclass +class Audio(MessageSegment): + data: AudioData = field(default_factory=dict) + + +class VideoData(TypedDict): + src: str + cache: NotRequired[bool] + timeout: NotRequired[str] + + +@dataclass +class Video(MessageSegment): + data: VideoData = field(default_factory=dict) + + +class FileData(TypedDict): + src: str + cache: NotRequired[bool] + timeout: NotRequired[str] + + +@dataclass +class File(MessageSegment): + data: FileData = field(default_factory=dict) + + +@dataclass +class Br(MessageSegment): + @override + def __str__(self): + return "
" + + +class ParagraphData(TypedDict): + text: str + + +@dataclass +class Paragraph(MessageSegment): + data: ParagraphData = field(default_factory=dict) + + @override + def __str__(self): + return f'

{escape(self.data["text"])}

' + + +class RenderMessageData(TypedDict): + id: NotRequired[str] + forward: NotRequired[bool] + content: NotRequired["Message"] + + +@dataclass +class RenderMessage(MessageSegment): + data: RenderMessageData = field(default_factory=dict) + + @override + def __str__(self): + attr = [] + if "id" in self.data: + attr.append(f'id="{escape(self.data["id"])}"') + if self.data.get("forward"): + attr.append("forward") + if "content" not in self.data: + return f'<{self.type} {" ".join(attr)} />' + else: + return f'<{self.type} {" ".join(attr)}>{self.data["content"]}' + + +class AuthorData(TypedDict): + id: str + nickname: NotRequired[str] + avatar: NotRequired[str] + + +@dataclass +class Author(MessageSegment): + data: AuthorData = field(default_factory=dict) + + +ELEMENT_TYPE_MAP = { + "text": (Text, "text"), + "at": (At, "at"), + "sharp": (Sharp, "sharp"), + "a": (Link, "link"), + "link": (Link, "link"), + "img": (Image, "img"), + "image": (Image, "img"), + "audio": (Audio, "audio"), + "video": (Video, "video"), + "file": (File, "file"), + "br": (Br, "br"), + "author": (Author, "author"), +} + + +class Message(BaseMessage[MessageSegment]): + @classmethod + @override + def get_segment_class(cls) -> Type[MessageSegment]: + return MessageSegment + + @override + def __add__( + self, other: Union[str, MessageSegment, Iterable[MessageSegment]] + ) -> "Message": + return super().__add__( + MessageSegment.text(other) if isinstance(other, str) else other + ) + + @override + def __radd__( + self, other: Union[str, MessageSegment, Iterable[MessageSegment]] + ) -> "Message": + return super().__radd__( + MessageSegment.text(other) if isinstance(other, str) else other + ) + + @staticmethod + @override + def _construct(msg: str) -> Iterable[MessageSegment]: + yield from Message.from_satori_element(parse(msg)) + + @classmethod + def from_satori_element(cls, elements: List[Element]) -> "Message": + msg = Message() + for elem in elements: + if elem.type in ELEMENT_TYPE_MAP: + seg_cls, seg_type = ELEMENT_TYPE_MAP[elem.type] + msg.append(seg_cls(seg_type, elem.attrs.copy())) + elif elem.type in { + "b", + "strong", + "i", + "em", + "u", + "ins", + "s", + "del", + "spl", + "code", + "sup", + "sub", + }: + msg.append( + Entity( + "entity", + {"text": elem.children[0].attrs["text"], "style": elem.type}, + ) + ) + elif elem.type in ("p", "paragraph"): + msg.append( + Paragraph("paragraph", {"text": elem.children[0].attrs["text"]}) + ) + elif elem.type in ("message", "quote"): + data = elem.attrs.copy() + if elem.children: + data["content"] = Message.from_satori_element(elem.children) + msg.append(RenderMessage(elem.type, data)) + else: + msg.append(Text("text", {"text": str(elem)})) + return msg + + def extract_content(self) -> str: + return "".join( + str(seg) + for seg in self + if seg.type in ("text", "entity", "at", "sharp", "link") + ) diff --git a/nonebot/adapters/satori/models.py b/nonebot/adapters/satori/models.py new file mode 100644 index 0000000..9123668 --- /dev/null +++ b/nonebot/adapters/satori/models.py @@ -0,0 +1,217 @@ +from enum import IntEnum +from datetime import datetime +from typing import Any, Dict, List, Union, Literal, Optional + +from pydantic import Field, BaseModel, validator + +from .utils import Element, parse + + +class ChannelType(IntEnum): + TEXT = 0 + VOICE = 1 + CATEGORY = 2 + DIRECT = 3 + + +class Channel(BaseModel): + id: str + name: str + type: ChannelType + parent_id: Optional[str] = None + + +class Guild(BaseModel): + id: str + name: str + avatar: Optional[str] = None + + +class User(BaseModel): + id: str + name: Optional[str] = None + avatar: Optional[str] = None + is_bot: Optional[bool] = None + + +class InnerMember(BaseModel): + user: Optional[User] = None + name: Optional[str] = None + avatar: Optional[str] = None + joined_at: Optional[datetime] = None + + @validator("joined_at", pre=True) + def parse_joined_at(cls, v): + if v is None: + return None + if isinstance(v, datetime): + return v + try: + timestamp = int(v) + except ValueError: + raise ValueError(f"invalid timestamp: {v}") + return datetime.fromtimestamp(timestamp / 1000) + + +class OuterMember(InnerMember): + user: User + joined_at: datetime + + +class Role(BaseModel): + id: str + name: str + + +class LoginStatus(IntEnum): + OFFLINE = 0 + ONLINE = 1 + CONNECT = 2 + DISCONNECT = 3 + RECONNECT = 4 + + +class InnerLogin(BaseModel): + user: Optional[User] = None + self_id: Optional[str] = None + platform: Optional[str] = None + status: LoginStatus + + +class OuterLogin(InnerLogin): + user: User + self_id: str + platform: str + + +class Opcode(IntEnum): + EVENT = 0 + PING = 1 + PONG = 2 + IDENTIFY = 3 + READY = 4 + + +class Payload(BaseModel): + op: Opcode = Field(...) + body: Optional[Dict[str, Any]] = Field(None) + + +class Identify(BaseModel): + token: Optional[str] = None + sequence: Optional[int] = None + + +class Ready(BaseModel): + logins: List[OuterLogin] + + +class IdentifyPayload(Payload): + op: Literal[Opcode.IDENTIFY] = Field(Opcode.IDENTIFY) + body: Identify + + +class ReadyPayload(Payload): + op: Literal[Opcode.READY] = Field(Opcode.READY) + body: Ready + + +class PingPayload(Payload): + op: Literal[Opcode.PING] = Field(Opcode.PING) + + +class PongPayload(Payload): + op: Literal[Opcode.PONG] = Field(Opcode.PONG) + + +class InnerMessage(BaseModel): + id: str + content: Optional[List[Element]] = None + channel: Optional[Channel] = None + guild: Optional[Guild] = None + member: Optional[InnerMember] = None + user: Optional[User] = None + created_at: Optional[datetime] = None + updated_at: Optional[datetime] = None + + @validator("content", pre=True) + def parse_content(cls, v): + if isinstance(v, list): + return v + if v is None: + return None + if not isinstance(v, str): + raise ValueError("content must be str") + return parse(v) + + @validator("created_at", pre=True) + def parse_created_at(cls, v): + if v is None: + return None + if isinstance(v, datetime): + return v + try: + timestamp = int(v) + except ValueError: + raise ValueError(f"invalid timestamp: {v}") + return datetime.fromtimestamp(timestamp / 1000) + + @validator("updated_at", pre=True) + def parse_updated_at(cls, v): + if v is None: + return None + if isinstance(v, datetime): + return v + try: + timestamp = int(v) + except ValueError: + raise ValueError(f"invalid timestamp: {v}") + return datetime.fromtimestamp(timestamp / 1000) + + +class OuterMessage(InnerMessage): + channel: Channel + guild: Guild + member: InnerMember + user: User + created_at: datetime + updated_at: datetime + + +class Event(BaseModel): + id: int + type: str + platform: str + self_id: str + timestamp: datetime + channel: Optional[Channel] = None + guild: Optional[Guild] = None + login: Optional[InnerLogin] = None + member: Optional[InnerMember] = None + message: Optional[InnerMessage] = None + operator: Optional[User] = None + role: Optional[Role] = None + user: Optional[User] = None + + @validator("timestamp", pre=True) + def parse_timestamp(cls, v): + if v is None: + return None + if isinstance(v, datetime): + return v + try: + timestamp = int(v) + except ValueError: + raise ValueError(f"invalid timestamp: {v}") + return datetime.fromtimestamp(timestamp / 1000) + + +class EventPayload(Payload): + op: Literal[Opcode.EVENT] = Field(Opcode.EVENT) + body: Event + + +PayloadType = Union[ + Union[EventPayload, PingPayload, PongPayload, IdentifyPayload, ReadyPayload], + Payload, +] diff --git a/nonebot/adapters/satori/utils.py b/nonebot/adapters/satori/utils.py new file mode 100644 index 0000000..3d0bfac --- /dev/null +++ b/nonebot/adapters/satori/utils.py @@ -0,0 +1,186 @@ +import re +from functools import partial +from typing_extensions import ParamSpec, Concatenate +from typing import ( + TYPE_CHECKING, + Any, + Dict, + List, + Type, + Union, + Generic, + TypeVar, + Callable, + Optional, + Awaitable, + overload, +) + +from pydantic import Field, BaseModel +from nonebot.utils import logger_wrapper + +if TYPE_CHECKING: + from .bot import Bot + +B = TypeVar("B", bound="Bot") +R = TypeVar("R") +P = ParamSpec("P") +log = logger_wrapper("Satori") + + +def escape(text: str) -> str: + return ( + text.replace('"', """) + .replace("&", "&") + .replace("<", "<") + .replace(">", ">") + ) + + +def unescape(text: str) -> str: + return ( + text.replace(""", '"') + .replace("&", "&") + .replace("<", "<") + .replace(">", ">") + ) + + +class Element(BaseModel): + type: str + attrs: Dict[str, Any] = Field(default_factory=dict) + children: List["Element"] = Field(default_factory=list) + source: Optional[str] = None + + def __str__(self): + if self.source: + return self.source + if self.type == "text": + return escape(self.attrs["text"]) + + def _attr(key: str, value: Any): + if value is True: + return key + if value is False: + return f"no-{key}" + if isinstance(value, (int, float)): + return f"{key}={value}" + return f'{key}="{escape(str(value))}"' + + attrs = " ".join(_attr(k, v) for k, v in self.attrs.items()) + if not self.children: + return f"<{self.type} {attrs}/>" + children = "".join(str(c) for c in self.children) + return f"<{self.type} {attrs}>{children}" + + +tag_pat = re.compile(r"|<(/?)([^!\s>/]*)([^>]*?)\s*(/?)>") +attr_pat = re.compile(r"([^\s=]+)(?:=\"([^\"]*)\"|='([^']*)')?", re.S) + + +class Token(BaseModel): + type: str + close: str + empty: str + attrs: Dict[str, Any] + source: str + + +def parse(src: str): + tokens: List[Union[Token, Element]] = [] + + def push_text(text: str): + if text: + tokens.append(Element(type="text", attrs={"text": text})) + + def parse_content(source: str): + push_text(unescape(source)) + + while tag_map := tag_pat.search(src): + parse_content(src[: tag_map.start()]) + src = src[tag_map.end() :] + if tag_map.group(0).startswith("