From 7a7ab984b438e83e814f340c96c1b371b9160100 Mon Sep 17 00:00:00 2001 From: rf_tar_railt <3165388245@qq.com> Date: Thu, 13 Jun 2024 11:53:07 +0800 Subject: [PATCH] :sparkles: version 0.12.0 impl Satori Protocol V1.1 --- nonebot/adapters/satori/adapter.py | 21 ++-- nonebot/adapters/satori/bot.py | 161 +++++++++++++++++++-------- nonebot/adapters/satori/config.py | 4 +- nonebot/adapters/satori/element.py | 51 ++++----- nonebot/adapters/satori/event.py | 20 ++-- nonebot/adapters/satori/exception.py | 4 +- nonebot/adapters/satori/message.py | 17 ++- nonebot/adapters/satori/models.py | 50 +++++++-- nonebot/adapters/satori/utils.py | 11 +- pdm.lock | 36 +++--- pyproject.toml | 14 +-- tests/conftest.py | 6 +- tests/fake_server.py | 4 +- tests/test_adapter.py | 10 +- tests/test_message.py | 13 +-- 15 files changed, 257 insertions(+), 165 deletions(-) diff --git a/nonebot/adapters/satori/adapter.py b/nonebot/adapters/satori/adapter.py index 183ffc8..81dcf11 100644 --- a/nonebot/adapters/satori/adapter.py +++ b/nonebot/adapters/satori/adapter.py @@ -1,7 +1,7 @@ import json import asyncio from typing_extensions import override -from typing import Any, Dict, List, Literal, Optional +from typing import Any, Literal, Optional from nonebot.utils import escape_tag from nonebot.exception import WebSocketClosed @@ -40,15 +40,15 @@ class Adapter(BaseAdapter): - bots: Dict[str, Bot] + bots: dict[str, Bot] @override def __init__(self, driver: Driver, **kwargs: Any): super().__init__(driver, **kwargs) # 读取适配器所需的配置项 self.satori_config: Config = get_plugin_config(Config) - self.tasks: List[asyncio.Task] = [] # 存储 ws 任务 - self.sequences: Dict[str, int] = {} # 存储 连接序列号 + self.tasks: list[asyncio.Task] = [] # 存储 ws 任务 + self.sequences: dict[str, int] = {} # 存储 连接序列号 self.setup() @classmethod @@ -136,7 +136,7 @@ async def _authenticate(self, info: ClientInfo, ws: WebSocket) -> Optional[Liter if login.status != LoginStatus.ONLINE: continue if login.self_id not in self.bots: - bot = Bot(self, login.self_id, login.platform or "satori", info) + bot = Bot(self, login.self_id, login, info) self.bot_connect(bot) log( "INFO", @@ -144,8 +144,7 @@ async def _authenticate(self, info: ClientInfo, ws: WebSocket) -> Optional[Liter ) else: bot = self.bots[login.self_id] - if login.user: - bot.on_ready(login.user) + bot._update(login) if not self.bots: log("WARNING", "No bots connected!") return @@ -232,9 +231,7 @@ async def _loop(self, info: ClientInfo, ws: WebSocket): ) else: if isinstance(event, LoginAddedEvent): - bot = Bot(self, event.self_id, event.platform, info) - if event.user: - bot.on_ready(event.user) + bot = Bot(self, event.self_id, event.login, info) self.bot_connect(bot) log( "INFO", @@ -247,8 +244,8 @@ async def _loop(self, info: ClientInfo, ws: WebSocket): f"Bot {escape_tag(event.self_id)} disconnected", ) continue - elif isinstance(event, LoginUpdatedEvent) and event.user: - self.bots[event.self_id].on_ready(event.user) + elif isinstance(event, LoginUpdatedEvent): + self.bots[event.self_id]._update(event.login) if not (bot := self.bots.get(event.self_id)): log( "WARNING", diff --git a/nonebot/adapters/satori/bot.py b/nonebot/adapters/satori/bot.py index 1113073..73c8dbf 100644 --- a/nonebot/adapters/satori/bot.py +++ b/nonebot/adapters/satori/bot.py @@ -1,7 +1,7 @@ import re import json from typing_extensions import override -from typing import TYPE_CHECKING, Any, Dict, List, Union, Optional +from typing import TYPE_CHECKING, Any, Union, Literal, Optional, overload from nonebot.message import handle_event from nonebot.drivers import Request, Response @@ -12,10 +12,11 @@ from .element import parse from .utils import API, log from .config import ClientInfo +from .models import PageDequeResult from .event import Event, MessageEvent from .models import MessageObject as SatoriMessage from .message import Author, Message, RenderMessage, MessageSegment -from .models import Role, User, Guild, Login, Member, Channel, PageResult +from .models import Role, User, Guild, Login, Order, Member, Upload, Channel, Direction, PageResult from .exception import ( ActionFailed, NetworkError, @@ -66,11 +67,7 @@ async def _check_reply( event.to_me = True else: return - if ( - len(message) > index - and message[index].type == "at" - and message[index].data.get("id") == str(bot.get_self_id()) - ): + if len(message) > index and message[index].type == "at" and message[index].data.get("id") == str(bot.get_self_id()): event.to_me = True del message[index] if len(message) > index and message[index].type == "text": @@ -149,22 +146,22 @@ class Bot(BaseBot): adapter: "Adapter" @override - def __init__(self, adapter: "Adapter", self_id: str, platform: str, info: ClientInfo): + def __init__(self, adapter: "Adapter", self_id: str, login: Login, info: ClientInfo): super().__init__(adapter, self_id) # Bot 配置信息 self.info: ClientInfo = info # Bot 自身所属平台 - self.platform: str = platform + self.platform: str = login.platform or "satori" # Bot 自身信息 - self._self_info: Optional[User] = None + self._self_info = login def __getattr__(self, item): raise AttributeError(f"'Bot' object has no attribute '{item}'") def get_self_id(self): - if self._self_info: - return self._self_info.id + if self._self_info and self._self_info.user: + return self._self_info.user.id return self.self_id @property @@ -175,14 +172,14 @@ def ready(self) -> bool: @property def self_info(self) -> User: """Bot 自身信息,仅当 Bot 连接鉴权完成后可用""" - if self._self_info is None: + if self._self_info.user is None: raise RuntimeError(f"Bot {self.self_id} of {self.platform} is not connected!") - return self._self_info + return self._self_info.user - def on_ready(self, user: User) -> None: - self._self_info = user + def _update(self, login: Login) -> None: + self._self_info = login - def get_authorization_header(self) -> Dict[str, str]: + def get_authorization_header(self) -> dict[str, str]: """获取当前 Bot 的鉴权信息""" header = { "Authorization": f"Bearer {self.info.token}", @@ -200,9 +197,17 @@ async def handle_event(self, event: Event) -> None: _check_nickname(self, event) await handle_event(self, event) - def _handle_response(self, response: Response) -> Any: + @overload + def _handle_response(self, response: Response) -> dict: ... + + @overload + def _handle_response(self, response: Response, noreturn: Literal[True]) -> None: ... + + def _handle_response(self, response: Response, noreturn=False) -> Any: if 200 <= response.status_code < 300: - return response.content and json.loads(response.content) + if not noreturn: + return + return json.loads(response.content) if response.content else {} elif response.status_code == 400: raise BadRequestException(response) elif response.status_code == 401: @@ -228,57 +233,71 @@ async def _request(self, request: Request) -> Any: return self._handle_response(response) + async def download(self, url: str) -> bytes: + """访问内部链接。""" + request = Request("GET", self.info.api_base / "proxy" / url.lstrip("/")) + try: + response = await self.adapter.request(request) + except Exception as e: + raise NetworkError("API request failed") from e + + self._handle_response(response, noreturn=True) + return response.content # type: ignore + @override async def send( self, event: Event, message: Union[str, Message, MessageSegment], **kwargs, - ) -> List[SatoriMessage]: + ) -> list[SatoriMessage]: if not event.channel: raise RuntimeError("Event cannot be replied to!") return await self.send_message(event.channel.id, message) async def send_message( self, - channel_id: str, + channel: Union[str, Channel], message: Union[str, Message, MessageSegment], - ) -> List[SatoriMessage]: + ) -> list[SatoriMessage]: """发送消息 参数: - channel_id: 要发送的频道 ID - message: 要发送的消息 + channel (str | Channel): 要发送的频道 ID + message (str | Message | MessageSegment): 要发送的消息 """ + channel_id = channel.id if isinstance(channel, Channel) else channel return await self.message_create(channel_id=channel_id, content=str(message)) async def send_private_message( self, - user_id: str, + user: Union[str, User], message: Union[str, Message, MessageSegment], - ) -> List[SatoriMessage]: + ) -> list[SatoriMessage]: """发送私聊消息 参数: - user_id: 要发送的用户 ID - message: 要发送的消息 + user (str | User): 要发送的用户 ID + message (str | Message | MessageSegment): 要发送的消息 """ + user_id = user.id if isinstance(user, User) else user channel = await self.user_channel_create(user_id=user_id) return await self.message_create(channel_id=channel.id, content=str(message)) async def update_message( self, - channel_id: str, + channel: Union[str, Channel], message_id: str, message: Union[str, Message, MessageSegment], ): """更新消息 参数: - channel_id: 要更新的频道 ID - message_id: 要更新的消息 ID - message: 要更新的消息 + channel (str | Channel): 要更新的频道 ID + message_id (str): 要更新的消息 ID + message (str | Message | MessageSegment): 更新后的消息 """ + channel_id = channel.id if isinstance(channel, Channel) else channel await self.message_update(channel_id=channel_id, message_id=message_id, content=str(message)) @API @@ -287,7 +306,7 @@ async def message_create( *, channel_id: str, content: str, - ) -> List[SatoriMessage]: + ) -> list[SatoriMessage]: request = Request( "POST", self.info.api_base / "message.create", @@ -336,14 +355,28 @@ async def message_update( @API async def message_list( - self, *, channel_id: str, next_token: Optional[str] = None - ) -> PageResult[SatoriMessage]: + self, + *, + channel_id: str, + next_token: Optional[str] = None, + direction: Direction = "before", + limit: int = 50, + order: Order = "asc", + ) -> PageDequeResult[SatoriMessage]: + if not next_token and direction != "before": + raise ValueError("Invalid direction") request = Request( "POST", self.info.api_base / "message.list", - json={"channel_id": channel_id, "next": next_token}, + json={ + "channel_id": channel_id, + "next": next_token, + "direction": direction, + "limit": limit, + "order": order, + }, ) - return type_validate_python(PageResult[SatoriMessage], await self._request(request)) + return type_validate_python(PageDequeResult[SatoriMessage], await self._request(request)) @API async def channel_get(self, *, channel_id: str) -> Channel: @@ -445,9 +478,7 @@ async def guild_approve(self, *, request_id: str, approve: bool, comment: str) - await self._request(request) @API - async def guild_member_list( - self, *, guild_id: str, next_token: Optional[str] = None - ) -> PageResult[Member]: + async def guild_member_list(self, *, guild_id: str, next_token: Optional[str] = None) -> PageResult[Member]: request = Request( "POST", self.info.api_base / "guild.member.list", @@ -665,15 +696,53 @@ async def friend_approve(self, *, request_id: str, approve: bool, comment: str) await self._request(request) @API - async def internal( - self, - *, - action: str, - **kwargs, - ) -> Any: + async def internal(self, *, action: str, **kwargs) -> Any: + """内部接口调用。 + + 参数: + action (str): 内部接口名称 + **kwargs: 参数 + """ request = Request( "POST", self.info.api_base / "internal" / action, json=kwargs, ) return await self._request(request) + + @API + async def admin_login_list(self) -> list[Login]: + request = Request( + "POST", + self.info.api_base / "admin" / "login.list", + ) + res = await self._request(request) + return [type_validate_python(Login, i) for i in res] + + @overload + async def upload(self, *uploads: Upload) -> list[str]: ... + + @overload + async def upload(self, **uploads: Upload) -> dict[str, str]: ... + + async def upload(self, *args: Upload, **kwargs: Upload): + """上传文件。 + + 如果要发送的消息中含有图片或其他媒体资源,\ + 可以使用此 API 将文件上传至 Satori 服务器并转换为 URL,以便在消息编码中使用。 + """ + if args and kwargs: + raise RuntimeError("upload can't accept both args and kwargs") + if args: + ids = [] + for upload in args: + ids.append(str(id(upload))) + + resp = await self.upload_create(**dict(zip(ids, args))) + return list(resp.values()) + return await self.upload_create(**kwargs) + + @API + async def upload_create(self, **kwargs: Upload) -> dict[str, str]: + request = Request("POST", self.info.api_base / "upload.create", files={k: v.dump() for k, v in kwargs.items()}) + return await self._request(request) diff --git a/nonebot/adapters/satori/config.py b/nonebot/adapters/satori/config.py index 9eeda2f..322c7fa 100644 --- a/nonebot/adapters/satori/config.py +++ b/nonebot/adapters/satori/config.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Optional from yarl import URL from pydantic import Field, BaseModel @@ -28,5 +28,5 @@ def ws_base(self): class Config(BaseModel): - satori_clients: List[ClientInfo] = Field(default_factory=list) + satori_clients: list[ClientInfo] = Field(default_factory=list) """client 配置""" diff --git a/nonebot/adapters/satori/element.py b/nonebot/adapters/satori/element.py index 6c769f6..3e6e4e7 100644 --- a/nonebot/adapters/satori/element.py +++ b/nonebot/adapters/satori/element.py @@ -1,8 +1,9 @@ import re from enum import IntEnum +from collections.abc import Iterable from typing_extensions import TypeAlias from dataclasses import field, dataclass -from typing import Any, Dict, List, Union, Literal, TypeVar, Callable, Iterable, Optional, TypedDict, cast +from typing import Any, Union, Literal, TypeVar, Callable, Optional, TypedDict, cast T = TypeVar("T") @@ -28,24 +29,20 @@ def camel_case(source: str) -> str: def param_case(source: str) -> str: - return re.sub( - ".[A-Z]+", lambda mat: mat[0][0] + "-" + mat[0][1:].lower(), uncapitalize(source).replace("_", "-") - ) + return re.sub(".[A-Z]+", lambda mat: mat[0][0] + "-" + mat[0][1:].lower(), uncapitalize(source).replace("_", "-")) def snake_case(source: str) -> str: - return re.sub( - ".[A-Z]", lambda mat: mat[0][0] + "_" + mat[0][1:].lower(), uncapitalize(source).replace("-", "_") - ) + return re.sub(".[A-Z]", lambda mat: mat[0][0] + "_" + mat[0][1:].lower(), uncapitalize(source).replace("-", "_")) -def ensure_list(value: Union[T, List[T], None]) -> List[T]: +def ensure_list(value: Union[T, list[T], None]) -> list[T]: return value if isinstance(value, list) else [value] if value else [] S = TypeVar("S") -Fragment: TypeAlias = Union[str, "Element", List[Union[str, "Element"]]] -Render: TypeAlias = Callable[[dict, List["Element"], S], T] +Fragment: TypeAlias = Union[str, "Element", list[Union[str, "Element"]]] +Render: TypeAlias = Callable[[dict, list["Element"], S], T] Visitor: TypeAlias = Callable[["Element", S], T] @@ -60,7 +57,7 @@ def make_element(content: Union[str, bool, int, float, "Element"]) -> Optional[" raise ValueError(f"Invalid content: {content!r}") -def make_elements(content: Fragment) -> List["Element"]: +def make_elements(content: Fragment) -> list["Element"]: if isinstance(content, list): res = [make_element(c) for c in content] else: @@ -70,14 +67,14 @@ def make_elements(content: Fragment) -> List["Element"]: class Element: type: str - attrs: Dict[str, Any] - children: List["Element"] + attrs: dict[str, Any] + children: list["Element"] source: Optional[str] = None def __init__( self, type: Union[str, Render[Fragment, Any]], - attrs: Optional[Dict[str, Any]] = None, + attrs: Optional[dict[str, Any]] = None, *children: Fragment, ) -> None: self.attrs = {} @@ -149,8 +146,8 @@ class Selector: comb_pat = re.compile(" *([ >+~]) *") -def parse_selector(input: str) -> List[List[Selector]]: - def _quert(query: str) -> List[Selector]: +def parse_selector(input: str) -> list[list[Selector]]: + def _quert(query: str) -> list[Selector]: selectors = [] combinator = " " while mat := comb_pat.search(query): @@ -168,7 +165,7 @@ def _quert(query: str) -> List[Selector]: return [_quert(q) for q in input.split(",")] -def select(source: Union[str, List[Element]], query: Union[str, List[List[Selector]]]) -> List[Element]: +def select(source: Union[str, list[Element]], query: Union[str, list[list[Selector]]]) -> list[Element]: if not source or not query: return [] if isinstance(source, str): @@ -177,10 +174,10 @@ def select(source: Union[str, List[Element]], query: Union[str, List[List[Select query = parse_selector(query) if not query: return [] - adjacent: List[List[Selector]] = [] + adjacent: list[list[Selector]] = [] results = [] for index, elem in enumerate(source): - inner: List[List[Selector]] = [] + inner: list[list[Selector]] = [] local = [*query, *adjacent] adjacent = [] matched = False @@ -231,9 +228,7 @@ def interpolate(expr: str, context: dict) -> Any: r"(?P)|(?P<(/?)([^!\s>/]*)([^>]*?)\s*(/?)>)|(?P\{(?P[@:/#][^\s\}]*)?[\s\S]*?\})" ) attr_pat1 = re.compile(r"([^\s=]+)(?:=\"(?P[^\"]*)\"|='(?P[^']*)')?", re.S) -attr_pat2 = re.compile( - r"([^\s=]+)(?:=\"(?P[^\"]*)\"|='(?P[^']*)'|=\{(?P[^\}]+)\})?", re.S -) +attr_pat2 = re.compile(r"([^\s=]+)(?:=\"(?P[^\"]*)\"|='(?P[^']*)'|=\{(?P[^\}]+)\})?", re.S) class Position(IntEnum): @@ -250,7 +245,7 @@ class Token: positon: Position source: str extra: str - children: Dict[str, List[Union[str, "Token"]]] = field(default_factory=dict) + children: dict[str, list[Union[str, "Token"]]] = field(default_factory=dict) class StackItem(TypedDict): @@ -258,8 +253,8 @@ class StackItem(TypedDict): slot: str -def fold_tokens(tokens: List[Union[str, Token]]) -> List[Union[str, Token]]: - stack: List[StackItem] = [ +def fold_tokens(tokens: list[Union[str, Token]]) -> list[Union[str, Token]]: + stack: list[StackItem] = [ { "token": Token( type="angle", @@ -296,8 +291,8 @@ def push_token(*tokens: Union[str, Token]): return stack[-1]["token"].children["default"] -def parse_tokens(tokens: List[Union[str, Token]], context: Optional[dict] = None) -> List[Element]: - result: List[Element] = [] +def parse_tokens(tokens: list[Union[str, Token]], context: Optional[dict] = None) -> list[Element]: + result: list[Element] = [] for token in tokens: if isinstance(token, str): result.append(Element(type="text", attrs={"text": token})) @@ -343,7 +338,7 @@ def parse_tokens(tokens: List[Union[str, Token]], context: Optional[dict] = None def parse(src: str, context: Optional[dict] = None): - tokens: List[Union[str, Token]] = [] + tokens: list[Union[str, Token]] = [] def push_text(text: str): if text: diff --git a/nonebot/adapters/satori/event.py b/nonebot/adapters/satori/event.py index 7861678..9af14c7 100644 --- a/nonebot/adapters/satori/event.py +++ b/nonebot/adapters/satori/event.py @@ -1,7 +1,7 @@ from enum import Enum from copy import deepcopy from typing_extensions import override -from typing import TYPE_CHECKING, Dict, Type, TypeVar, Optional +from typing import TYPE_CHECKING, TypeVar, Optional from nonebot.utils import escape_tag from nonebot.compat import model_dump, type_validate_python @@ -84,10 +84,10 @@ def is_tome(self) -> bool: return False -EVENT_CLASSES: Dict[str, Type[Event]] = {} +EVENT_CLASSES: dict[str, type[Event]] = {} -def register_event_class(event_class: Type[E]) -> Type[E]: +def register_event_class(event_class: type[E]) -> type[E]: EVENT_CLASSES[event_class.__type__.value] = event_class return event_class @@ -325,7 +325,7 @@ class PrivateMessageCreatedEvent(MessageCreatedEvent, PrivateMessageEvent): def get_event_description(self) -> str: return escape_tag( f"Message {self.msg_id} from " - f"{self.user.name or ''}({self.channel.id}): {self.get_message()!r}" + f"{self.user.name or self.user.nick or ''}({self.channel.id}): {self.get_message()!r}" ) @@ -334,7 +334,7 @@ class PublicMessageCreatedEvent(MessageCreatedEvent, PublicMessageEvent): def get_event_description(self) -> str: return escape_tag( f"Message {self.msg_id} from " - f"{self.member.name if self.member else (self.user.name or '')}({self.user.id})" + f"{(self.member.nick if self.member else None) or (self.user.name or self.user.nick or '')}({self.user.id})" f"@[{self.channel.name or ''}:{self.channel.id}]" f": {self.get_message()!r}" ) @@ -343,7 +343,9 @@ def get_event_description(self) -> str: class PrivateMessageDeletedEvent(MessageDeletedEvent, PrivateMessageEvent): @override def get_event_description(self) -> str: - return escape_tag(f"Message {self.msg_id} from " f"{self.user.name or ''}({self.channel.id}) deleted") + return escape_tag( + f"Message {self.msg_id} from " f"{self.user.name or self.user.nick or ''}({self.channel.id}) deleted" + ) class PublicMessageDeletedEvent(MessageDeletedEvent, PublicMessageEvent): @@ -351,7 +353,7 @@ class PublicMessageDeletedEvent(MessageDeletedEvent, PublicMessageEvent): def get_event_description(self) -> str: return escape_tag( f"Message {self.msg_id} from " - f"{self.member.name if self.member else (self.user.name or '')}({self.user.id})" + f"{(self.member.nick if self.member else None) or (self.user.name or self.user.nick or '')}({self.user.id})" f"@[{self.channel.name or ''}:{self.channel.id}] deleted" ) @@ -361,7 +363,7 @@ class PrivateMessageUpdatedEvent(MessageUpdatedEvent, PrivateMessageEvent): def get_event_description(self) -> str: return escape_tag( f"Message {self.msg_id} from " - f"{self.user.name or ''}({self.channel.id}) updated" + f"{self.user.name or self.user.nick or ''}({self.channel.id}) updated" f": {self.get_message()!r}" ) @@ -371,7 +373,7 @@ class PublicMessageUpdatedEvent(MessageUpdatedEvent, PublicMessageEvent): def get_event_description(self) -> str: return escape_tag( f"Message {self.msg_id} from " - f"{self.member.name if self.member else (self.user.name or '')}({self.user.id})" + f"{(self.member.nick if self.member else None) or (self.user.name or self.user.nick or '')}({self.user.id})" f"@[{self.channel.name or ''}:{self.channel.id}] updated" f": {self.get_message()!r}" ) diff --git a/nonebot/adapters/satori/exception.py b/nonebot/adapters/satori/exception.py index 6b584d6..f261983 100644 --- a/nonebot/adapters/satori/exception.py +++ b/nonebot/adapters/satori/exception.py @@ -19,9 +19,7 @@ def __init__(self, response: Response): self.content = response.content def __repr__(self) -> str: - return ( - f"<{self.__class__.__name__}: {self.status_code}, headers={self.headers}, content={self.content}>" - ) + return f"<{self.__class__.__name__}: {self.status_code}, headers={self.headers}, content={self.content}>" def __str__(self): return self.__repr__() diff --git a/nonebot/adapters/satori/message.py b/nonebot/adapters/satori/message.py index 4a189bd..7cf1eb8 100644 --- a/nonebot/adapters/satori/message.py +++ b/nonebot/adapters/satori/message.py @@ -2,9 +2,10 @@ from io import BytesIO from pathlib import Path from base64 import b64encode +from collections.abc import Iterable from dataclasses import InitVar, field, dataclass +from typing import Any, Union, Optional, TypedDict from typing_extensions import Self, NotRequired, override -from typing import Any, Dict, List, Type, Tuple, Union, Iterable, Optional, TypedDict from nonebot.adapters import Message as BaseMessage from nonebot.adapters import MessageSegment as BaseMessageSegment @@ -47,7 +48,7 @@ def __call__(self, *segments: Union[str, Iterable["MessageSegment"], "MessageSeg @classmethod @override - def get_message_class(cls) -> Type["Message"]: + def get_message_class(cls) -> type["Message"]: return Message @staticmethod @@ -371,7 +372,7 @@ def is_text(self) -> bool: class TextData(TypedDict): text: str - styles: Dict[Tuple[int, int], List[str]] + styles: dict[tuple[int, int], list[str]] @dataclass @@ -728,7 +729,7 @@ class Custom(MessageSegment): } -def handle(element: Element, upper_styles: Optional[List[str]] = None): +def handle(element: Element, upper_styles: Optional[list[str]] = None): tag = element.tag() if tag in ELEMENT_TYPE_MAP: seg_cls, seg_type = ELEMENT_TYPE_MAP[tag] @@ -771,15 +772,13 @@ def handle(element: Element, upper_styles: Optional[List[str]] = None): *(handle(child, [*(upper_styles or [])]) for child in element.children), ) else: - yield Custom(tag, element.attrs.copy())( - *(handle(child, [*(upper_styles or [])]) for child in element.children) - ) + yield Custom(tag, element.attrs.copy())(*(handle(child, [*(upper_styles or [])]) for child in element.children)) class Message(BaseMessage[MessageSegment]): @classmethod @override - def get_segment_class(cls) -> Type[MessageSegment]: + def get_segment_class(cls) -> type[MessageSegment]: return MessageSegment def __init__( @@ -808,7 +807,7 @@ def _construct(msg: str) -> Iterable[MessageSegment]: yield MessageSegment.text(msg) @classmethod - def from_satori_element(cls, elements: List[Element]) -> "Message": + def from_satori_element(cls, elements: list[Element]) -> "Message": msg = Message() for elem in elements: diff --git a/nonebot/adapters/satori/models.py b/nonebot/adapters/satori/models.py index 3280d8f..a265918 100644 --- a/nonebot/adapters/satori/models.py +++ b/nonebot/adapters/satori/models.py @@ -1,6 +1,10 @@ +import mimetypes +from os import PathLike from enum import IntEnum +from pathlib import Path from datetime import datetime -from typing import Any, Dict, List, Union, Generic, Literal, TypeVar, Optional +from typing_extensions import TypeAlias +from typing import IO, Any, Union, Generic, Literal, TypeVar, Optional from pydantic import Field, BaseModel from nonebot.compat import PYDANTIC_V2, ConfigDict @@ -63,7 +67,6 @@ class Config: class Member(BaseModel): user: Optional[User] = None - name: Optional[str] = None nick: Optional[str] = None avatar: Optional[str] = None joined_at: Optional[datetime] = None @@ -115,6 +118,8 @@ class Login(BaseModel): self_id: Optional[str] = None platform: Optional[str] = None status: LoginStatus + features: list[str] = Field(default_factory=list) + proxy_urls: list[str] = Field(default_factory=list) if PYDANTIC_V2: model_config: ConfigDict = ConfigDict(extra="allow") # type: ignore @@ -145,7 +150,7 @@ class Opcode(IntEnum): class Payload(BaseModel): op: Opcode = Field(...) - body: Optional[Dict[str, Any]] = Field(None) + body: Optional[dict[str, Any]] = Field(None) class Identify(BaseModel): @@ -154,7 +159,7 @@ class Identify(BaseModel): class Ready(BaseModel): - logins: List[Login] + logins: list[Login] class IdentifyPayload(Payload): @@ -282,19 +287,50 @@ class EventPayload(Payload): if PYDANTIC_V2: - class PageResult(BaseModel, Generic[T]): # type: ignore - data: List[T] + class PageResult(BaseModel, Generic[T]): + data: list[T] next: Optional[str] = None model_config: ConfigDict = ConfigDict(extra="allow") # type: ignore + class PageDequeResult(PageResult[T]): + prev: Optional[str] = None + else: from pydantic.generics import GenericModel class PageResult(GenericModel, Generic[T]): - data: List[T] + data: list[T] next: Optional[str] = None class Config: extra = "allow" + + class PageDequeResult(PageResult[T]): + prev: Optional[str] = None + + +Direction: TypeAlias = Literal["before", "after", "around"] +Order: TypeAlias = Literal["asc", "desc"] + + +class Upload: + def __init__( + self, file: Union[bytes, IO[bytes], PathLike], mimetype: str = "image/png", name: Optional[str] = None + ): + self.file = file + self.mimetype = mimetype + + if isinstance(self.file, PathLike): + self.mimetype = mimetypes.guess_type(str(self.file))[0] or self.mimetype + self.name = Path(self.file).name + else: + self.name = name + + def dump(self): + file = self.file + + if isinstance(file, PathLike): + file = open(file, "rb") + return self.name, file, self.mimetype diff --git a/nonebot/adapters/satori/utils.py b/nonebot/adapters/satori/utils.py index 078136e..fb56457 100644 --- a/nonebot/adapters/satori/utils.py +++ b/nonebot/adapters/satori/utils.py @@ -1,6 +1,7 @@ from functools import partial +from collections.abc import Awaitable from typing_extensions import ParamSpec, Concatenate -from typing import TYPE_CHECKING, Type, Generic, TypeVar, Callable, Optional, Awaitable, overload +from typing import TYPE_CHECKING, Generic, TypeVar, Callable, Optional, overload from nonebot.utils import logger_wrapper @@ -17,17 +18,17 @@ class API(Generic[B, P, R]): def __init__(self, func: Callable[Concatenate[B, P], Awaitable[R]]) -> None: self.func = func - def __set_name__(self, owner: Type[B], name: str) -> None: + def __set_name__(self, owner: type[B], name: str) -> None: self.name = name @overload - def __get__(self, obj: None, objtype: Type[B]) -> "API[B, P, R]": ... + def __get__(self, obj: None, objtype: type[B]) -> "API[B, P, R]": ... @overload - def __get__(self, obj: B, objtype: Optional[Type[B]]) -> Callable[P, Awaitable[R]]: ... + def __get__(self, obj: B, objtype: Optional[type[B]]) -> Callable[P, Awaitable[R]]: ... def __get__( - self, obj: Optional[B], objtype: Optional[Type[B]] = None + self, obj: Optional[B], objtype: Optional[type[B]] = None ) -> "API[B, P, R] | Callable[P, Awaitable[R]]": if obj is None: return self diff --git a/pdm.lock b/pdm.lock index c0d765b..025511a 100644 --- a/pdm.lock +++ b/pdm.lock @@ -914,27 +914,27 @@ files = [ [[package]] name = "ruff" -version = "0.4.7" +version = "0.4.8" requires_python = ">=3.7" summary = "An extremely fast Python linter and code formatter, written in Rust." files = [ - {file = "ruff-0.4.7-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:e089371c67892a73b6bb1525608e89a2aca1b77b5440acf7a71dda5dac958f9e"}, - {file = "ruff-0.4.7-py3-none-macosx_11_0_arm64.whl", hash = "sha256:10f973d521d910e5f9c72ab27e409e839089f955be8a4c8826601a6323a89753"}, - {file = "ruff-0.4.7-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:59c3d110970001dfa494bcd95478e62286c751126dfb15c3c46e7915fc49694f"}, - {file = "ruff-0.4.7-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:fa9773c6c00f4958f73b317bc0fd125295110c3776089f6ef318f4b775f0abe4"}, - {file = "ruff-0.4.7-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:07fc80bbb61e42b3b23b10fda6a2a0f5a067f810180a3760c5ef1b456c21b9db"}, - {file = "ruff-0.4.7-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:fa4dafe3fe66d90e2e2b63fa1591dd6e3f090ca2128daa0be33db894e6c18648"}, - {file = "ruff-0.4.7-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a7c0083febdec17571455903b184a10026603a1de078428ba155e7ce9358c5f6"}, - {file = "ruff-0.4.7-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ad1b20e66a44057c326168437d680a2166c177c939346b19c0d6b08a62a37589"}, - {file = "ruff-0.4.7-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cbf5d818553add7511c38b05532d94a407f499d1a76ebb0cad0374e32bc67202"}, - {file = "ruff-0.4.7-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:50e9651578b629baec3d1513b2534de0ac7ed7753e1382272b8d609997e27e83"}, - {file = "ruff-0.4.7-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:8874a9df7766cb956b218a0a239e0a5d23d9e843e4da1e113ae1d27ee420877a"}, - {file = "ruff-0.4.7-py3-none-musllinux_1_2_i686.whl", hash = "sha256:b9de9a6e49f7d529decd09381c0860c3f82fa0b0ea00ea78409b785d2308a567"}, - {file = "ruff-0.4.7-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:13a1768b0691619822ae6d446132dbdfd568b700ecd3652b20d4e8bc1e498f78"}, - {file = "ruff-0.4.7-py3-none-win32.whl", hash = "sha256:769e5a51df61e07e887b81e6f039e7ed3573316ab7dd9f635c5afaa310e4030e"}, - {file = "ruff-0.4.7-py3-none-win_amd64.whl", hash = "sha256:9e3ab684ad403a9ed1226894c32c3ab9c2e0718440f6f50c7c5829932bc9e054"}, - {file = "ruff-0.4.7-py3-none-win_arm64.whl", hash = "sha256:10f2204b9a613988e3484194c2c9e96a22079206b22b787605c255f130db5ed7"}, - {file = "ruff-0.4.7.tar.gz", hash = "sha256:2331d2b051dc77a289a653fcc6a42cce357087c5975738157cd966590b18b5e1"}, + {file = "ruff-0.4.8-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:7663a6d78f6adb0eab270fa9cf1ff2d28618ca3a652b60f2a234d92b9ec89066"}, + {file = "ruff-0.4.8-py3-none-macosx_11_0_arm64.whl", hash = "sha256:eeceb78da8afb6de0ddada93112869852d04f1cd0f6b80fe464fd4e35c330913"}, + {file = "ruff-0.4.8-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:aad360893e92486662ef3be0a339c5ca3c1b109e0134fcd37d534d4be9fb8de3"}, + {file = "ruff-0.4.8-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:284c2e3f3396fb05f5f803c9fffb53ebbe09a3ebe7dda2929ed8d73ded736deb"}, + {file = "ruff-0.4.8-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a7354f921e3fbe04d2a62d46707e569f9315e1a613307f7311a935743c51a764"}, + {file = "ruff-0.4.8-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:72584676164e15a68a15778fd1b17c28a519e7a0622161eb2debdcdabdc71883"}, + {file = "ruff-0.4.8-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9678d5c9b43315f323af2233a04d747409d1e3aa6789620083a82d1066a35199"}, + {file = "ruff-0.4.8-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704977a658131651a22b5ebeb28b717ef42ac6ee3b11e91dc87b633b5d83142b"}, + {file = "ruff-0.4.8-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d05f8d6f0c3cce5026cecd83b7a143dcad503045857bc49662f736437380ad45"}, + {file = "ruff-0.4.8-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:6ea874950daca5697309d976c9afba830d3bf0ed66887481d6bca1673fc5b66a"}, + {file = "ruff-0.4.8-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:fc95aac2943ddf360376be9aa3107c8cf9640083940a8c5bd824be692d2216dc"}, + {file = "ruff-0.4.8-py3-none-musllinux_1_2_i686.whl", hash = "sha256:384154a1c3f4bf537bac69f33720957ee49ac8d484bfc91720cc94172026ceed"}, + {file = "ruff-0.4.8-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:e9d5ce97cacc99878aa0d084c626a15cd21e6b3d53fd6f9112b7fc485918e1fa"}, + {file = "ruff-0.4.8-py3-none-win32.whl", hash = "sha256:6d795d7639212c2dfd01991259460101c22aabf420d9b943f153ab9d9706e6a9"}, + {file = "ruff-0.4.8-py3-none-win_amd64.whl", hash = "sha256:e14a3a095d07560a9d6769a72f781d73259655919d9b396c650fc98a8157555d"}, + {file = "ruff-0.4.8-py3-none-win_arm64.whl", hash = "sha256:14019a06dbe29b608f6b7cbcec300e3170a8d86efaddb7b23405cb7f7dcaf780"}, + {file = "ruff-0.4.8.tar.gz", hash = "sha256:16d717b1d57b2e2fd68bd0bf80fb43931b79d05a7131aa477d66fc40fbd86268"}, ] [[package]] diff --git a/pyproject.toml b/pyproject.toml index 284de19..dc55627 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,12 +1,12 @@ [project] name = "nonebot-adapter-satori" -version = "0.11.5" +version = "0.12.0" description = "Satori Protocol Adapter for Nonebot2" authors = [ {name = "RF-Tar-Railt",email = "rf_tar_railt@qq.com"}, ] dependencies = [ - "nonebot2>=2.2.0", + "nonebot2>=2.3.0", ] requires-python = ">=3.9" readme = "README.md" @@ -27,7 +27,7 @@ dev = [ "black>=24.1.1", "ruff>=0.2.1", "pre-commit>=3.5.0", - "nonebot2[httpx,websockets]>=2.2.0", + "nonebot2[httpx,websockets]>=2.3.0", ] test = [ "nonebug>=0.3.5", @@ -46,14 +46,14 @@ format = { composite = ["isort ./","black ./","ruff check ./"] } [tool.black] -line-length = 110 +line-length = 120 include = '\.pyi?$' extend-exclude = ''' ''' [tool.isort] profile = "black" -line_length = 110 +line_length = 120 length_sort = true skip_gitignore = true force_sort_within_sections = true @@ -61,8 +61,8 @@ extra_standard_library = ["typing_extensions"] [tool.ruff] -line-length = 110 -target-version = "py38" +line-length = 120 +target-version = "py39" [tool.ruff.lint] select = ["E", "W", "F", "UP", "C", "T", "Q"] diff --git a/tests/conftest.py b/tests/conftest.py index d7d959f..aac483b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,6 @@ import threading from pathlib import Path -from typing import Generator +from collections.abc import Generator import pytest from nonebot.drivers import URL @@ -11,9 +11,7 @@ import nonebot import nonebot.adapters -nonebot.adapters.__path__.append( # type: ignore - str((Path(__file__).parent.parent / "nonebot" / "adapters").resolve()) -) +nonebot.adapters.__path__.append(str((Path(__file__).parent.parent / "nonebot" / "adapters").resolve())) # type: ignore from nonebot.adapters.satori import Adapter as SatoriAdapter diff --git a/tests/fake_server.py b/tests/fake_server.py index 04b2371..cf16a1f 100644 --- a/tests/fake_server.py +++ b/tests/fake_server.py @@ -2,7 +2,7 @@ import base64 import socket from queue import Empty, Queue -from typing import Dict, List, Union, TypeVar, Callable +from typing import Union, TypeVar, Callable from wsproto.events import Ping from werkzeug import Request, Response @@ -32,7 +32,7 @@ def json_safe(string, content_type="application/octet-stream") -> str: ).decode("utf-8") -def flattern(d: "MultiDict[K, V]") -> Dict[K, Union[V, List[V]]]: +def flattern(d: "MultiDict[K, V]") -> dict[K, Union[V, list[V]]]: return {k: v[0] if len(v) == 1 else v for k, v in d.to_dict(flat=False).items()} diff --git a/tests/test_adapter.py b/tests/test_adapter.py index 97527ac..8873428 100644 --- a/tests/test_adapter.py +++ b/tests/test_adapter.py @@ -6,6 +6,7 @@ import nonebot from nonebot.adapters.satori import Bot, Adapter +from nonebot.adapters.satori.models import Login, LoginStatus from nonebot.adapters.satori.event import PublicMessageCreatedEvent @@ -16,14 +17,13 @@ async def test_adapter(app: App): @cmd.handle() async def handle(bot: Bot): - await bot.send_message( - channel_id="67890", - message="hello", - ) + await bot.send_message(channel="67890", message="hello") async with app.test_matcher(cmd) as ctx: adapter: Adapter = nonebot.get_adapter(Adapter) - bot: Bot = ctx.create_bot(base=Bot, adapter=adapter, self_id="0", platform="test", info=None) + bot: Bot = ctx.create_bot( + base=Bot, adapter=adapter, self_id="0", login=Login(status=LoginStatus.CONNECT), info=None + ) ctx.receive_event( bot, diff --git a/tests/test_message.py b/tests/test_message.py index 0322a9a..c02622e 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -14,12 +14,11 @@ def test_message(): assert Message.from_satori_element(parse(code))[0].data["chronocat:seq"] == "1" assert Message("1234")[0].data["test:aaa"] is True assert ( - Message.from_satori_element(parse(""))[0].data["test:bbb"] - == "foo" + Message.from_satori_element(parse(""))[0].data["test:bbb"] == "foo" ) - assert Message.from_satori_element( - parse("") - )[0].data == {"id": "265", "name": "[辣眼睛]", "platform": "chronocat"} + assert Message.from_satori_element(parse(""))[ + 0 + ].data == {"id": "265", "name": "[辣眼睛]", "platform": "chronocat"} test_message1 = MessageSegment(type="chronocat:face", data={"id": 12}) + "\n" + "Hello Yoshi" assert str(test_message1) == '\nHello Yoshi' @@ -59,8 +58,6 @@ def test_message_fallback(): """ msg = Message.from_satori_element(parse(code)) - assert ( - str(msg[0].children) == '当前平台不支持发送视频,请在这里观看视频!' - ) + assert str(msg[0].children) == '当前平台不支持发送视频,请在这里观看视频!' assert msg.extract_plain_text() == "当前平台不支持发送视频,请在这里观看视频!" assert list(msg.query("link"))[0].data["text"] == "http://aa.com/a.mp4"