diff --git a/nonebot/adapters/satori/adapter.py b/nonebot/adapters/satori/adapter.py index 125486b..a03ce3c 100644 --- a/nonebot/adapters/satori/adapter.py +++ b/nonebot/adapters/satori/adapter.py @@ -104,7 +104,7 @@ def payload_to_json(payload: Payload) -> str: async def receive_payload(self, info: ClientInfo, ws: WebSocket) -> Payload: payload = type_validate_python(PayloadType, json.loads(await ws.receive())) if isinstance(payload, EventPayload): - self.sequences[info.identity] = payload.body.id + self.sequences[info.identity] = payload.body["id"] return payload async def _authenticate(self, info: ClientInfo, ws: WebSocket) -> Optional[Literal[True]]: @@ -136,12 +136,12 @@ async def _authenticate(self, info: ClientInfo, ws: WebSocket) -> Optional[Liter ) return for login in resp.body.logins: - if not login.self_id: + if not login.id: continue if login.status != LoginStatus.ONLINE: continue - if login.self_id not in self.bots: - bot = Bot(self, login.self_id, login, info) + if login.id not in self.bots: + bot = Bot(self, login.id, login, info) self._bots[info.identity].add(bot.self_id) self.bot_connect(bot) log( @@ -149,8 +149,8 @@ async def _authenticate(self, info: ClientInfo, ws: WebSocket) -> Optional[Liter f"Bot {escape_tag(bot.self_id)} connected", ) else: - self._bots[info.identity].add(login.self_id) - bot = self.bots[login.self_id] + self._bots[info.identity].add(login.id) + bot = self.bots[login.id] bot._update(login) if not self.bots: log("WARNING", "No bots connected!") @@ -230,7 +230,7 @@ async def _loop(self, info: ClientInfo, ws: WebSocket): ) if isinstance(payload, EventPayload): try: - event = self.payload_to_event(payload.body) + event = self.payload_to_event(type_validate_python(Event, payload.body)) except Exception as e: log( "WARNING", @@ -265,7 +265,7 @@ async def _loop(self, info: ClientInfo, ws: WebSocket): continue if isinstance(event, (MessageEvent, InteractionEvent)): event = event.convert() - asyncio.create_task(bot.handle_event(event)) + _t = asyncio.create_task(bot.handle_event(event)) elif isinstance(payload, PongPayload): log("TRACE", "Pong") continue diff --git a/nonebot/adapters/satori/bot.py b/nonebot/adapters/satori/bot.py index bcb9012..2179742 100644 --- a/nonebot/adapters/satori/bot.py +++ b/nonebot/adapters/satori/bot.py @@ -3,6 +3,7 @@ from typing_extensions import override from typing import TYPE_CHECKING, Any, Union, Literal, Optional, overload +from yarl import URL from nonebot.message import handle_event from nonebot.drivers import Request, Response from nonebot.compat import model_dump, type_validate_python @@ -16,7 +17,7 @@ from .event import Event, MessageEvent from .models import MessageObject as SatoriMessage from .message import Author, Message, RenderMessage, MessageSegment -from .models import Role, User, Guild, Login, Order, Member, Upload, Channel, Direction, PageResult +from .models import Role, User, Guild, Login, Order, Member, Upload, Channel, Direction, LoginType, PageResult from .exception import ( ActionFailed, NetworkError, @@ -146,7 +147,7 @@ class Bot(BaseBot): adapter: "Adapter" @override - def __init__(self, adapter: "Adapter", self_id: str, login: Login, info: ClientInfo): + def __init__(self, adapter: "Adapter", self_id: str, login: LoginType, info: ClientInfo): super().__init__(adapter, self_id) # Bot 配置信息 @@ -181,7 +182,7 @@ def self_info(self) -> User: raise RuntimeError(f"Bot {self.self_id} of {self.platform} is not connected!") return self._self_info.user - def _update(self, login: Login) -> None: + def _update(self, login: LoginType) -> None: self._self_info = login def get_authorization_header(self) -> dict[str, str]: @@ -190,6 +191,8 @@ def get_authorization_header(self) -> dict[str, str]: "Authorization": f"Bearer {self.info.token}", "X-Self-ID": self.self_id, "X-Platform": self.platform, + "Satori-Platform": self.platform, + "Satori-Login-ID": self.self_id, } if not self.info.token: del header["Authorization"] @@ -238,9 +241,23 @@ async def _request(self, request: Request) -> Any: return self._handle_response(response) + def ensure_url(self, url: str) -> URL: + """确定链接形式。 + + 若链接符合以下条件之一,则返回链接的代理形式 ({host}/{path}/{version}/proxy/{url}): + - 链接以 "upload://" 开头 + - 链接开头出现在 self_info.proxy_urls 中的某一项 + """ + if url.startswith("upload"): + return self.info.api_base / "proxy" / url.lstrip("/") + for proxy_url in self._self_info.proxy_urls: + if url.startswith(proxy_url): + return self.info.api_base / "proxy" / url.lstrip("/") + return URL(url) + async def download(self, url: str) -> bytes: """访问内部链接。""" - request = Request("GET", self.info.api_base / "proxy" / url.lstrip("/")) + request = Request("GET", self.ensure_url(url)) try: response = await self.adapter.request(request) except Exception as e: diff --git a/nonebot/adapters/satori/models.py b/nonebot/adapters/satori/models.py index 5470092..d8fb93c 100644 --- a/nonebot/adapters/satori/models.py +++ b/nonebot/adapters/satori/models.py @@ -121,6 +121,30 @@ class Login(BaseModel): features: list[str] = Field(default_factory=list) proxy_urls: list[str] = Field(default_factory=list) + @property + def id(self) -> Optional[str]: + return self.self_id or (self.user.id if self.user else None) + + if PYDANTIC_V2: + model_config: ConfigDict = ConfigDict(extra="allow") # type: ignore + + else: + + class Config: + extra = "allow" + + +class LoginPreview(BaseModel): + user: User + platform: str + status: Optional[LoginStatus] = None + features: list[str] = Field(default_factory=list) + proxy_urls: list[str] = Field(default_factory=list) + + @property + def id(self) -> str: + return self.user.id + if PYDANTIC_V2: model_config: ConfigDict = ConfigDict(extra="allow") # type: ignore @@ -130,6 +154,9 @@ class Config: extra = "allow" +LoginType = Union[Login, LoginPreview] + + class ArgvInteraction(BaseModel): name: str arguments: list @@ -159,7 +186,7 @@ class Identify(BaseModel): class Ready(BaseModel): - logins: list[Login] + logins: list[LoginType] class IdentifyPayload(Payload): @@ -243,7 +270,7 @@ class Event(BaseModel): button: Optional[ButtonInteraction] = None channel: Optional[Channel] = None guild: Optional[Guild] = None - login: Optional[Login] = None + login: Optional[LoginType] = None member: Optional[Member] = None message: Optional[MessageObject] = None operator: Optional[User] = None @@ -262,6 +289,25 @@ def parse_timestamp(cls, v): raise ValueError(f"invalid timestamp: {v}") from e return datetime.fromtimestamp(timestamp / 1000) + @model_validator(mode="before") + def ensure_login(cls, values): + if "self_id" not in values and "platform" not in values: + log( + "WARNING", + "received event without `self_id` and `platform`, " + "this may be caused by Satori Server used protocol version 1.2.", + ) + if "login" in values: + values["self_id"] = values["login"]["user"]["id"] + values["platform"] = values["login"]["platform"] + return values + log( + "WARNING", + "received event without login, " "this may be caused by a bug of Satori Server.", + ) + return values + return values + if PYDANTIC_V2: model_config: ConfigDict = ConfigDict(extra="allow") # type: ignore @@ -273,7 +319,7 @@ class Config: class EventPayload(Payload): op: Literal[Opcode.EVENT] = Field(Opcode.EVENT) - body: Event + body: dict PayloadType = Union[ diff --git a/pyproject.toml b/pyproject.toml index fe165e7..a179e44 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "nonebot-adapter-satori" -version = "0.12.6" +version = "0.13.0rc1" description = "Satori Protocol Adapter for Nonebot2" authors = [ {name = "RF-Tar-Railt",email = "rf_tar_railt@qq.com"}, diff --git a/tests/test_connection.py b/tests/test_connection.py index 316d2d6..e0292c8 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -12,9 +12,6 @@ async def test_ws(app: App): adapter: Adapter = nonebot.get_adapter(Adapter) - for client in adapter.satori_config.satori_clients: - adapter.tasks.append(asyncio.create_task(adapter.ws(client))) - @ws_handlers.put def identify(json: dict) -> dict: assert json["op"] == 3 @@ -36,6 +33,14 @@ def identify(json: dict) -> dict: }, } + @ws_handlers.put + def _ping(json: dict) -> dict: + assert json == {"op": 1} + return {"op": 2} + + for client in adapter.satori_config.satori_clients: + adapter.tasks.append(asyncio.create_task(adapter.ws(client))) + await asyncio.sleep(5) bots = nonebot.get_bots() assert "0" in bots