From 7a7ab984b438e83e814f340c96c1b371b9160100 Mon Sep 17 00:00:00 2001
From: rf_tar_railt <3165388245@qq.com>
Date: Thu, 13 Jun 2024 11:53:07 +0800
Subject: [PATCH] :sparkles: version 0.12.0 impl Satori Protocol V1.1
---
nonebot/adapters/satori/adapter.py | 21 ++--
nonebot/adapters/satori/bot.py | 161 +++++++++++++++++++--------
nonebot/adapters/satori/config.py | 4 +-
nonebot/adapters/satori/element.py | 51 ++++-----
nonebot/adapters/satori/event.py | 20 ++--
nonebot/adapters/satori/exception.py | 4 +-
nonebot/adapters/satori/message.py | 17 ++-
nonebot/adapters/satori/models.py | 50 +++++++--
nonebot/adapters/satori/utils.py | 11 +-
pdm.lock | 36 +++---
pyproject.toml | 14 +--
tests/conftest.py | 6 +-
tests/fake_server.py | 4 +-
tests/test_adapter.py | 10 +-
tests/test_message.py | 13 +--
15 files changed, 257 insertions(+), 165 deletions(-)
diff --git a/nonebot/adapters/satori/adapter.py b/nonebot/adapters/satori/adapter.py
index 183ffc8..81dcf11 100644
--- a/nonebot/adapters/satori/adapter.py
+++ b/nonebot/adapters/satori/adapter.py
@@ -1,7 +1,7 @@
import json
import asyncio
from typing_extensions import override
-from typing import Any, Dict, List, Literal, Optional
+from typing import Any, Literal, Optional
from nonebot.utils import escape_tag
from nonebot.exception import WebSocketClosed
@@ -40,15 +40,15 @@
class Adapter(BaseAdapter):
- bots: Dict[str, Bot]
+ bots: dict[str, Bot]
@override
def __init__(self, driver: Driver, **kwargs: Any):
super().__init__(driver, **kwargs)
# 读取适配器所需的配置项
self.satori_config: Config = get_plugin_config(Config)
- self.tasks: List[asyncio.Task] = [] # 存储 ws 任务
- self.sequences: Dict[str, int] = {} # 存储 连接序列号
+ self.tasks: list[asyncio.Task] = [] # 存储 ws 任务
+ self.sequences: dict[str, int] = {} # 存储 连接序列号
self.setup()
@classmethod
@@ -136,7 +136,7 @@ async def _authenticate(self, info: ClientInfo, ws: WebSocket) -> Optional[Liter
if login.status != LoginStatus.ONLINE:
continue
if login.self_id not in self.bots:
- bot = Bot(self, login.self_id, login.platform or "satori", info)
+ bot = Bot(self, login.self_id, login, info)
self.bot_connect(bot)
log(
"INFO",
@@ -144,8 +144,7 @@ async def _authenticate(self, info: ClientInfo, ws: WebSocket) -> Optional[Liter
)
else:
bot = self.bots[login.self_id]
- if login.user:
- bot.on_ready(login.user)
+ bot._update(login)
if not self.bots:
log("WARNING", "No bots connected!")
return
@@ -232,9 +231,7 @@ async def _loop(self, info: ClientInfo, ws: WebSocket):
)
else:
if isinstance(event, LoginAddedEvent):
- bot = Bot(self, event.self_id, event.platform, info)
- if event.user:
- bot.on_ready(event.user)
+ bot = Bot(self, event.self_id, event.login, info)
self.bot_connect(bot)
log(
"INFO",
@@ -247,8 +244,8 @@ async def _loop(self, info: ClientInfo, ws: WebSocket):
f"Bot {escape_tag(event.self_id)} disconnected",
)
continue
- elif isinstance(event, LoginUpdatedEvent) and event.user:
- self.bots[event.self_id].on_ready(event.user)
+ elif isinstance(event, LoginUpdatedEvent):
+ self.bots[event.self_id]._update(event.login)
if not (bot := self.bots.get(event.self_id)):
log(
"WARNING",
diff --git a/nonebot/adapters/satori/bot.py b/nonebot/adapters/satori/bot.py
index 1113073..73c8dbf 100644
--- a/nonebot/adapters/satori/bot.py
+++ b/nonebot/adapters/satori/bot.py
@@ -1,7 +1,7 @@
import re
import json
from typing_extensions import override
-from typing import TYPE_CHECKING, Any, Dict, List, Union, Optional
+from typing import TYPE_CHECKING, Any, Union, Literal, Optional, overload
from nonebot.message import handle_event
from nonebot.drivers import Request, Response
@@ -12,10 +12,11 @@
from .element import parse
from .utils import API, log
from .config import ClientInfo
+from .models import PageDequeResult
from .event import Event, MessageEvent
from .models import MessageObject as SatoriMessage
from .message import Author, Message, RenderMessage, MessageSegment
-from .models import Role, User, Guild, Login, Member, Channel, PageResult
+from .models import Role, User, Guild, Login, Order, Member, Upload, Channel, Direction, PageResult
from .exception import (
ActionFailed,
NetworkError,
@@ -66,11 +67,7 @@ async def _check_reply(
event.to_me = True
else:
return
- if (
- len(message) > index
- and message[index].type == "at"
- and message[index].data.get("id") == str(bot.get_self_id())
- ):
+ if len(message) > index and message[index].type == "at" and message[index].data.get("id") == str(bot.get_self_id()):
event.to_me = True
del message[index]
if len(message) > index and message[index].type == "text":
@@ -149,22 +146,22 @@ class Bot(BaseBot):
adapter: "Adapter"
@override
- def __init__(self, adapter: "Adapter", self_id: str, platform: str, info: ClientInfo):
+ def __init__(self, adapter: "Adapter", self_id: str, login: Login, info: ClientInfo):
super().__init__(adapter, self_id)
# Bot 配置信息
self.info: ClientInfo = info
# Bot 自身所属平台
- self.platform: str = platform
+ self.platform: str = login.platform or "satori"
# Bot 自身信息
- self._self_info: Optional[User] = None
+ self._self_info = login
def __getattr__(self, item):
raise AttributeError(f"'Bot' object has no attribute '{item}'")
def get_self_id(self):
- if self._self_info:
- return self._self_info.id
+ if self._self_info and self._self_info.user:
+ return self._self_info.user.id
return self.self_id
@property
@@ -175,14 +172,14 @@ def ready(self) -> bool:
@property
def self_info(self) -> User:
"""Bot 自身信息,仅当 Bot 连接鉴权完成后可用"""
- if self._self_info is None:
+ if self._self_info.user is None:
raise RuntimeError(f"Bot {self.self_id} of {self.platform} is not connected!")
- return self._self_info
+ return self._self_info.user
- def on_ready(self, user: User) -> None:
- self._self_info = user
+ def _update(self, login: Login) -> None:
+ self._self_info = login
- def get_authorization_header(self) -> Dict[str, str]:
+ def get_authorization_header(self) -> dict[str, str]:
"""获取当前 Bot 的鉴权信息"""
header = {
"Authorization": f"Bearer {self.info.token}",
@@ -200,9 +197,17 @@ async def handle_event(self, event: Event) -> None:
_check_nickname(self, event)
await handle_event(self, event)
- def _handle_response(self, response: Response) -> Any:
+ @overload
+ def _handle_response(self, response: Response) -> dict: ...
+
+ @overload
+ def _handle_response(self, response: Response, noreturn: Literal[True]) -> None: ...
+
+ def _handle_response(self, response: Response, noreturn=False) -> Any:
if 200 <= response.status_code < 300:
- return response.content and json.loads(response.content)
+ if not noreturn:
+ return
+ return json.loads(response.content) if response.content else {}
elif response.status_code == 400:
raise BadRequestException(response)
elif response.status_code == 401:
@@ -228,57 +233,71 @@ async def _request(self, request: Request) -> Any:
return self._handle_response(response)
+ async def download(self, url: str) -> bytes:
+ """访问内部链接。"""
+ request = Request("GET", self.info.api_base / "proxy" / url.lstrip("/"))
+ try:
+ response = await self.adapter.request(request)
+ except Exception as e:
+ raise NetworkError("API request failed") from e
+
+ self._handle_response(response, noreturn=True)
+ return response.content # type: ignore
+
@override
async def send(
self,
event: Event,
message: Union[str, Message, MessageSegment],
**kwargs,
- ) -> List[SatoriMessage]:
+ ) -> list[SatoriMessage]:
if not event.channel:
raise RuntimeError("Event cannot be replied to!")
return await self.send_message(event.channel.id, message)
async def send_message(
self,
- channel_id: str,
+ channel: Union[str, Channel],
message: Union[str, Message, MessageSegment],
- ) -> List[SatoriMessage]:
+ ) -> list[SatoriMessage]:
"""发送消息
参数:
- channel_id: 要发送的频道 ID
- message: 要发送的消息
+ channel (str | Channel): 要发送的频道 ID
+ message (str | Message | MessageSegment): 要发送的消息
"""
+ channel_id = channel.id if isinstance(channel, Channel) else channel
return await self.message_create(channel_id=channel_id, content=str(message))
async def send_private_message(
self,
- user_id: str,
+ user: Union[str, User],
message: Union[str, Message, MessageSegment],
- ) -> List[SatoriMessage]:
+ ) -> list[SatoriMessage]:
"""发送私聊消息
参数:
- user_id: 要发送的用户 ID
- message: 要发送的消息
+ user (str | User): 要发送的用户 ID
+ message (str | Message | MessageSegment): 要发送的消息
"""
+ user_id = user.id if isinstance(user, User) else user
channel = await self.user_channel_create(user_id=user_id)
return await self.message_create(channel_id=channel.id, content=str(message))
async def update_message(
self,
- channel_id: str,
+ channel: Union[str, Channel],
message_id: str,
message: Union[str, Message, MessageSegment],
):
"""更新消息
参数:
- channel_id: 要更新的频道 ID
- message_id: 要更新的消息 ID
- message: 要更新的消息
+ channel (str | Channel): 要更新的频道 ID
+ message_id (str): 要更新的消息 ID
+ message (str | Message | MessageSegment): 更新后的消息
"""
+ channel_id = channel.id if isinstance(channel, Channel) else channel
await self.message_update(channel_id=channel_id, message_id=message_id, content=str(message))
@API
@@ -287,7 +306,7 @@ async def message_create(
*,
channel_id: str,
content: str,
- ) -> List[SatoriMessage]:
+ ) -> list[SatoriMessage]:
request = Request(
"POST",
self.info.api_base / "message.create",
@@ -336,14 +355,28 @@ async def message_update(
@API
async def message_list(
- self, *, channel_id: str, next_token: Optional[str] = None
- ) -> PageResult[SatoriMessage]:
+ self,
+ *,
+ channel_id: str,
+ next_token: Optional[str] = None,
+ direction: Direction = "before",
+ limit: int = 50,
+ order: Order = "asc",
+ ) -> PageDequeResult[SatoriMessage]:
+ if not next_token and direction != "before":
+ raise ValueError("Invalid direction")
request = Request(
"POST",
self.info.api_base / "message.list",
- json={"channel_id": channel_id, "next": next_token},
+ json={
+ "channel_id": channel_id,
+ "next": next_token,
+ "direction": direction,
+ "limit": limit,
+ "order": order,
+ },
)
- return type_validate_python(PageResult[SatoriMessage], await self._request(request))
+ return type_validate_python(PageDequeResult[SatoriMessage], await self._request(request))
@API
async def channel_get(self, *, channel_id: str) -> Channel:
@@ -445,9 +478,7 @@ async def guild_approve(self, *, request_id: str, approve: bool, comment: str) -
await self._request(request)
@API
- async def guild_member_list(
- self, *, guild_id: str, next_token: Optional[str] = None
- ) -> PageResult[Member]:
+ async def guild_member_list(self, *, guild_id: str, next_token: Optional[str] = None) -> PageResult[Member]:
request = Request(
"POST",
self.info.api_base / "guild.member.list",
@@ -665,15 +696,53 @@ async def friend_approve(self, *, request_id: str, approve: bool, comment: str)
await self._request(request)
@API
- async def internal(
- self,
- *,
- action: str,
- **kwargs,
- ) -> Any:
+ async def internal(self, *, action: str, **kwargs) -> Any:
+ """内部接口调用。
+
+ 参数:
+ action (str): 内部接口名称
+ **kwargs: 参数
+ """
request = Request(
"POST",
self.info.api_base / "internal" / action,
json=kwargs,
)
return await self._request(request)
+
+ @API
+ async def admin_login_list(self) -> list[Login]:
+ request = Request(
+ "POST",
+ self.info.api_base / "admin" / "login.list",
+ )
+ res = await self._request(request)
+ return [type_validate_python(Login, i) for i in res]
+
+ @overload
+ async def upload(self, *uploads: Upload) -> list[str]: ...
+
+ @overload
+ async def upload(self, **uploads: Upload) -> dict[str, str]: ...
+
+ async def upload(self, *args: Upload, **kwargs: Upload):
+ """上传文件。
+
+ 如果要发送的消息中含有图片或其他媒体资源,\
+ 可以使用此 API 将文件上传至 Satori 服务器并转换为 URL,以便在消息编码中使用。
+ """
+ if args and kwargs:
+ raise RuntimeError("upload can't accept both args and kwargs")
+ if args:
+ ids = []
+ for upload in args:
+ ids.append(str(id(upload)))
+
+ resp = await self.upload_create(**dict(zip(ids, args)))
+ return list(resp.values())
+ return await self.upload_create(**kwargs)
+
+ @API
+ async def upload_create(self, **kwargs: Upload) -> dict[str, str]:
+ request = Request("POST", self.info.api_base / "upload.create", files={k: v.dump() for k, v in kwargs.items()})
+ return await self._request(request)
diff --git a/nonebot/adapters/satori/config.py b/nonebot/adapters/satori/config.py
index 9eeda2f..322c7fa 100644
--- a/nonebot/adapters/satori/config.py
+++ b/nonebot/adapters/satori/config.py
@@ -1,4 +1,4 @@
-from typing import List, Optional
+from typing import Optional
from yarl import URL
from pydantic import Field, BaseModel
@@ -28,5 +28,5 @@ def ws_base(self):
class Config(BaseModel):
- satori_clients: List[ClientInfo] = Field(default_factory=list)
+ satori_clients: list[ClientInfo] = Field(default_factory=list)
"""client 配置"""
diff --git a/nonebot/adapters/satori/element.py b/nonebot/adapters/satori/element.py
index 6c769f6..3e6e4e7 100644
--- a/nonebot/adapters/satori/element.py
+++ b/nonebot/adapters/satori/element.py
@@ -1,8 +1,9 @@
import re
from enum import IntEnum
+from collections.abc import Iterable
from typing_extensions import TypeAlias
from dataclasses import field, dataclass
-from typing import Any, Dict, List, Union, Literal, TypeVar, Callable, Iterable, Optional, TypedDict, cast
+from typing import Any, Union, Literal, TypeVar, Callable, Optional, TypedDict, cast
T = TypeVar("T")
@@ -28,24 +29,20 @@ def camel_case(source: str) -> str:
def param_case(source: str) -> str:
- return re.sub(
- ".[A-Z]+", lambda mat: mat[0][0] + "-" + mat[0][1:].lower(), uncapitalize(source).replace("_", "-")
- )
+ return re.sub(".[A-Z]+", lambda mat: mat[0][0] + "-" + mat[0][1:].lower(), uncapitalize(source).replace("_", "-"))
def snake_case(source: str) -> str:
- return re.sub(
- ".[A-Z]", lambda mat: mat[0][0] + "_" + mat[0][1:].lower(), uncapitalize(source).replace("-", "_")
- )
+ return re.sub(".[A-Z]", lambda mat: mat[0][0] + "_" + mat[0][1:].lower(), uncapitalize(source).replace("-", "_"))
-def ensure_list(value: Union[T, List[T], None]) -> List[T]:
+def ensure_list(value: Union[T, list[T], None]) -> list[T]:
return value if isinstance(value, list) else [value] if value else []
S = TypeVar("S")
-Fragment: TypeAlias = Union[str, "Element", List[Union[str, "Element"]]]
-Render: TypeAlias = Callable[[dict, List["Element"], S], T]
+Fragment: TypeAlias = Union[str, "Element", list[Union[str, "Element"]]]
+Render: TypeAlias = Callable[[dict, list["Element"], S], T]
Visitor: TypeAlias = Callable[["Element", S], T]
@@ -60,7 +57,7 @@ def make_element(content: Union[str, bool, int, float, "Element"]) -> Optional["
raise ValueError(f"Invalid content: {content!r}")
-def make_elements(content: Fragment) -> List["Element"]:
+def make_elements(content: Fragment) -> list["Element"]:
if isinstance(content, list):
res = [make_element(c) for c in content]
else:
@@ -70,14 +67,14 @@ def make_elements(content: Fragment) -> List["Element"]:
class Element:
type: str
- attrs: Dict[str, Any]
- children: List["Element"]
+ attrs: dict[str, Any]
+ children: list["Element"]
source: Optional[str] = None
def __init__(
self,
type: Union[str, Render[Fragment, Any]],
- attrs: Optional[Dict[str, Any]] = None,
+ attrs: Optional[dict[str, Any]] = None,
*children: Fragment,
) -> None:
self.attrs = {}
@@ -149,8 +146,8 @@ class Selector:
comb_pat = re.compile(" *([ >+~]) *")
-def parse_selector(input: str) -> List[List[Selector]]:
- def _quert(query: str) -> List[Selector]:
+def parse_selector(input: str) -> list[list[Selector]]:
+ def _quert(query: str) -> list[Selector]:
selectors = []
combinator = " "
while mat := comb_pat.search(query):
@@ -168,7 +165,7 @@ def _quert(query: str) -> List[Selector]:
return [_quert(q) for q in input.split(",")]
-def select(source: Union[str, List[Element]], query: Union[str, List[List[Selector]]]) -> List[Element]:
+def select(source: Union[str, list[Element]], query: Union[str, list[list[Selector]]]) -> list[Element]:
if not source or not query:
return []
if isinstance(source, str):
@@ -177,10 +174,10 @@ def select(source: Union[str, List[Element]], query: Union[str, List[List[Select
query = parse_selector(query)
if not query:
return []
- adjacent: List[List[Selector]] = []
+ adjacent: list[list[Selector]] = []
results = []
for index, elem in enumerate(source):
- inner: List[List[Selector]] = []
+ inner: list[list[Selector]] = []
local = [*query, *adjacent]
adjacent = []
matched = False
@@ -231,9 +228,7 @@ def interpolate(expr: str, context: dict) -> Any:
r"(?P)|(?P<(/?)([^!\s>/]*)([^>]*?)\s*(/?)>)|(?P\{(?P[@:/#][^\s\}]*)?[\s\S]*?\})"
)
attr_pat1 = re.compile(r"([^\s=]+)(?:=\"(?P[^\"]*)\"|='(?P[^']*)')?", re.S)
-attr_pat2 = re.compile(
- r"([^\s=]+)(?:=\"(?P[^\"]*)\"|='(?P[^']*)'|=\{(?P[^\}]+)\})?", re.S
-)
+attr_pat2 = re.compile(r"([^\s=]+)(?:=\"(?P[^\"]*)\"|='(?P[^']*)'|=\{(?P[^\}]+)\})?", re.S)
class Position(IntEnum):
@@ -250,7 +245,7 @@ class Token:
positon: Position
source: str
extra: str
- children: Dict[str, List[Union[str, "Token"]]] = field(default_factory=dict)
+ children: dict[str, list[Union[str, "Token"]]] = field(default_factory=dict)
class StackItem(TypedDict):
@@ -258,8 +253,8 @@ class StackItem(TypedDict):
slot: str
-def fold_tokens(tokens: List[Union[str, Token]]) -> List[Union[str, Token]]:
- stack: List[StackItem] = [
+def fold_tokens(tokens: list[Union[str, Token]]) -> list[Union[str, Token]]:
+ stack: list[StackItem] = [
{
"token": Token(
type="angle",
@@ -296,8 +291,8 @@ def push_token(*tokens: Union[str, Token]):
return stack[-1]["token"].children["default"]
-def parse_tokens(tokens: List[Union[str, Token]], context: Optional[dict] = None) -> List[Element]:
- result: List[Element] = []
+def parse_tokens(tokens: list[Union[str, Token]], context: Optional[dict] = None) -> list[Element]:
+ result: list[Element] = []
for token in tokens:
if isinstance(token, str):
result.append(Element(type="text", attrs={"text": token}))
@@ -343,7 +338,7 @@ def parse_tokens(tokens: List[Union[str, Token]], context: Optional[dict] = None
def parse(src: str, context: Optional[dict] = None):
- tokens: List[Union[str, Token]] = []
+ tokens: list[Union[str, Token]] = []
def push_text(text: str):
if text:
diff --git a/nonebot/adapters/satori/event.py b/nonebot/adapters/satori/event.py
index 7861678..9af14c7 100644
--- a/nonebot/adapters/satori/event.py
+++ b/nonebot/adapters/satori/event.py
@@ -1,7 +1,7 @@
from enum import Enum
from copy import deepcopy
from typing_extensions import override
-from typing import TYPE_CHECKING, Dict, Type, TypeVar, Optional
+from typing import TYPE_CHECKING, TypeVar, Optional
from nonebot.utils import escape_tag
from nonebot.compat import model_dump, type_validate_python
@@ -84,10 +84,10 @@ def is_tome(self) -> bool:
return False
-EVENT_CLASSES: Dict[str, Type[Event]] = {}
+EVENT_CLASSES: dict[str, type[Event]] = {}
-def register_event_class(event_class: Type[E]) -> Type[E]:
+def register_event_class(event_class: type[E]) -> type[E]:
EVENT_CLASSES[event_class.__type__.value] = event_class
return event_class
@@ -325,7 +325,7 @@ class PrivateMessageCreatedEvent(MessageCreatedEvent, PrivateMessageEvent):
def get_event_description(self) -> str:
return escape_tag(
f"Message {self.msg_id} from "
- f"{self.user.name or ''}({self.channel.id}): {self.get_message()!r}"
+ f"{self.user.name or self.user.nick or ''}({self.channel.id}): {self.get_message()!r}"
)
@@ -334,7 +334,7 @@ class PublicMessageCreatedEvent(MessageCreatedEvent, PublicMessageEvent):
def get_event_description(self) -> str:
return escape_tag(
f"Message {self.msg_id} from "
- f"{self.member.name if self.member else (self.user.name or '')}({self.user.id})"
+ f"{(self.member.nick if self.member else None) or (self.user.name or self.user.nick or '')}({self.user.id})"
f"@[{self.channel.name or ''}:{self.channel.id}]"
f": {self.get_message()!r}"
)
@@ -343,7 +343,9 @@ def get_event_description(self) -> str:
class PrivateMessageDeletedEvent(MessageDeletedEvent, PrivateMessageEvent):
@override
def get_event_description(self) -> str:
- return escape_tag(f"Message {self.msg_id} from " f"{self.user.name or ''}({self.channel.id}) deleted")
+ return escape_tag(
+ f"Message {self.msg_id} from " f"{self.user.name or self.user.nick or ''}({self.channel.id}) deleted"
+ )
class PublicMessageDeletedEvent(MessageDeletedEvent, PublicMessageEvent):
@@ -351,7 +353,7 @@ class PublicMessageDeletedEvent(MessageDeletedEvent, PublicMessageEvent):
def get_event_description(self) -> str:
return escape_tag(
f"Message {self.msg_id} from "
- f"{self.member.name if self.member else (self.user.name or '')}({self.user.id})"
+ f"{(self.member.nick if self.member else None) or (self.user.name or self.user.nick or '')}({self.user.id})"
f"@[{self.channel.name or ''}:{self.channel.id}] deleted"
)
@@ -361,7 +363,7 @@ class PrivateMessageUpdatedEvent(MessageUpdatedEvent, PrivateMessageEvent):
def get_event_description(self) -> str:
return escape_tag(
f"Message {self.msg_id} from "
- f"{self.user.name or ''}({self.channel.id}) updated"
+ f"{self.user.name or self.user.nick or ''}({self.channel.id}) updated"
f": {self.get_message()!r}"
)
@@ -371,7 +373,7 @@ class PublicMessageUpdatedEvent(MessageUpdatedEvent, PublicMessageEvent):
def get_event_description(self) -> str:
return escape_tag(
f"Message {self.msg_id} from "
- f"{self.member.name if self.member else (self.user.name or '')}({self.user.id})"
+ f"{(self.member.nick if self.member else None) or (self.user.name or self.user.nick or '')}({self.user.id})"
f"@[{self.channel.name or ''}:{self.channel.id}] updated"
f": {self.get_message()!r}"
)
diff --git a/nonebot/adapters/satori/exception.py b/nonebot/adapters/satori/exception.py
index 6b584d6..f261983 100644
--- a/nonebot/adapters/satori/exception.py
+++ b/nonebot/adapters/satori/exception.py
@@ -19,9 +19,7 @@ def __init__(self, response: Response):
self.content = response.content
def __repr__(self) -> str:
- return (
- f"<{self.__class__.__name__}: {self.status_code}, headers={self.headers}, content={self.content}>"
- )
+ return f"<{self.__class__.__name__}: {self.status_code}, headers={self.headers}, content={self.content}>"
def __str__(self):
return self.__repr__()
diff --git a/nonebot/adapters/satori/message.py b/nonebot/adapters/satori/message.py
index 4a189bd..7cf1eb8 100644
--- a/nonebot/adapters/satori/message.py
+++ b/nonebot/adapters/satori/message.py
@@ -2,9 +2,10 @@
from io import BytesIO
from pathlib import Path
from base64 import b64encode
+from collections.abc import Iterable
from dataclasses import InitVar, field, dataclass
+from typing import Any, Union, Optional, TypedDict
from typing_extensions import Self, NotRequired, override
-from typing import Any, Dict, List, Type, Tuple, Union, Iterable, Optional, TypedDict
from nonebot.adapters import Message as BaseMessage
from nonebot.adapters import MessageSegment as BaseMessageSegment
@@ -47,7 +48,7 @@ def __call__(self, *segments: Union[str, Iterable["MessageSegment"], "MessageSeg
@classmethod
@override
- def get_message_class(cls) -> Type["Message"]:
+ def get_message_class(cls) -> type["Message"]:
return Message
@staticmethod
@@ -371,7 +372,7 @@ def is_text(self) -> bool:
class TextData(TypedDict):
text: str
- styles: Dict[Tuple[int, int], List[str]]
+ styles: dict[tuple[int, int], list[str]]
@dataclass
@@ -728,7 +729,7 @@ class Custom(MessageSegment):
}
-def handle(element: Element, upper_styles: Optional[List[str]] = None):
+def handle(element: Element, upper_styles: Optional[list[str]] = None):
tag = element.tag()
if tag in ELEMENT_TYPE_MAP:
seg_cls, seg_type = ELEMENT_TYPE_MAP[tag]
@@ -771,15 +772,13 @@ def handle(element: Element, upper_styles: Optional[List[str]] = None):
*(handle(child, [*(upper_styles or [])]) for child in element.children),
)
else:
- yield Custom(tag, element.attrs.copy())(
- *(handle(child, [*(upper_styles or [])]) for child in element.children)
- )
+ yield Custom(tag, element.attrs.copy())(*(handle(child, [*(upper_styles or [])]) for child in element.children))
class Message(BaseMessage[MessageSegment]):
@classmethod
@override
- def get_segment_class(cls) -> Type[MessageSegment]:
+ def get_segment_class(cls) -> type[MessageSegment]:
return MessageSegment
def __init__(
@@ -808,7 +807,7 @@ def _construct(msg: str) -> Iterable[MessageSegment]:
yield MessageSegment.text(msg)
@classmethod
- def from_satori_element(cls, elements: List[Element]) -> "Message":
+ def from_satori_element(cls, elements: list[Element]) -> "Message":
msg = Message()
for elem in elements:
diff --git a/nonebot/adapters/satori/models.py b/nonebot/adapters/satori/models.py
index 3280d8f..a265918 100644
--- a/nonebot/adapters/satori/models.py
+++ b/nonebot/adapters/satori/models.py
@@ -1,6 +1,10 @@
+import mimetypes
+from os import PathLike
from enum import IntEnum
+from pathlib import Path
from datetime import datetime
-from typing import Any, Dict, List, Union, Generic, Literal, TypeVar, Optional
+from typing_extensions import TypeAlias
+from typing import IO, Any, Union, Generic, Literal, TypeVar, Optional
from pydantic import Field, BaseModel
from nonebot.compat import PYDANTIC_V2, ConfigDict
@@ -63,7 +67,6 @@ class Config:
class Member(BaseModel):
user: Optional[User] = None
- name: Optional[str] = None
nick: Optional[str] = None
avatar: Optional[str] = None
joined_at: Optional[datetime] = None
@@ -115,6 +118,8 @@ class Login(BaseModel):
self_id: Optional[str] = None
platform: Optional[str] = None
status: LoginStatus
+ features: list[str] = Field(default_factory=list)
+ proxy_urls: list[str] = Field(default_factory=list)
if PYDANTIC_V2:
model_config: ConfigDict = ConfigDict(extra="allow") # type: ignore
@@ -145,7 +150,7 @@ class Opcode(IntEnum):
class Payload(BaseModel):
op: Opcode = Field(...)
- body: Optional[Dict[str, Any]] = Field(None)
+ body: Optional[dict[str, Any]] = Field(None)
class Identify(BaseModel):
@@ -154,7 +159,7 @@ class Identify(BaseModel):
class Ready(BaseModel):
- logins: List[Login]
+ logins: list[Login]
class IdentifyPayload(Payload):
@@ -282,19 +287,50 @@ class EventPayload(Payload):
if PYDANTIC_V2:
- class PageResult(BaseModel, Generic[T]): # type: ignore
- data: List[T]
+ class PageResult(BaseModel, Generic[T]):
+ data: list[T]
next: Optional[str] = None
model_config: ConfigDict = ConfigDict(extra="allow") # type: ignore
+ class PageDequeResult(PageResult[T]):
+ prev: Optional[str] = None
+
else:
from pydantic.generics import GenericModel
class PageResult(GenericModel, Generic[T]):
- data: List[T]
+ data: list[T]
next: Optional[str] = None
class Config:
extra = "allow"
+
+ class PageDequeResult(PageResult[T]):
+ prev: Optional[str] = None
+
+
+Direction: TypeAlias = Literal["before", "after", "around"]
+Order: TypeAlias = Literal["asc", "desc"]
+
+
+class Upload:
+ def __init__(
+ self, file: Union[bytes, IO[bytes], PathLike], mimetype: str = "image/png", name: Optional[str] = None
+ ):
+ self.file = file
+ self.mimetype = mimetype
+
+ if isinstance(self.file, PathLike):
+ self.mimetype = mimetypes.guess_type(str(self.file))[0] or self.mimetype
+ self.name = Path(self.file).name
+ else:
+ self.name = name
+
+ def dump(self):
+ file = self.file
+
+ if isinstance(file, PathLike):
+ file = open(file, "rb")
+ return self.name, file, self.mimetype
diff --git a/nonebot/adapters/satori/utils.py b/nonebot/adapters/satori/utils.py
index 078136e..fb56457 100644
--- a/nonebot/adapters/satori/utils.py
+++ b/nonebot/adapters/satori/utils.py
@@ -1,6 +1,7 @@
from functools import partial
+from collections.abc import Awaitable
from typing_extensions import ParamSpec, Concatenate
-from typing import TYPE_CHECKING, Type, Generic, TypeVar, Callable, Optional, Awaitable, overload
+from typing import TYPE_CHECKING, Generic, TypeVar, Callable, Optional, overload
from nonebot.utils import logger_wrapper
@@ -17,17 +18,17 @@ class API(Generic[B, P, R]):
def __init__(self, func: Callable[Concatenate[B, P], Awaitable[R]]) -> None:
self.func = func
- def __set_name__(self, owner: Type[B], name: str) -> None:
+ def __set_name__(self, owner: type[B], name: str) -> None:
self.name = name
@overload
- def __get__(self, obj: None, objtype: Type[B]) -> "API[B, P, R]": ...
+ def __get__(self, obj: None, objtype: type[B]) -> "API[B, P, R]": ...
@overload
- def __get__(self, obj: B, objtype: Optional[Type[B]]) -> Callable[P, Awaitable[R]]: ...
+ def __get__(self, obj: B, objtype: Optional[type[B]]) -> Callable[P, Awaitable[R]]: ...
def __get__(
- self, obj: Optional[B], objtype: Optional[Type[B]] = None
+ self, obj: Optional[B], objtype: Optional[type[B]] = None
) -> "API[B, P, R] | Callable[P, Awaitable[R]]":
if obj is None:
return self
diff --git a/pdm.lock b/pdm.lock
index c0d765b..025511a 100644
--- a/pdm.lock
+++ b/pdm.lock
@@ -914,27 +914,27 @@ files = [
[[package]]
name = "ruff"
-version = "0.4.7"
+version = "0.4.8"
requires_python = ">=3.7"
summary = "An extremely fast Python linter and code formatter, written in Rust."
files = [
- {file = "ruff-0.4.7-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:e089371c67892a73b6bb1525608e89a2aca1b77b5440acf7a71dda5dac958f9e"},
- {file = "ruff-0.4.7-py3-none-macosx_11_0_arm64.whl", hash = "sha256:10f973d521d910e5f9c72ab27e409e839089f955be8a4c8826601a6323a89753"},
- {file = "ruff-0.4.7-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:59c3d110970001dfa494bcd95478e62286c751126dfb15c3c46e7915fc49694f"},
- {file = "ruff-0.4.7-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:fa9773c6c00f4958f73b317bc0fd125295110c3776089f6ef318f4b775f0abe4"},
- {file = "ruff-0.4.7-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:07fc80bbb61e42b3b23b10fda6a2a0f5a067f810180a3760c5ef1b456c21b9db"},
- {file = "ruff-0.4.7-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:fa4dafe3fe66d90e2e2b63fa1591dd6e3f090ca2128daa0be33db894e6c18648"},
- {file = "ruff-0.4.7-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a7c0083febdec17571455903b184a10026603a1de078428ba155e7ce9358c5f6"},
- {file = "ruff-0.4.7-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ad1b20e66a44057c326168437d680a2166c177c939346b19c0d6b08a62a37589"},
- {file = "ruff-0.4.7-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cbf5d818553add7511c38b05532d94a407f499d1a76ebb0cad0374e32bc67202"},
- {file = "ruff-0.4.7-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:50e9651578b629baec3d1513b2534de0ac7ed7753e1382272b8d609997e27e83"},
- {file = "ruff-0.4.7-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:8874a9df7766cb956b218a0a239e0a5d23d9e843e4da1e113ae1d27ee420877a"},
- {file = "ruff-0.4.7-py3-none-musllinux_1_2_i686.whl", hash = "sha256:b9de9a6e49f7d529decd09381c0860c3f82fa0b0ea00ea78409b785d2308a567"},
- {file = "ruff-0.4.7-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:13a1768b0691619822ae6d446132dbdfd568b700ecd3652b20d4e8bc1e498f78"},
- {file = "ruff-0.4.7-py3-none-win32.whl", hash = "sha256:769e5a51df61e07e887b81e6f039e7ed3573316ab7dd9f635c5afaa310e4030e"},
- {file = "ruff-0.4.7-py3-none-win_amd64.whl", hash = "sha256:9e3ab684ad403a9ed1226894c32c3ab9c2e0718440f6f50c7c5829932bc9e054"},
- {file = "ruff-0.4.7-py3-none-win_arm64.whl", hash = "sha256:10f2204b9a613988e3484194c2c9e96a22079206b22b787605c255f130db5ed7"},
- {file = "ruff-0.4.7.tar.gz", hash = "sha256:2331d2b051dc77a289a653fcc6a42cce357087c5975738157cd966590b18b5e1"},
+ {file = "ruff-0.4.8-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:7663a6d78f6adb0eab270fa9cf1ff2d28618ca3a652b60f2a234d92b9ec89066"},
+ {file = "ruff-0.4.8-py3-none-macosx_11_0_arm64.whl", hash = "sha256:eeceb78da8afb6de0ddada93112869852d04f1cd0f6b80fe464fd4e35c330913"},
+ {file = "ruff-0.4.8-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:aad360893e92486662ef3be0a339c5ca3c1b109e0134fcd37d534d4be9fb8de3"},
+ {file = "ruff-0.4.8-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:284c2e3f3396fb05f5f803c9fffb53ebbe09a3ebe7dda2929ed8d73ded736deb"},
+ {file = "ruff-0.4.8-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a7354f921e3fbe04d2a62d46707e569f9315e1a613307f7311a935743c51a764"},
+ {file = "ruff-0.4.8-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:72584676164e15a68a15778fd1b17c28a519e7a0622161eb2debdcdabdc71883"},
+ {file = "ruff-0.4.8-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9678d5c9b43315f323af2233a04d747409d1e3aa6789620083a82d1066a35199"},
+ {file = "ruff-0.4.8-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704977a658131651a22b5ebeb28b717ef42ac6ee3b11e91dc87b633b5d83142b"},
+ {file = "ruff-0.4.8-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d05f8d6f0c3cce5026cecd83b7a143dcad503045857bc49662f736437380ad45"},
+ {file = "ruff-0.4.8-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:6ea874950daca5697309d976c9afba830d3bf0ed66887481d6bca1673fc5b66a"},
+ {file = "ruff-0.4.8-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:fc95aac2943ddf360376be9aa3107c8cf9640083940a8c5bd824be692d2216dc"},
+ {file = "ruff-0.4.8-py3-none-musllinux_1_2_i686.whl", hash = "sha256:384154a1c3f4bf537bac69f33720957ee49ac8d484bfc91720cc94172026ceed"},
+ {file = "ruff-0.4.8-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:e9d5ce97cacc99878aa0d084c626a15cd21e6b3d53fd6f9112b7fc485918e1fa"},
+ {file = "ruff-0.4.8-py3-none-win32.whl", hash = "sha256:6d795d7639212c2dfd01991259460101c22aabf420d9b943f153ab9d9706e6a9"},
+ {file = "ruff-0.4.8-py3-none-win_amd64.whl", hash = "sha256:e14a3a095d07560a9d6769a72f781d73259655919d9b396c650fc98a8157555d"},
+ {file = "ruff-0.4.8-py3-none-win_arm64.whl", hash = "sha256:14019a06dbe29b608f6b7cbcec300e3170a8d86efaddb7b23405cb7f7dcaf780"},
+ {file = "ruff-0.4.8.tar.gz", hash = "sha256:16d717b1d57b2e2fd68bd0bf80fb43931b79d05a7131aa477d66fc40fbd86268"},
]
[[package]]
diff --git a/pyproject.toml b/pyproject.toml
index 284de19..dc55627 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,12 +1,12 @@
[project]
name = "nonebot-adapter-satori"
-version = "0.11.5"
+version = "0.12.0"
description = "Satori Protocol Adapter for Nonebot2"
authors = [
{name = "RF-Tar-Railt",email = "rf_tar_railt@qq.com"},
]
dependencies = [
- "nonebot2>=2.2.0",
+ "nonebot2>=2.3.0",
]
requires-python = ">=3.9"
readme = "README.md"
@@ -27,7 +27,7 @@ dev = [
"black>=24.1.1",
"ruff>=0.2.1",
"pre-commit>=3.5.0",
- "nonebot2[httpx,websockets]>=2.2.0",
+ "nonebot2[httpx,websockets]>=2.3.0",
]
test = [
"nonebug>=0.3.5",
@@ -46,14 +46,14 @@ format = { composite = ["isort ./","black ./","ruff check ./"] }
[tool.black]
-line-length = 110
+line-length = 120
include = '\.pyi?$'
extend-exclude = '''
'''
[tool.isort]
profile = "black"
-line_length = 110
+line_length = 120
length_sort = true
skip_gitignore = true
force_sort_within_sections = true
@@ -61,8 +61,8 @@ extra_standard_library = ["typing_extensions"]
[tool.ruff]
-line-length = 110
-target-version = "py38"
+line-length = 120
+target-version = "py39"
[tool.ruff.lint]
select = ["E", "W", "F", "UP", "C", "T", "Q"]
diff --git a/tests/conftest.py b/tests/conftest.py
index d7d959f..aac483b 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,6 +1,6 @@
import threading
from pathlib import Path
-from typing import Generator
+from collections.abc import Generator
import pytest
from nonebot.drivers import URL
@@ -11,9 +11,7 @@
import nonebot
import nonebot.adapters
-nonebot.adapters.__path__.append( # type: ignore
- str((Path(__file__).parent.parent / "nonebot" / "adapters").resolve())
-)
+nonebot.adapters.__path__.append(str((Path(__file__).parent.parent / "nonebot" / "adapters").resolve())) # type: ignore
from nonebot.adapters.satori import Adapter as SatoriAdapter
diff --git a/tests/fake_server.py b/tests/fake_server.py
index 04b2371..cf16a1f 100644
--- a/tests/fake_server.py
+++ b/tests/fake_server.py
@@ -2,7 +2,7 @@
import base64
import socket
from queue import Empty, Queue
-from typing import Dict, List, Union, TypeVar, Callable
+from typing import Union, TypeVar, Callable
from wsproto.events import Ping
from werkzeug import Request, Response
@@ -32,7 +32,7 @@ def json_safe(string, content_type="application/octet-stream") -> str:
).decode("utf-8")
-def flattern(d: "MultiDict[K, V]") -> Dict[K, Union[V, List[V]]]:
+def flattern(d: "MultiDict[K, V]") -> dict[K, Union[V, list[V]]]:
return {k: v[0] if len(v) == 1 else v for k, v in d.to_dict(flat=False).items()}
diff --git a/tests/test_adapter.py b/tests/test_adapter.py
index 97527ac..8873428 100644
--- a/tests/test_adapter.py
+++ b/tests/test_adapter.py
@@ -6,6 +6,7 @@
import nonebot
from nonebot.adapters.satori import Bot, Adapter
+from nonebot.adapters.satori.models import Login, LoginStatus
from nonebot.adapters.satori.event import PublicMessageCreatedEvent
@@ -16,14 +17,13 @@ async def test_adapter(app: App):
@cmd.handle()
async def handle(bot: Bot):
- await bot.send_message(
- channel_id="67890",
- message="hello",
- )
+ await bot.send_message(channel="67890", message="hello")
async with app.test_matcher(cmd) as ctx:
adapter: Adapter = nonebot.get_adapter(Adapter)
- bot: Bot = ctx.create_bot(base=Bot, adapter=adapter, self_id="0", platform="test", info=None)
+ bot: Bot = ctx.create_bot(
+ base=Bot, adapter=adapter, self_id="0", login=Login(status=LoginStatus.CONNECT), info=None
+ )
ctx.receive_event(
bot,
diff --git a/tests/test_message.py b/tests/test_message.py
index 0322a9a..c02622e 100644
--- a/tests/test_message.py
+++ b/tests/test_message.py
@@ -14,12 +14,11 @@ def test_message():
assert Message.from_satori_element(parse(code))[0].data["chronocat:seq"] == "1"
assert Message("1234")[0].data["test:aaa"] is True
assert (
- Message.from_satori_element(parse(""))[0].data["test:bbb"]
- == "foo"
+ Message.from_satori_element(parse(""))[0].data["test:bbb"] == "foo"
)
- assert Message.from_satori_element(
- parse("")
- )[0].data == {"id": "265", "name": "[辣眼睛]", "platform": "chronocat"}
+ assert Message.from_satori_element(parse(""))[
+ 0
+ ].data == {"id": "265", "name": "[辣眼睛]", "platform": "chronocat"}
test_message1 = MessageSegment(type="chronocat:face", data={"id": 12}) + "\n" + "Hello Yoshi"
assert str(test_message1) == '\nHello Yoshi'
@@ -59,8 +58,6 @@ def test_message_fallback():
"""
msg = Message.from_satori_element(parse(code))
- assert (
- str(msg[0].children) == '当前平台不支持发送视频,请在这里观看视频!'
- )
+ assert str(msg[0].children) == '当前平台不支持发送视频,请在这里观看视频!'
assert msg.extract_plain_text() == "当前平台不支持发送视频,请在这里观看视频!"
assert list(msg.query("link"))[0].data["text"] == "http://aa.com/a.mp4"