diff --git a/.github/ISSUE_TEMPLATE/bug.md b/.github/ISSUE_TEMPLATE/bug.md
new file mode 100644
index 0000000..1d0b478
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/bug.md
@@ -0,0 +1,34 @@
+---
+name: Bug 报告
+about: 有关 bug 的报告
+title: "[Bug]"
+labels: bug, triage
+assignees: ""
+---
+
+## 请确认:
+
+* [ ] 问题的标题明确
+* [ ] 我翻阅过其他的 issue 并且找不到类似的问题
+* [ ] 我已经阅读了[相关文档](https://satori.js.org/zh-CN/) 并仍然认为这是一个Bug
+
+# Bug
+
+## 问题
+
+
+## 如何复现
+
+
+## 预期行为
+
+
+## 使用环境:
+- 操作系统 (Windows/Linux/Mac):
+- Python 版本:
+- Nonebot2 版本:
+- 适配器版本:
+- 使用的 Satori 服务端 (例如 Chronocat):
+
+## 日志/截图
+
diff --git a/.github/ISSUE_TEMPLATE/feature.md b/.github/ISSUE_TEMPLATE/feature.md
new file mode 100644
index 0000000..9be8fcc
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/feature.md
@@ -0,0 +1,27 @@
+---
+name: Feature 特性请求
+about: 为适配器加份菜
+title: "[Feature] "
+labels: enhancement, triage
+assignees: ""
+---
+
+## 请确认:
+
+* [ ] 新特性的目的明确
+* [ ] 我已经阅读了[相关文档](https://satori.js.org/zh-CN/) 并且找不到类似特性
+
+
+## Feature
+### 概要
+
+
+
+### 是否已有相关实现
+
+暂无
+
+
+### 其他内容
+
+暂无
diff --git a/.github/actions/setup-python/action.yml b/.github/actions/setup-python/action.yml
new file mode 100644
index 0000000..ebf8f36
--- /dev/null
+++ b/.github/actions/setup-python/action.yml
@@ -0,0 +1,21 @@
+name: Setup Python
+description: Setup Python
+
+inputs:
+ python-version:
+ description: Python version
+ required: false
+ default: "3.10"
+
+runs:
+ using: "composite"
+ steps:
+ - uses: pdm-project/setup-pdm@v3
+ name: Setup PDM
+ with:
+ python-version: ${{ inputs.python-version }}
+ architecture: "x64"
+ cache: true
+
+ - run: pdm sync -G:all
+ shell: bash
diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml
new file mode 100644
index 0000000..a341b06
--- /dev/null
+++ b/.github/workflows/release.yml
@@ -0,0 +1,36 @@
+name: Release
+
+on:
+ push:
+ tags:
+ - v*
+
+jobs:
+ release:
+ runs-on: ubuntu-latest
+ permissions:
+ id-token: write
+ contents: write
+ steps:
+ - uses: actions/checkout@v3
+
+ - name: Setup Python environment
+ uses: ./.github/actions/setup-python
+
+ - name: Get Version
+ id: version
+ run: |
+ echo "VERSION=$(pdm show --version)" >> $GITHUB_OUTPUT
+ echo "TAG_VERSION=${GITHUB_REF#refs/tags/v}" >> $GITHUB_OUTPUT
+ echo "TAG_NAME=${GITHUB_REF#refs/tags/}" >> $GITHUB_OUTPUT
+
+ - name: Check Version
+ if: steps.version.outputs.VERSION != steps.version.outputs.TAG_VERSION
+ run: exit 1
+
+ - name: Publish Package
+ run: |
+ pdm publish
+ gh release upload --clobber ${{ steps.version.outputs.TAG_NAME }} dist/*.tar.gz dist/*.whl
+ env:
+ GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml
new file mode 100644
index 0000000..ec3f9c2
--- /dev/null
+++ b/.github/workflows/ruff.yml
@@ -0,0 +1,17 @@
+name: Ruff Lint
+
+on:
+ push:
+ branches:
+ - main
+ pull_request:
+
+jobs:
+ ruff:
+ name: Ruff Lint
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v3
+
+ - name: Run Ruff Lint
+ uses: chartboost/ruff-action@v1
diff --git a/.gitignore b/.gitignore
index 68bc17f..4b3ab30 100644
--- a/.gitignore
+++ b/.gitignore
@@ -108,6 +108,8 @@ ipython_config.py
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
+.pdm-python
+.pdm-build/
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
new file mode 100644
index 0000000..73fa8ca
--- /dev/null
+++ b/.pre-commit-config.yaml
@@ -0,0 +1,26 @@
+default_install_hook_types: [pre-commit, prepare-commit-msg]
+ci:
+ autofix_commit_msg: ":rotating_light: auto fix by pre-commit hooks"
+ autofix_prs: true
+ autoupdate_branch: master
+ autoupdate_schedule: monthly
+ autoupdate_commit_msg: ":arrow_up: auto update by pre-commit hooks"
+repos:
+ - repo: https://github.com/astral-sh/ruff-pre-commit
+ rev: v0.0.276
+ hooks:
+ - id: ruff
+ args: [--fix, --exit-non-zero-on-fix]
+ stages: [commit]
+
+ - repo: https://github.com/pycqa/isort
+ rev: 5.12.0
+ hooks:
+ - id: isort
+ stages: [commit]
+
+ - repo: https://ghproxy.com/github.com/psf/black
+ rev: 23.3.0
+ hooks:
+ - id: black
+ stages: [commit]
diff --git a/README.md b/README.md
index cab78d5..db1a12c 100644
--- a/README.md
+++ b/README.md
@@ -1,2 +1,71 @@
-# adapter-satori
-NoneBot2 Satori 适配器 / Satori Protocol adapter for nonebot2
+
+
+# NoneBot-Adapter-Satori
+
+_✨ NoneBot2 Satori Protocol适配器 / Satori Protocol Adapter for NoneBot2 ✨_
+
+
+
+## 协议介绍
+
+[Satori Protocol](https://satori.js.org/zh-CN/)
+
+## 协议端
+
+目前提供了 `satori` 协议实现的有:
+- [Chronocat](https://chronocat.vercel.app)
+- Satori
+- Koishi (搭配 `@koishijs/plugin-server`)
+
+## 配置
+
+修改 NoneBot 配置文件 `.env` 或者 `.env.*`。
+
+### Driver
+
+参考 [driver](https://nonebot.dev/docs/appendices/config#driver) 配置项,添加 `ForwardDriver` 支持。
+
+如:
+
+```dotenv
+DRIVER=~httpx+~websockets
+DRIVER=~aiohttp
+```
+
+### SATORI_CLIENTS
+
+配置连接配置,如:
+
+```dotenv
+SATORI_CLIENTS='
+[
+ {
+ "port": "localhost",
+ "host": "5500",
+ "token": "xxx"
+ }
+]
+'
+```
+
+`host` 与 `port` 为 Satori 服务端的监听地址与端口,
+
+`token` 由 Satori 服务端决定是否需要。
+
+
+## 示例
+
+```python
+from nonebot import on_command
+from nonebot.adapters.satori import Bot
+from nonebot.adapters.satori.event import MessageEvent
+from nonebot.adapters.satori.message import MessageSegment
+
+
+matcher = on_command("test")
+
+@matcher.handle()
+async def handle_receive(bot: Bot, event: MessageEvent):
+ if event.is_private:
+ await bot.send(event, MessageSegment.text("Hello, world!"))
+```
diff --git a/nonebot/adapters/satori/__init__.py b/nonebot/adapters/satori/__init__.py
new file mode 100644
index 0000000..2bdb6f9
--- /dev/null
+++ b/nonebot/adapters/satori/__init__.py
@@ -0,0 +1,5 @@
+from .bot import Bot as Bot
+from .adapter import Adapter as Adapter
+from .message import Message as Message
+from .event import MessageEvent as MessageEvent
+from .message import MessageSegment as MessageSegment
diff --git a/nonebot/adapters/satori/adapter.py b/nonebot/adapters/satori/adapter.py
new file mode 100644
index 0000000..01fbb44
--- /dev/null
+++ b/nonebot/adapters/satori/adapter.py
@@ -0,0 +1,282 @@
+import asyncio
+from typing_extensions import override
+from typing import Any, Dict, List, Literal, Optional
+
+from pydantic import parse_raw_as
+from nonebot.utils import escape_tag
+from nonebot.exception import WebSocketClosed
+from nonebot.drivers import (
+ Driver,
+ Request,
+ WebSocket,
+ HTTPClientMixin,
+ WebSocketClientMixin,
+)
+
+from nonebot.adapters import Adapter as BaseAdapter
+
+from .bot import Bot
+from .utils import API, log
+from .config import Config, ClientInfo
+from .exception import ApiNotAvailable
+from .models import Event as SatoriEvent
+from .event import (
+ EVENT_CLASSES,
+ Event,
+ MessageEvent,
+ LoginAddedEvent,
+ LoginRemovedEvent,
+ LoginUpdatedEvent,
+)
+from .models import (
+ Payload,
+ PayloadType,
+ PingPayload,
+ PongPayload,
+ EventPayload,
+ ReadyPayload,
+ IdentifyPayload,
+)
+
+
+class Adapter(BaseAdapter):
+ @override
+ def __init__(self, driver: Driver, **kwargs: Any):
+ super().__init__(driver, **kwargs)
+ # 读取适配器所需的配置项
+ self.satori_config: Config = Config.parse_obj(self.config)
+ self.tasks: List[asyncio.Task] = [] # 存储 ws 任务
+ self.sequences: Dict[str, int] = {} # 存储 连接序列号
+ self.setup()
+
+ @classmethod
+ @override
+ def get_name(cls) -> str:
+ """适配器名称"""
+ return "Satori"
+
+ def setup(self) -> None:
+ if not isinstance(self.driver, HTTPClientMixin):
+ # 判断用户配置的Driver类型是否符合适配器要求,不符合时应抛出异常
+ raise RuntimeError(
+ f"Current driver {self.config.driver} "
+ f"doesn't support http client requests!"
+ f"{self.get_name()} Adapter need a HTTPClient Driver to work."
+ )
+ if not isinstance(self.driver, WebSocketClientMixin):
+ raise RuntimeError(
+ f"Current driver {self.config.driver} does not support "
+ "websocket client! "
+ f"{self.get_name()} Adapter need a WebSocketClient Driver to work."
+ )
+ # 在 NoneBot 启动和关闭时进行相关操作
+ self.driver.on_startup(self.startup)
+ self.driver.on_shutdown(self.shutdown)
+
+ async def startup(self) -> None:
+ """定义启动时的操作,例如和平台建立连接"""
+ for client in self.satori_config.satori_clients:
+ self.tasks.append(asyncio.create_task(self.ws(client)))
+
+ async def shutdown(self) -> None:
+ for task in self.tasks:
+ if not task.done():
+ task.cancel()
+
+ await asyncio.gather(
+ *(asyncio.wait_for(task, timeout=10) for task in self.tasks),
+ return_exceptions=True,
+ )
+
+ @staticmethod
+ def payload_to_json(payload: Payload) -> str:
+ return payload.__config__.json_dumps(
+ payload.dict(), default=payload.__json_encoder__
+ )
+
+ async def receive_payload(self, info: ClientInfo, ws: WebSocket) -> Payload:
+ payload = parse_raw_as(PayloadType, await ws.receive())
+ if isinstance(payload, EventPayload):
+ self.sequences[info.identity] = payload.body.id
+ return payload
+
+ async def _authenticate(
+ self, info: ClientInfo, ws: WebSocket
+ ) -> Optional[Literal[True]]:
+ """鉴权连接"""
+ payload = IdentifyPayload.parse_obj(
+ {
+ "body": {
+ "token": info.token,
+ },
+ }
+ )
+ if info.identity in self.sequences:
+ payload.body.sequence = self.sequences[info.identity]
+
+ try:
+ await ws.send(self.payload_to_json(payload))
+ except Exception as e:
+ log(
+ "ERROR",
+ "Error while sending "
+ + "Identify event",
+ e,
+ )
+ return
+
+ resp = await self.receive_payload(info, ws)
+ if not isinstance(resp, ReadyPayload):
+ log(
+ "ERROR",
+ "Received unexpected payload while authenticating: "
+ f"{escape_tag(repr(resp))}",
+ )
+ return
+ for login in resp.body.logins:
+ if login.self_id not in self.bots:
+ bot = Bot(self, login.self_id, login.platform, info)
+ self.bot_connect(bot)
+ log(
+ "INFO",
+ f"Bot {escape_tag(bot.self_id)} connected",
+ )
+ else:
+ bot = self.bots[login.self_id]
+ bot.on_ready(login.user)
+
+ return True
+
+ async def _heartbeat(self, info: ClientInfo, ws: WebSocket):
+ """心跳"""
+ while True:
+ log("TRACE", f"Heartbeat {self.sequences[info.identity]}")
+ payload = PingPayload.parse_obj({})
+ try:
+ await ws.send(self.payload_to_json(payload))
+ except Exception as e:
+ log("WARNING", "Error while sending heartbeat, Ignored!", e)
+ await asyncio.sleep(9)
+
+ async def ws(self, info: ClientInfo) -> None:
+ ws_url = f"ws://{info.host}:{info.port}/v1/events"
+ req = Request("GET", ws_url, timeout=60.0)
+ heartbeat_task: Optional["asyncio.Task"] = None
+ while True:
+ try:
+ async with self.websocket(req) as ws:
+ log(
+ "DEBUG",
+ f"WebSocket Connection to "
+ f"{escape_tag(str(ws_url))} established",
+ )
+ try:
+ if not await self._authenticate(info, ws):
+ await asyncio.sleep(3)
+ continue
+ heartbeat_task = asyncio.create_task(self._heartbeat(info, ws))
+ await self._loop(info, ws)
+ except WebSocketClosed as e:
+ log(
+ "ERROR",
+ "WebSocket Closed",
+ e,
+ )
+ except Exception as e:
+ log(
+ "ERROR",
+ "Error while process data from websocket "
+ f"{escape_tag(str(ws_url))}. "
+ f"Trying to reconnect...",
+ e,
+ )
+ finally:
+ if heartbeat_task:
+ heartbeat_task.cancel()
+ heartbeat_task = None
+ bots = list(self.bots.values())
+ for bot in bots:
+ self.bot_disconnect(bot)
+ bots.clear()
+ except Exception as e:
+ log(
+ "ERROR",
+ (
+ ""
+ "Error while setup websocket to "
+ f"{escape_tag(str(ws_url))}. Trying to reconnect..."
+ ""
+ ),
+ e,
+ )
+ await asyncio.sleep(3) # 重连间隔
+
+ async def _loop(self, info: ClientInfo, ws: WebSocket):
+ while True:
+ payload = await self.receive_payload(info, ws)
+ log(
+ "TRACE",
+ f"Received payload: {escape_tag(repr(payload))}",
+ )
+ if isinstance(payload, EventPayload):
+ try:
+ event = self.payload_to_event(payload.body)
+ except Exception as e:
+ log(
+ "WARNING",
+ f"Failed to parse event {escape_tag(repr(payload))}",
+ e,
+ )
+ else:
+ if isinstance(event, LoginAddedEvent):
+ bot = Bot(self, event.self_id, event.platform, info)
+ bot.on_ready(event.user)
+ self.bot_connect(bot)
+ log(
+ "INFO",
+ f"Bot {escape_tag(bot.self_id)} connected",
+ )
+ elif isinstance(event, LoginRemovedEvent):
+ self.bot_disconnect(self.bots[event.self_id])
+ log(
+ "INFO",
+ f"Bot {escape_tag(event.self_id)} disconnected",
+ )
+ continue
+ elif isinstance(event, LoginUpdatedEvent):
+ self.bots[event.self_id].on_ready(event.user)
+ if not (bot := self.bots.get(event.self_id)):
+ log(
+ "WARNING",
+ f"Received event for unknown bot "
+ f"{escape_tag(event.self_id)}",
+ )
+ continue
+ if isinstance(event, MessageEvent):
+ event = event.convert()
+ asyncio.create_task(bot.handle_event(event))
+ elif isinstance(payload, PongPayload):
+ log("TRACE", "Pong")
+ continue
+ else:
+ log(
+ "WARNING",
+ f"Unknown payload from server: {escape_tag(repr(payload))}",
+ )
+
+ @staticmethod
+ def payload_to_event(payload: SatoriEvent) -> Event:
+ EventClass = EVENT_CLASSES.get(payload.type, None)
+ if EventClass is None:
+ log("WARNING", f"Unknown payload type: {payload.type}")
+ event = Event.parse_obj(payload)
+ return event
+ return EventClass.parse_obj(payload)
+
+ @override
+ async def _call_api(self, bot: Bot, api: str, **data: Any) -> Any:
+ log("DEBUG", f"Bot {bot.self_id} calling API {api}")
+ api_handler: Optional[API] = getattr(bot.__class__, api, None)
+ if api_handler is None:
+ raise ApiNotAvailable
+ return await api_handler(bot, **data)
diff --git a/nonebot/adapters/satori/bot.py b/nonebot/adapters/satori/bot.py
new file mode 100644
index 0000000..dd665aa
--- /dev/null
+++ b/nonebot/adapters/satori/bot.py
@@ -0,0 +1,544 @@
+import re
+import json
+from typing_extensions import override
+from typing import TYPE_CHECKING, Any, Dict, List, Union, Optional
+
+from nonebot.message import handle_event
+from nonebot.drivers import Request, Response
+
+from nonebot.adapters import Bot as BaseBot
+
+from .utils import API, log
+from .config import ClientInfo
+from .event import Event, MessageEvent
+from .message import Message, MessageSegment
+from .models import InnerMessage as SatoriMessage
+from .models import Role, User, Guild, Channel, OuterLogin, OuterMember
+from .exception import (
+ ActionFailed,
+ NetworkError,
+ NotFoundException,
+ ForbiddenException,
+ BadRequestException,
+ UnauthorizedException,
+ MethodNotAllowedException,
+)
+
+if TYPE_CHECKING:
+ from .adapter import Adapter
+
+
+def _check_reply(
+ bot: "Bot",
+ event: MessageEvent,
+) -> None:
+ """检查消息中存在的回复,赋值 `event.reply`, `event.to_me`。
+
+ 参数:
+ bot: Bot 对象
+ event: MessageEvent 对象
+ """
+ message = event.get_message()
+ try:
+ index = message.index("quote")
+ except ValueError:
+ return
+
+ msg_seg = message[index]
+
+ event.reply = msg_seg # type: ignore
+
+ del message[index]
+ if (
+ len(message) > index
+ and message[index].type == "at"
+ and message[index].data.get("id") == str(bot.self_info.id)
+ ):
+ del message[index]
+ if len(message) > index and message[index].type == "text":
+ message[index].data["text"] = message[index].data["text"].lstrip()
+ if not message[index].data["text"]:
+ del message[index]
+ if not message:
+ message.append(MessageSegment.text(""))
+
+
+def _check_at_me(
+ bot: "Bot",
+ event: MessageEvent,
+):
+ def _is_at_me_seg(segment: MessageSegment) -> bool:
+ return segment.type == "at" and segment.data.get("id") == str(bot.self_info.id)
+
+ message = event.get_message()
+
+ # ensure message is not empty
+ if not message:
+ message.append(MessageSegment.text(""))
+
+ deleted = False
+ if _is_at_me_seg(message[0]):
+ message.pop(0)
+ deleted = True
+ if message and message[0].type == "text":
+ message[0].data["text"] = message[0].data["text"].lstrip("\xa0").lstrip()
+ if not message[0].data["text"]:
+ del message[0]
+
+ if not deleted:
+ # check the last segment
+ i = -1
+ last_msg_seg = message[i]
+ if (
+ last_msg_seg.type == "text"
+ and not last_msg_seg.data["text"].strip()
+ and len(message) >= 2
+ ):
+ i -= 1
+ last_msg_seg = message[i]
+
+ if _is_at_me_seg(last_msg_seg):
+ deleted = True
+ del message[i:]
+
+ if not message:
+ message.append(MessageSegment.text(""))
+
+
+def _check_nickname(bot: "Bot", event: MessageEvent) -> None:
+ """检查消息开头是否存在昵称,去除并赋值 `event.to_me`。
+
+ 参数:
+ bot: Bot 对象
+ event: MessageEvent 对象
+ """
+ message = event.get_message()
+ first_msg_seg = message[0]
+ if first_msg_seg.type != "text":
+ return
+
+ nicknames = {re.escape(n) for n in bot.config.nickname}
+ if not nicknames:
+ return
+
+ # check if the user is calling me with my nickname
+ nickname_regex = "|".join(nicknames)
+ first_text = first_msg_seg.data["text"]
+ if m := re.search(rf"^({nickname_regex})([\s,,]*|$)", first_text, re.IGNORECASE):
+ log("DEBUG", f"User is calling me {m[1]}")
+ event.to_me = True
+ first_msg_seg.data["text"] = first_text[m.end() :]
+
+
+class Bot(BaseBot):
+ adapter: "Adapter"
+
+ @override
+ def __init__(
+ self, adapter: "Adapter", self_id: str, platform: str, info: ClientInfo
+ ):
+ super().__init__(adapter, self_id)
+
+ # Bot 配置信息
+ self.info: ClientInfo = info
+ # Bot 自身所属平台
+ self.platform: str = platform
+ # Bot 自身信息
+ self._self_info: Optional[User] = None
+
+ def __getattr__(self, item):
+ raise AttributeError(f"'Bot' object has no attribute '{item}'")
+
+ @property
+ def ready(self) -> bool:
+ """Bot 是否已连接"""
+ return self._self_info is not None
+
+ @property
+ def self_info(self) -> User:
+ """Bot 自身信息,仅当 Bot 连接鉴权完成后可用"""
+ if self._self_info is None:
+ raise RuntimeError(
+ f"Bot {self.self_id} of {self.platform} is not connected!"
+ )
+ return self._self_info
+
+ def on_ready(self, user: User) -> None:
+ self._self_info = user
+
+ def get_authorization_header(self) -> Dict[str, str]:
+ """获取当前 Bot 的鉴权信息"""
+ return {
+ "Authorization": f"Bearer {self.info.token}",
+ "X-Self-ID": self.self_id,
+ "X-Platform": self.platform,
+ }
+
+ async def handle_event(self, event: Event) -> None:
+ if isinstance(event, MessageEvent):
+ _check_reply(self, event)
+ _check_at_me(self, event)
+ _check_nickname(self, event)
+ await handle_event(self, event)
+
+ def _handle_response(self, response: Response) -> Any:
+ if 200 <= response.status_code < 300:
+ return response.content and json.loads(response.content)
+ elif response.status_code == 400:
+ raise BadRequestException(response)
+ elif response.status_code == 401:
+ raise UnauthorizedException(response)
+ elif response.status_code == 403:
+ raise ForbiddenException(response)
+ elif response.status_code == 404:
+ raise NotFoundException(response)
+ elif response.status_code == 405:
+ raise MethodNotAllowedException(response)
+ else:
+ raise ActionFailed(response)
+
+ async def _request(self, request: Request) -> Any:
+ request.headers.update(self.get_authorization_header())
+
+ try:
+ response = await self.adapter.request(request)
+ except Exception as e:
+ raise NetworkError("API request failed") from e
+
+ return self._handle_response(response)
+
+ @override
+ async def send(
+ self,
+ event: Event,
+ message: Union[str, Message, MessageSegment],
+ **kwargs,
+ ) -> List[SatoriMessage]:
+ if not event.channel:
+ raise RuntimeError("Event cannot be replied to!")
+ return await self.send_message(event.channel.id, message)
+
+ async def send_message(
+ self,
+ channel_id: str,
+ message: Union[str, Message, MessageSegment],
+ ):
+ """发送消息
+
+ 参数:
+ channel_id: 要发送的频道 ID
+ message: 要发送的消息
+ """
+ return await self.message_create(channel_id=channel_id, content=str(message))
+
+ async def update_message(
+ self,
+ channel_id: str,
+ message_id: str,
+ message: Union[str, Message, MessageSegment],
+ ):
+ """更新消息
+
+ 参数:
+ channel_id: 要更新的频道 ID
+ message_id: 要更新的消息 ID
+ message: 要更新的消息
+ """
+ return await self.message_update(
+ channel_id=channel_id, message_id=message_id, content=str(message)
+ )
+
+ @API
+ async def message_create(
+ self,
+ *,
+ channel_id: str,
+ content: str,
+ ):
+ request = Request(
+ "POST",
+ self.info.api_base / "message.create",
+ json={"channel_id": channel_id, "content": content},
+ )
+ res = await self._request(request)
+ return [SatoriMessage.parse_obj(i) for i in res]
+
+ @API
+ async def message_get(self, *, channel_id: str, message_id: str):
+ request = Request(
+ "POST",
+ self.info.api_base / "message.get",
+ json={"channel_id": channel_id, "message_id": message_id},
+ )
+ res = await self._request(request)
+ return SatoriMessage.parse_obj(res)
+
+ @API
+ async def message_delete(self, *, channel_id: str, message_id: str):
+ request = Request(
+ "POST",
+ self.info.api_base / "message.delete",
+ json={"channel_id": channel_id, "message_id": message_id},
+ )
+ await self._request(request)
+
+ @API
+ async def message_update(
+ self,
+ *,
+ channel_id: str,
+ message_id: str,
+ content: str,
+ ):
+ request = Request(
+ "POST",
+ self.info.api_base / "message.update",
+ json={
+ "channel_id": channel_id,
+ "message_id": message_id,
+ "content": content,
+ },
+ )
+ await self._request(request)
+
+ @API
+ async def message_list(self, *, channel_id: str, next_token: Optional[str] = None):
+ request = Request(
+ "POST",
+ self.info.api_base / "message.list",
+ json={"channel_id": channel_id, "next": next_token},
+ )
+ res = await self._request(request)
+ return [SatoriMessage.parse_obj(i) for i in res]
+
+ @API
+ async def channel_get(self, *, channel_id: str):
+ request = Request(
+ "POST",
+ self.info.api_base / "channel.get",
+ json={"channel_id": channel_id},
+ )
+ res = await self._request(request)
+ return Channel.parse_obj(res)
+
+ @API
+ async def channel_list(self, *, guild_id: str, next_token: Optional[str] = None):
+ request = Request(
+ "POST",
+ self.info.api_base / "channel.list",
+ json={"guild_id": guild_id, "next": next_token},
+ )
+ res = await self._request(request)
+ return [Channel.parse_obj(i) for i in res]
+
+ @API
+ async def channel_create(self, *, guild_id: str, data: Channel):
+ request = Request(
+ "POST",
+ self.info.api_base / "channel.create",
+ json={"guild_id": guild_id, "data": data.dict()},
+ )
+ return Channel.parse_obj(await self._request(request))
+
+ @API
+ async def channel_update(
+ self,
+ *,
+ channel_id: str,
+ data: Channel,
+ ):
+ request = Request(
+ "POST",
+ self.info.api_base / "channel.update",
+ json={"channel_id": channel_id, "data": data.dict()},
+ )
+ await self._request(request)
+
+ @API
+ async def channel_delete(self, *, channel_id: str):
+ request = Request(
+ "POST",
+ self.info.api_base / "channel.delete",
+ json={"channel_id": channel_id},
+ )
+ await self._request(request)
+
+ @API
+ async def user_channel_create(self, *, user_id: str):
+ request = Request(
+ "POST",
+ self.info.api_base / "user.channel.create",
+ json={"user_id": user_id},
+ )
+ return Channel.parse_obj(await self._request(request))
+
+ @API
+ async def guild_get(self, *, guild_id: str):
+ request = Request(
+ "POST",
+ self.info.api_base / "guild.get",
+ json={"guild_id": guild_id},
+ )
+ return Guild.parse_obj(await self._request(request))
+
+ @API
+ async def guild_list(self, *, next_token: Optional[str] = None):
+ request = Request(
+ "POST",
+ self.info.api_base / "guild.list",
+ json={"next": next_token},
+ )
+ return [Guild.parse_obj(i) for i in await self._request(request)]
+
+ @API
+ async def guild_approve(self, *, request_id: str, approve: bool, comment: str):
+ request = Request(
+ "POST",
+ self.info.api_base / "guild.approve",
+ json={"message_id": request_id, "approve": approve, "comment": comment},
+ )
+ await self._request(request)
+
+ @API
+ async def guild_member_list(
+ self, *, guild_id: str, next_token: Optional[str] = None
+ ):
+ request = Request(
+ "POST",
+ self.info.api_base / "guild.member.list",
+ json={"guild_id": guild_id, "next": next_token},
+ )
+ return [OuterMember.parse_obj(i) for i in await self._request(request)]
+
+ @API
+ async def guild_member_get(self, *, guild_id: str, user_id: str):
+ request = Request(
+ "POST",
+ self.info.api_base / "guild.member.get",
+ json={"guild_id": guild_id, "user_id": user_id},
+ )
+ return OuterMember.parse_obj(await self._request(request))
+
+ @API
+ async def guild_member_kick(
+ self, *, guild_id: str, user_id: str, permanent: bool = False
+ ):
+ request = Request(
+ "POST",
+ self.info.api_base / "guild.member.kick",
+ json={"guild_id": guild_id, "user_id": user_id, "permanent": permanent},
+ )
+ await self._request(request)
+
+ @API
+ async def guild_member_approve(
+ self, *, request_id: str, approve: bool, comment: str
+ ):
+ request = Request(
+ "POST",
+ self.info.api_base / "guild.member.approve",
+ json={"message_id": request_id, "approve": approve, "comment": comment},
+ )
+ await self._request(request)
+
+ @API
+ async def guild_member_role_set(self, *, guild_id: str, user_id: str, role_id: str):
+ request = Request(
+ "POST",
+ self.info.api_base / "guild.member.role.set",
+ json={"guild_id": guild_id, "user_id": user_id, "role_id": role_id},
+ )
+ await self._request(request)
+
+ @API
+ async def guild_member_role_unset(
+ self, *, guild_id: str, user_id: str, role_id: str
+ ):
+ request = Request(
+ "POST",
+ self.info.api_base / "guild.member.role.unset",
+ json={"guild_id": guild_id, "user_id": user_id, "role_id": role_id},
+ )
+ await self._request(request)
+
+ @API
+ async def guild_role_list(self, guild_id: str, next_token: Optional[str] = None):
+ request = Request(
+ "POST",
+ self.info.api_base / "guild.role.list",
+ json={"guild_id": guild_id, "next": next_token},
+ )
+ return [Role.parse_obj(i) for i in await self._request(request)]
+
+ @API
+ async def guild_role_create(
+ self,
+ *,
+ guild_id: str,
+ role: Role,
+ ):
+ request = Request(
+ "POST",
+ self.info.api_base / "guild.role.create",
+ json={"guild_id": guild_id, "role": role.dict()},
+ )
+ return Role.parse_obj(await self._request(request))
+
+ @API
+ async def guild_role_update(
+ self,
+ *,
+ guild_id: str,
+ role_id: str,
+ role: Role,
+ ):
+ request = Request(
+ "POST",
+ self.info.api_base / "guild.role.update",
+ json={"guild_id": guild_id, "role_id": role_id, "role": role.dict()},
+ )
+ await self._request(request)
+
+ @API
+ async def guild_role_delete(self, *, guild_id: str, role_id: str):
+ request = Request(
+ "POST",
+ self.info.api_base / "guild.role.delete",
+ json={"guild_id": guild_id, "role_id": role_id},
+ )
+ await self._request(request)
+
+ @API
+ async def login_get(self):
+ request = Request(
+ "POST",
+ self.info.api_base / "login.get",
+ )
+ return OuterLogin.parse_obj(await self._request(request))
+
+ @API
+ async def user_get(self, *, user_id: str):
+ request = Request(
+ "POST",
+ self.info.api_base / "user.get",
+ json={"user_id": user_id},
+ )
+ return User.parse_obj(await self._request(request))
+
+ @API
+ async def friend_list(self, *, next_token: Optional[str] = None):
+ request = Request(
+ "POST",
+ self.info.api_base / "friend.list",
+ json={"next": next_token},
+ )
+ return [User.parse_obj(i) for i in await self._request(request)]
+
+ @API
+ async def friend_approve(self, *, request_id: str, approve: bool, comment: str):
+ request = Request(
+ "POST",
+ self.info.api_base / "friend.approve",
+ json={"message_id": request_id, "approve": approve, "comment": comment},
+ )
+ await self._request(request)
diff --git a/nonebot/adapters/satori/config.py b/nonebot/adapters/satori/config.py
new file mode 100644
index 0000000..56038da
--- /dev/null
+++ b/nonebot/adapters/satori/config.py
@@ -0,0 +1,27 @@
+from typing import List, Optional
+
+from yarl import URL
+from pydantic import Extra, Field, BaseModel
+
+
+class ClientInfo(BaseModel):
+ host: str = "localhost"
+ port: int
+ token: Optional[str] = None
+
+ @property
+ def identity(self):
+ return f"{self.host}:{self.port}#{self.token}"
+
+ @property
+ def api_base(self):
+ return URL(f"http://{self.host}:{self.port}") / "v1"
+
+ @property
+ def ws_base(self):
+ return URL(f"ws://{self.host}:{self.port}") / "v1"
+
+
+class Config(BaseModel, extra=Extra.ignore):
+ satori_clients: List[ClientInfo] = Field(default_factory=list)
+ """client 配置"""
diff --git a/nonebot/adapters/satori/event.py b/nonebot/adapters/satori/event.py
new file mode 100644
index 0000000..754dd4f
--- /dev/null
+++ b/nonebot/adapters/satori/event.py
@@ -0,0 +1,421 @@
+from enum import Enum
+from copy import deepcopy
+from typing_extensions import override
+from typing import TYPE_CHECKING, Any, Dict, Type, TypeVar, Optional
+
+from pydantic import root_validator
+from nonebot.utils import escape_tag
+
+from nonebot.adapters import Event as BaseEvent
+
+from .models import Role, User
+from .models import Event as SatoriEvent
+from .message import Message, RenderMessage
+from .models import InnerMessage as SatoriMessage
+from .models import Guild, Channel, InnerLogin, ChannelType, InnerMember
+
+E = TypeVar("E", bound="Event")
+
+
+class EventType(str, Enum):
+ FRIEND_REQUEST = "friend-request"
+ GUILD_ADDED = "guild-added"
+ GUILD_MEMBER_ADDED = "guild-member-added"
+ GUILD_MEMBER_REMOVED = "guild-member-removed"
+ GUILD_MEMBER_REQUEST = "guild-member-request"
+ GUILD_MEMBER_UPDATED = "guild-member-updated"
+ GUILD_REMOVED = "guild-removed"
+ GUILD_REQUEST = "guild-request"
+ GUILD_ROLE_CREATED = "guild-role-created"
+ GUILD_ROLE_DELETED = "guild-role-deleted"
+ GUILD_ROLE_UPDATED = "guild-role-updated"
+ GUILD_UPDATED = "guild-updated"
+ LOGIN_ADDED = "login-added"
+ LOGIN_REMOVED = "login-removed"
+ LOGIN_UPDATED = "login-updated"
+ MESSAGE_CREATED = "message-created"
+ MESSAGE_DELETED = "message-deleted"
+ MESSAGE_UPDATED = "message-updated"
+ REACTION_ADDED = "reaction-added"
+ REACTION_REMOVED = "reaction-removed"
+
+
+class Event(BaseEvent, SatoriEvent):
+ __type__: EventType
+
+ @override
+ def get_type(self) -> str:
+ return ""
+
+ @override
+ def get_event_name(self) -> str:
+ return self.type
+
+ @override
+ def get_event_description(self) -> str:
+ return escape_tag(str(self.dict()))
+
+ @override
+ def get_message(self) -> Message:
+ raise ValueError("Event has no message!")
+
+ @override
+ def get_user_id(self) -> str:
+ raise ValueError("Event has no context!")
+
+ @override
+ def get_session_id(self) -> str:
+ raise ValueError("Event has no context!")
+
+ @override
+ def is_tome(self) -> bool:
+ return False
+
+
+EVENT_CLASSES: Dict[str, Type[Event]] = {}
+
+
+def register_event_class(event_class: Type[E]) -> Type[E]:
+ EVENT_CLASSES[event_class.__type__.value] = event_class
+ return event_class
+
+
+class NoticeEvent(Event):
+ @override
+ def get_type(self) -> str:
+ return "notice"
+
+
+class FriendEvent(NoticeEvent):
+ channel: Channel
+ user: User
+
+ @override
+ def get_user_id(self) -> str:
+ return self.user.id
+
+ @override
+ def get_session_id(self) -> str:
+ return self.user.id
+
+
+@register_event_class
+class FriendRequestEvent(FriendEvent):
+ __type__ = EventType.FRIEND_REQUEST
+
+
+class GuildEvent(NoticeEvent):
+ channel: Channel
+ guild: Guild
+
+ @override
+ def get_session_id(self) -> str:
+ return f"{self.guild.id}/{self.channel.id}"
+
+
+@register_event_class
+class GuildAddedEvent(GuildEvent):
+ __type__ = EventType.GUILD_ADDED
+
+
+@register_event_class
+class GuildRemovedEvent(GuildEvent):
+ __type__ = EventType.GUILD_REMOVED
+
+
+@register_event_class
+class GuildRequestEvent(GuildEvent):
+ __type__ = EventType.GUILD_REQUEST
+
+
+@register_event_class
+class GuildUpdatedEvent(GuildEvent):
+ __type__ = EventType.GUILD_UPDATED
+
+
+class GuildInnerMemberEvent(GuildEvent):
+ user: User
+
+ @override
+ def get_user_id(self) -> str:
+ return self.member.user.id if self.member else self.user.id
+
+ @override
+ def get_session_id(self) -> str:
+ return f"{self.guild.id}/{self.channel.id}/{self.get_user_id()}"
+
+
+@register_event_class
+class GuildInnerMemberAddedEvent(GuildInnerMemberEvent):
+ __type__ = EventType.GUILD_MEMBER_ADDED
+
+
+@register_event_class
+class GuildInnerMemberRemovedEvent(GuildInnerMemberEvent):
+ member: InnerMember
+ __type__ = EventType.GUILD_MEMBER_REMOVED
+
+
+@register_event_class
+class GuildInnerMemberRequestEvent(GuildInnerMemberEvent):
+ __type__ = EventType.GUILD_MEMBER_REQUEST
+
+
+@register_event_class
+class GuildInnerMemberUpdatedEvent(GuildInnerMemberEvent):
+ member: InnerMember
+ __type__ = EventType.GUILD_MEMBER_UPDATED
+
+
+class GuildRoleEvent(GuildEvent):
+ role: Role
+
+ @override
+ def get_session_id(self) -> str:
+ return f"{self.guild.id}/{self.channel.id}/{self.role.id}"
+
+
+@register_event_class
+class GuildRoleCreatedEvent(GuildRoleEvent):
+ __type__ = EventType.GUILD_ROLE_CREATED
+
+
+@register_event_class
+class GuildRoleDeletedEvent(GuildRoleEvent):
+ __type__ = EventType.GUILD_ROLE_DELETED
+
+
+@register_event_class
+class GuildRoleUpdatedEvent(GuildRoleEvent):
+ __type__ = EventType.GUILD_ROLE_UPDATED
+
+
+class LoginEvent(NoticeEvent):
+ login: InnerLogin
+ user: User
+
+ @override
+ def get_user_id(self) -> str:
+ return self.user.id
+
+ @override
+ def get_session_id(self) -> str:
+ return self.user.id
+
+
+@register_event_class
+class LoginAddedEvent(LoginEvent):
+ __type__ = EventType.LOGIN_ADDED
+
+
+@register_event_class
+class LoginRemovedEvent(LoginEvent):
+ __type__ = EventType.LOGIN_REMOVED
+
+
+@register_event_class
+class LoginUpdatedEvent(LoginEvent):
+ __type__ = EventType.LOGIN_UPDATED
+
+
+class MessageEvent(Event):
+ channel: Channel
+ user: User
+ message: SatoriMessage
+ to_me: bool = False
+ reply: Optional[RenderMessage] = None
+
+ if TYPE_CHECKING:
+ _message: Message
+ original_message: Message
+
+ @override
+ def get_type(self) -> str:
+ return "message"
+
+ @override
+ def is_tome(self) -> bool:
+ return self.to_me
+
+ @override
+ def get_message(self) -> Message:
+ return self._message
+
+ @root_validator
+ def generate_message(cls, values: Dict[str, Any]) -> Dict[str, Any]:
+ values["_message"] = Message.from_satori_element(values["message"].content)
+ values["original_message"] = deepcopy(values["_message"])
+ return values
+
+ @property
+ def msg_id(self) -> str:
+ return self.message.id
+
+ def convert(self) -> "MessageEvent":
+ raise NotImplementedError
+
+
+@register_event_class
+class MessageCreatedEvent(MessageEvent):
+ __type__ = EventType.MESSAGE_CREATED
+
+ def convert(self):
+ if self.channel.type == ChannelType.DIRECT:
+ return PrivateMessageCreatedEvent.parse_obj(self)
+ else:
+ return PublicMessageCreatedEvent.parse_obj(self)
+
+
+@register_event_class
+class MessageDeletedEvent(MessageEvent):
+ __type__ = EventType.MESSAGE_DELETED
+
+ def convert(self):
+ if self.channel.type == ChannelType.DIRECT:
+ return PrivateMessageDeletedEvent.parse_obj(self)
+ else:
+ return PublicMessageDeletedEvent.parse_obj(self)
+
+
+@register_event_class
+class MessageUpdatedEvent(MessageEvent):
+ __type__ = EventType.MESSAGE_UPDATED
+
+ def convert(self):
+ if self.channel.type == ChannelType.DIRECT:
+ return PrivateMessageUpdatedEvent.parse_obj(self)
+ else:
+ return PublicMessageUpdatedEvent.parse_obj(self)
+
+
+class PrivateMessageEvent(MessageEvent):
+ @override
+ def is_tome(self) -> bool:
+ return True
+
+ @override
+ def get_session_id(self) -> str:
+ return self.channel.id
+
+ @override
+ def get_user_id(self) -> str:
+ return self.user.id
+
+
+class PublicMessageEvent(MessageEvent):
+ guild: Guild
+ member: InnerMember
+
+ @override
+ def get_session_id(self) -> str:
+ return f"{self.guild.id}/{self.channel.id}/{self.user.id}"
+
+ @override
+ def get_user_id(self) -> str:
+ return self.user.id
+
+
+class PrivateMessageCreatedEvent(MessageCreatedEvent, PrivateMessageEvent):
+ @override
+ def get_event_description(self) -> str:
+ return escape_tag(
+ f"Message {self.msg_id} from "
+ f"{self.user.name}({self.channel.id}): {self.get_message()!r}"
+ )
+
+
+class PublicMessageCreatedEvent(MessageCreatedEvent, PublicMessageEvent):
+ @override
+ def get_event_description(self) -> str:
+ return escape_tag(
+ f"Message {self.msg_id} from "
+ f"{self.user.name or ''}({self.channel.id})"
+ f"@[{self.channel.name}:{self.guild.id}/{self.channel.id}]"
+ f": {self.get_message()!r}"
+ )
+
+
+class PrivateMessageDeletedEvent(MessageDeletedEvent, PrivateMessageEvent):
+ @override
+ def get_event_description(self) -> str:
+ return escape_tag(
+ f"Message {self.msg_id} from "
+ f"{self.user.name}({self.channel.id}) deleted"
+ )
+
+
+class PublicMessageDeletedEvent(MessageDeletedEvent, PublicMessageEvent):
+ @override
+ def get_event_description(self) -> str:
+ return escape_tag(
+ f"Message {self.msg_id} from "
+ f"{self.user.name or ''}({self.channel.id})"
+ f"@[{self.channel.name}:{self.guild.id}/{self.channel.id}] deleted"
+ )
+
+
+class PrivateMessageUpdatedEvent(MessageUpdatedEvent, PrivateMessageEvent):
+ @override
+ def get_event_description(self) -> str:
+ return escape_tag(
+ f"Message {self.msg_id} from "
+ f"{self.user.name}({self.channel.id}) updated"
+ f": {self.get_message()!r}"
+ )
+
+
+class PublicMessageUpdatedEvent(MessageUpdatedEvent, PublicMessageEvent):
+ @override
+ def get_event_description(self) -> str:
+ return escape_tag(
+ f"Message {self.msg_id} from "
+ f"{self.user.name or ''}({self.channel.id})"
+ f"@[{self.channel.name}:{self.guild.id}/{self.channel.id}] updated"
+ f": {self.get_message()!r}"
+ )
+
+
+class ReactionEvent(NoticeEvent):
+ channel: Channel
+ user: User
+ message: SatoriMessage
+
+ if TYPE_CHECKING:
+ _message: Message
+
+ @override
+ def get_user_id(self) -> str:
+ return self.user.id
+
+ @override
+ def get_session_id(self) -> str:
+ return f"{self.channel.id}/{self.user.id}"
+
+ @root_validator
+ def generate_message(cls, values: Dict[str, Any]) -> Dict[str, Any]:
+ values["_message"] = Message.from_satori_element(values["message"]["content"])
+ return values
+
+ @property
+ def msg_id(self) -> str:
+ return self.message.id
+
+
+@register_event_class
+class ReactionAddedEvent(ReactionEvent):
+ __type__ = EventType.REACTION_ADDED
+
+ @override
+ def get_event_description(self) -> str:
+ return escape_tag(
+ f"Reaction added to {self.msg_id} by {self.user.name}({self.channel.id})"
+ )
+
+
+@register_event_class
+class ReactionRemovedEvent(ReactionEvent):
+ __type__ = EventType.REACTION_REMOVED
+
+ @override
+ def get_event_description(self) -> str:
+ return escape_tag(f"Reaction removed from {self.msg_id}")
diff --git a/nonebot/adapters/satori/exception.py b/nonebot/adapters/satori/exception.py
new file mode 100644
index 0000000..5d6e9aa
--- /dev/null
+++ b/nonebot/adapters/satori/exception.py
@@ -0,0 +1,75 @@
+import json
+from typing import Optional
+
+from nonebot.drivers import Response
+from nonebot.exception import AdapterException
+from nonebot.exception import ActionFailed as BaseActionFailed
+from nonebot.exception import NetworkError as BaseNetworkError
+from nonebot.exception import ApiNotAvailable as BaseApiNotAvailable
+
+
+class SatoriAdapterException(AdapterException):
+ def __init__(self):
+ super().__init__("satori")
+
+
+class ActionFailed(BaseActionFailed, SatoriAdapterException):
+ def __init__(self, response: Response):
+ self.status_code: int = response.status_code
+ self.code: Optional[int] = None
+ self.message: Optional[str] = None
+ self.data: Optional[dict] = None
+ if response.content:
+ body = json.loads(response.content)
+ self._prepare_body(body)
+
+ def __repr__(self) -> str:
+ return (
+ f""
+ )
+
+ def __str__(self):
+ return self.__repr__()
+
+ def _prepare_body(self, body: dict):
+ self.code = body.get("code", None)
+ self.message = body.get("message", None)
+ self.data = body.get("data", None)
+
+
+class BadRequestException(ActionFailed):
+ pass
+
+
+class UnauthorizedException(ActionFailed):
+ pass
+
+
+class ForbiddenException(ActionFailed):
+ pass
+
+
+class NotFoundException(ActionFailed):
+ pass
+
+
+class MethodNotAllowedException(ActionFailed):
+ pass
+
+
+class NetworkError(BaseNetworkError, SatoriAdapterException):
+ def __init__(self, msg: Optional[str] = None):
+ super().__init__()
+ self.msg: Optional[str] = msg
+ """错误原因"""
+
+ def __repr__(self):
+ return f""
+
+ def __str__(self):
+ return self.__repr__()
+
+
+class ApiNotAvailable(BaseApiNotAvailable, SatoriAdapterException):
+ pass
diff --git a/nonebot/adapters/satori/message.py b/nonebot/adapters/satori/message.py
new file mode 100644
index 0000000..2d95730
--- /dev/null
+++ b/nonebot/adapters/satori/message.py
@@ -0,0 +1,422 @@
+from dataclasses import field, dataclass
+from typing_extensions import NotRequired, override
+from typing import Any, List, Type, Union, Iterable, Optional, TypedDict
+
+from nonebot.adapters import Message as BaseMessage
+from nonebot.adapters import MessageSegment as BaseMessageSegment
+
+from .utils import Element, parse, escape
+
+
+class MessageSegment(BaseMessageSegment["Message"]):
+ def __str__(self) -> str:
+ def _attr(key: str, value: Any):
+ if value is True:
+ return key
+ if value is False:
+ return f"no-{key}"
+ if isinstance(value, (int, float)):
+ return f"{key}={value}"
+ return f'{key}="{escape(str(value))}"'
+
+ attrs = " ".join(_attr(k, v) for k, v in self.data.items())
+ return f"<{self.type} {attrs} />"
+
+ @classmethod
+ @override
+ def get_message_class(cls) -> Type["Message"]:
+ return Message
+
+ @staticmethod
+ def text(content: str) -> "Text":
+ return Text("text", {"text": content})
+
+ @staticmethod
+ def entity(content: str, style: str) -> "Entity":
+ return Entity("entity", {"text": content, "style": style})
+
+ @staticmethod
+ def at(
+ user_id: str,
+ name: Optional[str] = None,
+ ) -> "At":
+ data = {"id": user_id}
+ if name:
+ data["name"] = name
+ return At("at", data)
+
+ @staticmethod
+ def at_role(
+ role: str,
+ name: Optional[str] = None,
+ ) -> "At":
+ data = {"role": role}
+ if name:
+ data["name"] = name
+ return At("at_role", data)
+
+ @staticmethod
+ def at_all(here: bool = False) -> "At":
+ return At("at", {"type": "here" if here else "all"})
+
+ @staticmethod
+ def sharp(channel_id: str, name: Optional[str] = None) -> "Sharp":
+ data = {"id": channel_id}
+ if name:
+ data["name"] = name
+ return Sharp("sharp", data)
+
+ @staticmethod
+ def link(href: str) -> "Link":
+ return Link("link", {"href": href})
+
+ @staticmethod
+ def image(
+ src: str, cache: Optional[bool] = None, timeout: Optional[str] = None
+ ) -> "Image":
+ data = {"src": src}
+ if cache is not None:
+ data["cache"] = cache
+ if timeout is not None:
+ data["timeout"] = timeout
+ return Image("img", data)
+
+ @staticmethod
+ def audio(
+ src: str, cache: Optional[bool] = None, timeout: Optional[str] = None
+ ) -> "Audio":
+ data = {"src": src}
+ if cache is not None:
+ data["cache"] = cache
+ if timeout is not None:
+ data["timeout"] = timeout
+ return Audio("audio", data)
+
+ @staticmethod
+ def video(
+ src: str, cache: Optional[bool] = None, timeout: Optional[str] = None
+ ) -> "Video":
+ data = {"src": src}
+ if cache is not None:
+ data["cache"] = cache
+ if timeout is not None:
+ data["timeout"] = timeout
+ return Video("video", data)
+
+ @staticmethod
+ def file(
+ src: str, cache: Optional[bool] = None, timeout: Optional[str] = None
+ ) -> "File":
+ data = {"src": src}
+ if cache is not None:
+ data["cache"] = cache
+ if timeout is not None:
+ data["timeout"] = timeout
+ return File("file", data)
+
+ @staticmethod
+ def br() -> "Br":
+ return Br("br", {})
+
+ @staticmethod
+ def paragraph(text: str) -> "Paragraph":
+ return Paragraph("paragraph", {"text": text})
+
+ @staticmethod
+ def message(
+ mid: Optional[str] = None,
+ forward: Optional[bool] = None,
+ content: Optional["Message"] = None,
+ ) -> "RenderMessage":
+ data = {}
+ if mid:
+ data["id"] = mid
+ if forward is not None:
+ data["forward"] = forward
+ if content:
+ data["content"] = content
+ return RenderMessage("message", data)
+
+ @staticmethod
+ def quote(
+ mid: str,
+ forward: Optional[bool] = None,
+ content: Optional["Message"] = None,
+ ) -> "RenderMessage":
+ data = {"id": mid}
+ if forward is not None:
+ data["forward"] = forward
+ if content:
+ data["content"] = content
+ return RenderMessage("quote", data)
+
+ @staticmethod
+ def author(
+ user_id: str,
+ nickname: Optional[str] = None,
+ avatar: Optional[str] = None,
+ ) -> "Author":
+ data = {"id": user_id}
+ if nickname:
+ data["nickname"] = nickname
+ if avatar:
+ data["avatar"] = avatar
+ return Author("author", data)
+
+ @override
+ def is_text(self) -> bool:
+ return self.type == "text"
+
+
+class TextData(TypedDict):
+ text: str
+
+
+@dataclass
+class Text(MessageSegment):
+ data: TextData = field(default_factory=dict)
+
+ @override
+ def __str__(self) -> str:
+ return escape(self.data["text"])
+
+
+class EntityData(TypedDict):
+ text: str
+ style: str
+
+
+@dataclass
+class Entity(MessageSegment):
+ data: EntityData = field(default_factory=dict)
+
+ @override
+ def __str__(self) -> str:
+ style = self.data["style"]
+ return f'<{style}>{escape(self.data["text"])}{style}>'
+
+
+class AtData(TypedDict):
+ id: NotRequired[str]
+ name: NotRequired[str]
+ role: NotRequired[str]
+ type: NotRequired[str]
+
+
+@dataclass
+class At(MessageSegment):
+ data: AtData = field(default_factory=dict)
+
+
+class SharpData(TypedDict):
+ id: str
+ name: NotRequired[str]
+
+
+@dataclass
+class Sharp(MessageSegment):
+ data: SharpData = field(default_factory=dict)
+
+
+class LinkData(TypedDict):
+ href: str
+
+
+@dataclass
+class Link(MessageSegment):
+ data: LinkData = field(default_factory=dict)
+
+ @override
+ def __str__(self):
+ return f''
+
+
+class ImageData(TypedDict):
+ src: str
+ cache: NotRequired[bool]
+ timeout: NotRequired[str]
+ width: NotRequired[int]
+ height: NotRequired[int]
+
+
+@dataclass
+class Image(MessageSegment):
+ data: ImageData = field(default_factory=dict)
+
+
+class AudioData(TypedDict):
+ src: str
+ cache: NotRequired[bool]
+ timeout: NotRequired[str]
+
+
+@dataclass
+class Audio(MessageSegment):
+ data: AudioData = field(default_factory=dict)
+
+
+class VideoData(TypedDict):
+ src: str
+ cache: NotRequired[bool]
+ timeout: NotRequired[str]
+
+
+@dataclass
+class Video(MessageSegment):
+ data: VideoData = field(default_factory=dict)
+
+
+class FileData(TypedDict):
+ src: str
+ cache: NotRequired[bool]
+ timeout: NotRequired[str]
+
+
+@dataclass
+class File(MessageSegment):
+ data: FileData = field(default_factory=dict)
+
+
+@dataclass
+class Br(MessageSegment):
+ @override
+ def __str__(self):
+ return "
"
+
+
+class ParagraphData(TypedDict):
+ text: str
+
+
+@dataclass
+class Paragraph(MessageSegment):
+ data: ParagraphData = field(default_factory=dict)
+
+ @override
+ def __str__(self):
+ return f'{escape(self.data["text"])}
'
+
+
+class RenderMessageData(TypedDict):
+ id: NotRequired[str]
+ forward: NotRequired[bool]
+ content: NotRequired["Message"]
+
+
+@dataclass
+class RenderMessage(MessageSegment):
+ data: RenderMessageData = field(default_factory=dict)
+
+ @override
+ def __str__(self):
+ attr = []
+ if "id" in self.data:
+ attr.append(f'id="{escape(self.data["id"])}"')
+ if self.data.get("forward"):
+ attr.append("forward")
+ if "content" not in self.data:
+ return f'<{self.type} {" ".join(attr)} />'
+ else:
+ return f'<{self.type} {" ".join(attr)}>{self.data["content"]}{self.type}>'
+
+
+class AuthorData(TypedDict):
+ id: str
+ nickname: NotRequired[str]
+ avatar: NotRequired[str]
+
+
+@dataclass
+class Author(MessageSegment):
+ data: AuthorData = field(default_factory=dict)
+
+
+ELEMENT_TYPE_MAP = {
+ "text": (Text, "text"),
+ "at": (At, "at"),
+ "sharp": (Sharp, "sharp"),
+ "a": (Link, "link"),
+ "link": (Link, "link"),
+ "img": (Image, "img"),
+ "image": (Image, "img"),
+ "audio": (Audio, "audio"),
+ "video": (Video, "video"),
+ "file": (File, "file"),
+ "br": (Br, "br"),
+ "author": (Author, "author"),
+}
+
+
+class Message(BaseMessage[MessageSegment]):
+ @classmethod
+ @override
+ def get_segment_class(cls) -> Type[MessageSegment]:
+ return MessageSegment
+
+ @override
+ def __add__(
+ self, other: Union[str, MessageSegment, Iterable[MessageSegment]]
+ ) -> "Message":
+ return super().__add__(
+ MessageSegment.text(other) if isinstance(other, str) else other
+ )
+
+ @override
+ def __radd__(
+ self, other: Union[str, MessageSegment, Iterable[MessageSegment]]
+ ) -> "Message":
+ return super().__radd__(
+ MessageSegment.text(other) if isinstance(other, str) else other
+ )
+
+ @staticmethod
+ @override
+ def _construct(msg: str) -> Iterable[MessageSegment]:
+ yield from Message.from_satori_element(parse(msg))
+
+ @classmethod
+ def from_satori_element(cls, elements: List[Element]) -> "Message":
+ msg = Message()
+ for elem in elements:
+ if elem.type in ELEMENT_TYPE_MAP:
+ seg_cls, seg_type = ELEMENT_TYPE_MAP[elem.type]
+ msg.append(seg_cls(seg_type, elem.attrs.copy()))
+ elif elem.type in {
+ "b",
+ "strong",
+ "i",
+ "em",
+ "u",
+ "ins",
+ "s",
+ "del",
+ "spl",
+ "code",
+ "sup",
+ "sub",
+ }:
+ msg.append(
+ Entity(
+ "entity",
+ {"text": elem.children[0].attrs["text"], "style": elem.type},
+ )
+ )
+ elif elem.type in ("p", "paragraph"):
+ msg.append(
+ Paragraph("paragraph", {"text": elem.children[0].attrs["text"]})
+ )
+ elif elem.type in ("message", "quote"):
+ data = elem.attrs.copy()
+ if elem.children:
+ data["content"] = Message.from_satori_element(elem.children)
+ msg.append(RenderMessage(elem.type, data))
+ else:
+ msg.append(Text("text", {"text": str(elem)}))
+ return msg
+
+ def extract_content(self) -> str:
+ return "".join(
+ str(seg)
+ for seg in self
+ if seg.type in ("text", "entity", "at", "sharp", "link")
+ )
diff --git a/nonebot/adapters/satori/models.py b/nonebot/adapters/satori/models.py
new file mode 100644
index 0000000..9123668
--- /dev/null
+++ b/nonebot/adapters/satori/models.py
@@ -0,0 +1,217 @@
+from enum import IntEnum
+from datetime import datetime
+from typing import Any, Dict, List, Union, Literal, Optional
+
+from pydantic import Field, BaseModel, validator
+
+from .utils import Element, parse
+
+
+class ChannelType(IntEnum):
+ TEXT = 0
+ VOICE = 1
+ CATEGORY = 2
+ DIRECT = 3
+
+
+class Channel(BaseModel):
+ id: str
+ name: str
+ type: ChannelType
+ parent_id: Optional[str] = None
+
+
+class Guild(BaseModel):
+ id: str
+ name: str
+ avatar: Optional[str] = None
+
+
+class User(BaseModel):
+ id: str
+ name: Optional[str] = None
+ avatar: Optional[str] = None
+ is_bot: Optional[bool] = None
+
+
+class InnerMember(BaseModel):
+ user: Optional[User] = None
+ name: Optional[str] = None
+ avatar: Optional[str] = None
+ joined_at: Optional[datetime] = None
+
+ @validator("joined_at", pre=True)
+ def parse_joined_at(cls, v):
+ if v is None:
+ return None
+ if isinstance(v, datetime):
+ return v
+ try:
+ timestamp = int(v)
+ except ValueError:
+ raise ValueError(f"invalid timestamp: {v}")
+ return datetime.fromtimestamp(timestamp / 1000)
+
+
+class OuterMember(InnerMember):
+ user: User
+ joined_at: datetime
+
+
+class Role(BaseModel):
+ id: str
+ name: str
+
+
+class LoginStatus(IntEnum):
+ OFFLINE = 0
+ ONLINE = 1
+ CONNECT = 2
+ DISCONNECT = 3
+ RECONNECT = 4
+
+
+class InnerLogin(BaseModel):
+ user: Optional[User] = None
+ self_id: Optional[str] = None
+ platform: Optional[str] = None
+ status: LoginStatus
+
+
+class OuterLogin(InnerLogin):
+ user: User
+ self_id: str
+ platform: str
+
+
+class Opcode(IntEnum):
+ EVENT = 0
+ PING = 1
+ PONG = 2
+ IDENTIFY = 3
+ READY = 4
+
+
+class Payload(BaseModel):
+ op: Opcode = Field(...)
+ body: Optional[Dict[str, Any]] = Field(None)
+
+
+class Identify(BaseModel):
+ token: Optional[str] = None
+ sequence: Optional[int] = None
+
+
+class Ready(BaseModel):
+ logins: List[OuterLogin]
+
+
+class IdentifyPayload(Payload):
+ op: Literal[Opcode.IDENTIFY] = Field(Opcode.IDENTIFY)
+ body: Identify
+
+
+class ReadyPayload(Payload):
+ op: Literal[Opcode.READY] = Field(Opcode.READY)
+ body: Ready
+
+
+class PingPayload(Payload):
+ op: Literal[Opcode.PING] = Field(Opcode.PING)
+
+
+class PongPayload(Payload):
+ op: Literal[Opcode.PONG] = Field(Opcode.PONG)
+
+
+class InnerMessage(BaseModel):
+ id: str
+ content: Optional[List[Element]] = None
+ channel: Optional[Channel] = None
+ guild: Optional[Guild] = None
+ member: Optional[InnerMember] = None
+ user: Optional[User] = None
+ created_at: Optional[datetime] = None
+ updated_at: Optional[datetime] = None
+
+ @validator("content", pre=True)
+ def parse_content(cls, v):
+ if isinstance(v, list):
+ return v
+ if v is None:
+ return None
+ if not isinstance(v, str):
+ raise ValueError("content must be str")
+ return parse(v)
+
+ @validator("created_at", pre=True)
+ def parse_created_at(cls, v):
+ if v is None:
+ return None
+ if isinstance(v, datetime):
+ return v
+ try:
+ timestamp = int(v)
+ except ValueError:
+ raise ValueError(f"invalid timestamp: {v}")
+ return datetime.fromtimestamp(timestamp / 1000)
+
+ @validator("updated_at", pre=True)
+ def parse_updated_at(cls, v):
+ if v is None:
+ return None
+ if isinstance(v, datetime):
+ return v
+ try:
+ timestamp = int(v)
+ except ValueError:
+ raise ValueError(f"invalid timestamp: {v}")
+ return datetime.fromtimestamp(timestamp / 1000)
+
+
+class OuterMessage(InnerMessage):
+ channel: Channel
+ guild: Guild
+ member: InnerMember
+ user: User
+ created_at: datetime
+ updated_at: datetime
+
+
+class Event(BaseModel):
+ id: int
+ type: str
+ platform: str
+ self_id: str
+ timestamp: datetime
+ channel: Optional[Channel] = None
+ guild: Optional[Guild] = None
+ login: Optional[InnerLogin] = None
+ member: Optional[InnerMember] = None
+ message: Optional[InnerMessage] = None
+ operator: Optional[User] = None
+ role: Optional[Role] = None
+ user: Optional[User] = None
+
+ @validator("timestamp", pre=True)
+ def parse_timestamp(cls, v):
+ if v is None:
+ return None
+ if isinstance(v, datetime):
+ return v
+ try:
+ timestamp = int(v)
+ except ValueError:
+ raise ValueError(f"invalid timestamp: {v}")
+ return datetime.fromtimestamp(timestamp / 1000)
+
+
+class EventPayload(Payload):
+ op: Literal[Opcode.EVENT] = Field(Opcode.EVENT)
+ body: Event
+
+
+PayloadType = Union[
+ Union[EventPayload, PingPayload, PongPayload, IdentifyPayload, ReadyPayload],
+ Payload,
+]
diff --git a/nonebot/adapters/satori/utils.py b/nonebot/adapters/satori/utils.py
new file mode 100644
index 0000000..3d0bfac
--- /dev/null
+++ b/nonebot/adapters/satori/utils.py
@@ -0,0 +1,186 @@
+import re
+from functools import partial
+from typing_extensions import ParamSpec, Concatenate
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Dict,
+ List,
+ Type,
+ Union,
+ Generic,
+ TypeVar,
+ Callable,
+ Optional,
+ Awaitable,
+ overload,
+)
+
+from pydantic import Field, BaseModel
+from nonebot.utils import logger_wrapper
+
+if TYPE_CHECKING:
+ from .bot import Bot
+
+B = TypeVar("B", bound="Bot")
+R = TypeVar("R")
+P = ParamSpec("P")
+log = logger_wrapper("Satori")
+
+
+def escape(text: str) -> str:
+ return (
+ text.replace('"', """)
+ .replace("&", "&")
+ .replace("<", "<")
+ .replace(">", ">")
+ )
+
+
+def unescape(text: str) -> str:
+ return (
+ text.replace(""", '"')
+ .replace("&", "&")
+ .replace("<", "<")
+ .replace(">", ">")
+ )
+
+
+class Element(BaseModel):
+ type: str
+ attrs: Dict[str, Any] = Field(default_factory=dict)
+ children: List["Element"] = Field(default_factory=list)
+ source: Optional[str] = None
+
+ def __str__(self):
+ if self.source:
+ return self.source
+ if self.type == "text":
+ return escape(self.attrs["text"])
+
+ def _attr(key: str, value: Any):
+ if value is True:
+ return key
+ if value is False:
+ return f"no-{key}"
+ if isinstance(value, (int, float)):
+ return f"{key}={value}"
+ return f'{key}="{escape(str(value))}"'
+
+ attrs = " ".join(_attr(k, v) for k, v in self.attrs.items())
+ if not self.children:
+ return f"<{self.type} {attrs}/>"
+ children = "".join(str(c) for c in self.children)
+ return f"<{self.type} {attrs}>{children}{self.type}>"
+
+
+tag_pat = re.compile(r"|<(/?)([^!\s>/]*)([^>]*?)\s*(/?)>")
+attr_pat = re.compile(r"([^\s=]+)(?:=\"([^\"]*)\"|='([^']*)')?", re.S)
+
+
+class Token(BaseModel):
+ type: str
+ close: str
+ empty: str
+ attrs: Dict[str, Any]
+ source: str
+
+
+def parse(src: str):
+ tokens: List[Union[Token, Element]] = []
+
+ def push_text(text: str):
+ if text:
+ tokens.append(Element(type="text", attrs={"text": text}))
+
+ def parse_content(source: str):
+ push_text(unescape(source))
+
+ while tag_map := tag_pat.search(src):
+ parse_content(src[: tag_map.start()])
+ src = src[tag_map.end() :]
+ if tag_map.group(0).startswith("