Skip to content

Commit

Permalink
✨ version 0.12.0
Browse files Browse the repository at this point in the history
impl Satori Protocol V1.1
  • Loading branch information
RF-Tar-Railt committed Jun 13, 2024
1 parent 01a6b19 commit 7a7ab98
Show file tree
Hide file tree
Showing 15 changed files with 257 additions and 165 deletions.
21 changes: 9 additions & 12 deletions nonebot/adapters/satori/adapter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import asyncio
from typing_extensions import override
from typing import Any, Dict, List, Literal, Optional
from typing import Any, Literal, Optional

from nonebot.utils import escape_tag
from nonebot.exception import WebSocketClosed
Expand Down Expand Up @@ -40,15 +40,15 @@


class Adapter(BaseAdapter):
bots: Dict[str, Bot]
bots: dict[str, Bot]

@override
def __init__(self, driver: Driver, **kwargs: Any):
super().__init__(driver, **kwargs)
# 读取适配器所需的配置项
self.satori_config: Config = get_plugin_config(Config)
self.tasks: List[asyncio.Task] = [] # 存储 ws 任务
self.sequences: Dict[str, int] = {} # 存储 连接序列号
self.tasks: list[asyncio.Task] = [] # 存储 ws 任务
self.sequences: dict[str, int] = {} # 存储 连接序列号
self.setup()

@classmethod
Expand Down Expand Up @@ -136,16 +136,15 @@ async def _authenticate(self, info: ClientInfo, ws: WebSocket) -> Optional[Liter
if login.status != LoginStatus.ONLINE:
continue
if login.self_id not in self.bots:
bot = Bot(self, login.self_id, login.platform or "satori", info)
bot = Bot(self, login.self_id, login, info)
self.bot_connect(bot)
log(
"INFO",
f"<y>Bot {escape_tag(bot.self_id)}</y> connected",
)
else:
bot = self.bots[login.self_id]
if login.user:
bot.on_ready(login.user)
bot._update(login)
if not self.bots:
log("WARNING", "No bots connected!")
return
Expand Down Expand Up @@ -232,9 +231,7 @@ async def _loop(self, info: ClientInfo, ws: WebSocket):
)
else:
if isinstance(event, LoginAddedEvent):
bot = Bot(self, event.self_id, event.platform, info)
if event.user:
bot.on_ready(event.user)
bot = Bot(self, event.self_id, event.login, info)
self.bot_connect(bot)
log(
"INFO",
Expand All @@ -247,8 +244,8 @@ async def _loop(self, info: ClientInfo, ws: WebSocket):
f"<y>Bot {escape_tag(event.self_id)}</y> disconnected",
)
continue
elif isinstance(event, LoginUpdatedEvent) and event.user:
self.bots[event.self_id].on_ready(event.user)
elif isinstance(event, LoginUpdatedEvent):
self.bots[event.self_id]._update(event.login)
if not (bot := self.bots.get(event.self_id)):
log(
"WARNING",
Expand Down
161 changes: 115 additions & 46 deletions nonebot/adapters/satori/bot.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import re
import json
from typing_extensions import override
from typing import TYPE_CHECKING, Any, Dict, List, Union, Optional
from typing import TYPE_CHECKING, Any, Union, Literal, Optional, overload

from nonebot.message import handle_event
from nonebot.drivers import Request, Response
Expand All @@ -12,10 +12,11 @@
from .element import parse
from .utils import API, log
from .config import ClientInfo
from .models import PageDequeResult
from .event import Event, MessageEvent
from .models import MessageObject as SatoriMessage
from .message import Author, Message, RenderMessage, MessageSegment
from .models import Role, User, Guild, Login, Member, Channel, PageResult
from .models import Role, User, Guild, Login, Order, Member, Upload, Channel, Direction, PageResult
from .exception import (
ActionFailed,
NetworkError,
Expand Down Expand Up @@ -66,11 +67,7 @@ async def _check_reply(
event.to_me = True
else:
return
if (
len(message) > index
and message[index].type == "at"
and message[index].data.get("id") == str(bot.get_self_id())
):
if len(message) > index and message[index].type == "at" and message[index].data.get("id") == str(bot.get_self_id()):
event.to_me = True
del message[index]
if len(message) > index and message[index].type == "text":
Expand Down Expand Up @@ -149,22 +146,22 @@ class Bot(BaseBot):
adapter: "Adapter"

@override
def __init__(self, adapter: "Adapter", self_id: str, platform: str, info: ClientInfo):
def __init__(self, adapter: "Adapter", self_id: str, login: Login, info: ClientInfo):
super().__init__(adapter, self_id)

# Bot 配置信息
self.info: ClientInfo = info
# Bot 自身所属平台
self.platform: str = platform
self.platform: str = login.platform or "satori"
# Bot 自身信息
self._self_info: Optional[User] = None
self._self_info = login

def __getattr__(self, item):
raise AttributeError(f"'Bot' object has no attribute '{item}'")

def get_self_id(self):
if self._self_info:
return self._self_info.id
if self._self_info and self._self_info.user:
return self._self_info.user.id
return self.self_id

@property
Expand All @@ -175,14 +172,14 @@ def ready(self) -> bool:
@property
def self_info(self) -> User:
"""Bot 自身信息,仅当 Bot 连接鉴权完成后可用"""
if self._self_info is None:
if self._self_info.user is None:
raise RuntimeError(f"Bot {self.self_id} of {self.platform} is not connected!")
return self._self_info
return self._self_info.user

def on_ready(self, user: User) -> None:
self._self_info = user
def _update(self, login: Login) -> None:
self._self_info = login

def get_authorization_header(self) -> Dict[str, str]:
def get_authorization_header(self) -> dict[str, str]:
"""获取当前 Bot 的鉴权信息"""
header = {
"Authorization": f"Bearer {self.info.token}",
Expand All @@ -200,9 +197,17 @@ async def handle_event(self, event: Event) -> None:
_check_nickname(self, event)
await handle_event(self, event)

def _handle_response(self, response: Response) -> Any:
@overload
def _handle_response(self, response: Response) -> dict: ...

@overload
def _handle_response(self, response: Response, noreturn: Literal[True]) -> None: ...

def _handle_response(self, response: Response, noreturn=False) -> Any:
if 200 <= response.status_code < 300:
return response.content and json.loads(response.content)
if not noreturn:
return
return json.loads(response.content) if response.content else {}
elif response.status_code == 400:
raise BadRequestException(response)
elif response.status_code == 401:
Expand All @@ -228,57 +233,71 @@ async def _request(self, request: Request) -> Any:

return self._handle_response(response)

async def download(self, url: str) -> bytes:
"""访问内部链接。"""
request = Request("GET", self.info.api_base / "proxy" / url.lstrip("/"))
try:
response = await self.adapter.request(request)
except Exception as e:
raise NetworkError("API request failed") from e

self._handle_response(response, noreturn=True)
return response.content # type: ignore

@override
async def send(
self,
event: Event,
message: Union[str, Message, MessageSegment],
**kwargs,
) -> List[SatoriMessage]:
) -> list[SatoriMessage]:
if not event.channel:
raise RuntimeError("Event cannot be replied to!")
return await self.send_message(event.channel.id, message)

async def send_message(
self,
channel_id: str,
channel: Union[str, Channel],
message: Union[str, Message, MessageSegment],
) -> List[SatoriMessage]:
) -> list[SatoriMessage]:
"""发送消息
参数:
channel_id: 要发送的频道 ID
message: 要发送的消息
channel (str | Channel): 要发送的频道 ID
message (str | Message | MessageSegment): 要发送的消息
"""
channel_id = channel.id if isinstance(channel, Channel) else channel
return await self.message_create(channel_id=channel_id, content=str(message))

async def send_private_message(
self,
user_id: str,
user: Union[str, User],
message: Union[str, Message, MessageSegment],
) -> List[SatoriMessage]:
) -> list[SatoriMessage]:
"""发送私聊消息
参数:
user_id: 要发送的用户 ID
message: 要发送的消息
user (str | User): 要发送的用户 ID
message (str | Message | MessageSegment): 要发送的消息
"""
user_id = user.id if isinstance(user, User) else user
channel = await self.user_channel_create(user_id=user_id)
return await self.message_create(channel_id=channel.id, content=str(message))

async def update_message(
self,
channel_id: str,
channel: Union[str, Channel],
message_id: str,
message: Union[str, Message, MessageSegment],
):
"""更新消息
参数:
channel_id: 要更新的频道 ID
message_id: 要更新的消息 ID
message: 要更新的消息
channel (str | Channel): 要更新的频道 ID
message_id (str): 要更新的消息 ID
message (str | Message | MessageSegment): 更新后的消息
"""
channel_id = channel.id if isinstance(channel, Channel) else channel
await self.message_update(channel_id=channel_id, message_id=message_id, content=str(message))

@API
Expand All @@ -287,7 +306,7 @@ async def message_create(
*,
channel_id: str,
content: str,
) -> List[SatoriMessage]:
) -> list[SatoriMessage]:
request = Request(
"POST",
self.info.api_base / "message.create",
Expand Down Expand Up @@ -336,14 +355,28 @@ async def message_update(

@API
async def message_list(
self, *, channel_id: str, next_token: Optional[str] = None
) -> PageResult[SatoriMessage]:
self,
*,
channel_id: str,
next_token: Optional[str] = None,
direction: Direction = "before",
limit: int = 50,
order: Order = "asc",
) -> PageDequeResult[SatoriMessage]:
if not next_token and direction != "before":
raise ValueError("Invalid direction")
request = Request(
"POST",
self.info.api_base / "message.list",
json={"channel_id": channel_id, "next": next_token},
json={
"channel_id": channel_id,
"next": next_token,
"direction": direction,
"limit": limit,
"order": order,
},
)
return type_validate_python(PageResult[SatoriMessage], await self._request(request))
return type_validate_python(PageDequeResult[SatoriMessage], await self._request(request))

@API
async def channel_get(self, *, channel_id: str) -> Channel:
Expand Down Expand Up @@ -445,9 +478,7 @@ async def guild_approve(self, *, request_id: str, approve: bool, comment: str) -
await self._request(request)

@API
async def guild_member_list(
self, *, guild_id: str, next_token: Optional[str] = None
) -> PageResult[Member]:
async def guild_member_list(self, *, guild_id: str, next_token: Optional[str] = None) -> PageResult[Member]:
request = Request(
"POST",
self.info.api_base / "guild.member.list",
Expand Down Expand Up @@ -665,15 +696,53 @@ async def friend_approve(self, *, request_id: str, approve: bool, comment: str)
await self._request(request)

@API
async def internal(
self,
*,
action: str,
**kwargs,
) -> Any:
async def internal(self, *, action: str, **kwargs) -> Any:
"""内部接口调用。
参数:
action (str): 内部接口名称
**kwargs: 参数
"""
request = Request(
"POST",
self.info.api_base / "internal" / action,
json=kwargs,
)
return await self._request(request)

@API
async def admin_login_list(self) -> list[Login]:
request = Request(
"POST",
self.info.api_base / "admin" / "login.list",
)
res = await self._request(request)
return [type_validate_python(Login, i) for i in res]

@overload
async def upload(self, *uploads: Upload) -> list[str]: ...

@overload
async def upload(self, **uploads: Upload) -> dict[str, str]: ...

async def upload(self, *args: Upload, **kwargs: Upload):
"""上传文件。
如果要发送的消息中含有图片或其他媒体资源,\
可以使用此 API 将文件上传至 Satori 服务器并转换为 URL,以便在消息编码中使用。
"""
if args and kwargs:
raise RuntimeError("upload can't accept both args and kwargs")
if args:
ids = []
for upload in args:
ids.append(str(id(upload)))

resp = await self.upload_create(**dict(zip(ids, args)))
return list(resp.values())
return await self.upload_create(**kwargs)

@API
async def upload_create(self, **kwargs: Upload) -> dict[str, str]:
request = Request("POST", self.info.api_base / "upload.create", files={k: v.dump() for k, v in kwargs.items()})
return await self._request(request)
4 changes: 2 additions & 2 deletions nonebot/adapters/satori/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional
from typing import Optional

from yarl import URL
from pydantic import Field, BaseModel
Expand Down Expand Up @@ -28,5 +28,5 @@ def ws_base(self):


class Config(BaseModel):
satori_clients: List[ClientInfo] = Field(default_factory=list)
satori_clients: list[ClientInfo] = Field(default_factory=list)
"""client 配置"""
Loading

0 comments on commit 7a7ab98

Please sign in to comment.