Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into fix/0126
Browse files Browse the repository at this point in the history
# Conflicts:
#	pyproject.toml
  • Loading branch information
RF-Tar-Railt committed Sep 22, 2024
2 parents bfa5385 + 0ef197b commit 7a09f8c
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 19 deletions.
16 changes: 8 additions & 8 deletions nonebot/adapters/satori/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def payload_to_json(payload: Payload) -> str:
async def receive_payload(self, info: ClientInfo, ws: WebSocket) -> Payload:
payload = type_validate_python(PayloadType, json.loads(await ws.receive()))
if isinstance(payload, EventPayload):
self.sequences[info.identity] = payload.body.id
self.sequences[info.identity] = payload.body["id"]
return payload

async def _authenticate(self, info: ClientInfo, ws: WebSocket) -> Optional[Literal[True]]:
Expand Down Expand Up @@ -136,21 +136,21 @@ async def _authenticate(self, info: ClientInfo, ws: WebSocket) -> Optional[Liter
)
return
for login in resp.body.logins:
if not login.self_id:
if not login.id:
continue
if login.status != LoginStatus.ONLINE:
continue
if login.self_id not in self.bots:
bot = Bot(self, login.self_id, login, info)
if login.id not in self.bots:
bot = Bot(self, login.id, login, info)
self._bots[info.identity].add(bot.self_id)
self.bot_connect(bot)
log(
"INFO",
f"<y>Bot {escape_tag(bot.self_id)}</y> connected",
)
else:
self._bots[info.identity].add(login.self_id)
bot = self.bots[login.self_id]
self._bots[info.identity].add(login.id)
bot = self.bots[login.id]
bot._update(login)
if not self.bots:
log("WARNING", "No bots connected!")
Expand Down Expand Up @@ -230,7 +230,7 @@ async def _loop(self, info: ClientInfo, ws: WebSocket):
)
if isinstance(payload, EventPayload):
try:
event = self.payload_to_event(payload.body)
event = self.payload_to_event(type_validate_python(Event, payload.body))
except Exception as e:
log(
"WARNING",
Expand Down Expand Up @@ -265,7 +265,7 @@ async def _loop(self, info: ClientInfo, ws: WebSocket):
continue
if isinstance(event, (MessageEvent, InteractionEvent)):
event = event.convert()
asyncio.create_task(bot.handle_event(event))
_t = asyncio.create_task(bot.handle_event(event))
elif isinstance(payload, PongPayload):
log("TRACE", "Pong")
continue
Expand Down
25 changes: 21 additions & 4 deletions nonebot/adapters/satori/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing_extensions import override
from typing import TYPE_CHECKING, Any, Union, Literal, Optional, overload

from yarl import URL
from nonebot.message import handle_event
from nonebot.drivers import Request, Response
from nonebot.compat import model_dump, type_validate_python
Expand All @@ -16,7 +17,7 @@
from .event import Event, MessageEvent
from .models import MessageObject as SatoriMessage
from .message import Author, Message, RenderMessage, MessageSegment
from .models import Role, User, Guild, Login, Order, Member, Upload, Channel, Direction, PageResult
from .models import Role, User, Guild, Login, Order, Member, Upload, Channel, Direction, LoginType, PageResult
from .exception import (
ActionFailed,
NetworkError,
Expand Down Expand Up @@ -146,7 +147,7 @@ class Bot(BaseBot):
adapter: "Adapter"

@override
def __init__(self, adapter: "Adapter", self_id: str, login: Login, info: ClientInfo):
def __init__(self, adapter: "Adapter", self_id: str, login: LoginType, info: ClientInfo):
super().__init__(adapter, self_id)

# Bot 配置信息
Expand Down Expand Up @@ -181,7 +182,7 @@ def self_info(self) -> User:
raise RuntimeError(f"Bot {self.self_id} of {self.platform} is not connected!")
return self._self_info.user

def _update(self, login: Login) -> None:
def _update(self, login: LoginType) -> None:
self._self_info = login

def get_authorization_header(self) -> dict[str, str]:
Expand All @@ -190,6 +191,8 @@ def get_authorization_header(self) -> dict[str, str]:
"Authorization": f"Bearer {self.info.token}",
"X-Self-ID": self.self_id,
"X-Platform": self.platform,
"Satori-Platform": self.platform,
"Satori-Login-ID": self.self_id,
}
if not self.info.token:
del header["Authorization"]
Expand Down Expand Up @@ -238,9 +241,23 @@ async def _request(self, request: Request) -> Any:

return self._handle_response(response)

def ensure_url(self, url: str) -> URL:
"""确定链接形式。
若链接符合以下条件之一,则返回链接的代理形式 ({host}/{path}/{version}/proxy/{url}):
- 链接以 "upload://" 开头
- 链接开头出现在 self_info.proxy_urls 中的某一项
"""
if url.startswith("upload"):
return self.info.api_base / "proxy" / url.lstrip("/")
for proxy_url in self._self_info.proxy_urls:
if url.startswith(proxy_url):
return self.info.api_base / "proxy" / url.lstrip("/")
return URL(url)

async def download(self, url: str) -> bytes:
"""访问内部链接。"""
request = Request("GET", self.info.api_base / "proxy" / url.lstrip("/"))
request = Request("GET", self.ensure_url(url))
try:
response = await self.adapter.request(request)
except Exception as e:
Expand Down
52 changes: 49 additions & 3 deletions nonebot/adapters/satori/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,30 @@ class Login(BaseModel):
features: list[str] = Field(default_factory=list)
proxy_urls: list[str] = Field(default_factory=list)

@property
def id(self) -> Optional[str]:
return self.self_id or (self.user.id if self.user else None)

if PYDANTIC_V2:
model_config: ConfigDict = ConfigDict(extra="allow") # type: ignore

else:

class Config:
extra = "allow"


class LoginPreview(BaseModel):
user: User
platform: str
status: Optional[LoginStatus] = None
features: list[str] = Field(default_factory=list)
proxy_urls: list[str] = Field(default_factory=list)

@property
def id(self) -> str:
return self.user.id

if PYDANTIC_V2:
model_config: ConfigDict = ConfigDict(extra="allow") # type: ignore

Expand All @@ -130,6 +154,9 @@ class Config:
extra = "allow"


LoginType = Union[Login, LoginPreview]


class ArgvInteraction(BaseModel):
name: str
arguments: list
Expand Down Expand Up @@ -159,7 +186,7 @@ class Identify(BaseModel):


class Ready(BaseModel):
logins: list[Login]
logins: list[LoginType]


class IdentifyPayload(Payload):
Expand Down Expand Up @@ -243,7 +270,7 @@ class Event(BaseModel):
button: Optional[ButtonInteraction] = None
channel: Optional[Channel] = None
guild: Optional[Guild] = None
login: Optional[Login] = None
login: Optional[LoginType] = None
member: Optional[Member] = None
message: Optional[MessageObject] = None
operator: Optional[User] = None
Expand All @@ -262,6 +289,25 @@ def parse_timestamp(cls, v):
raise ValueError(f"invalid timestamp: {v}") from e
return datetime.fromtimestamp(timestamp / 1000)

@model_validator(mode="before")
def ensure_login(cls, values):
if "self_id" not in values and "platform" not in values:
log(
"WARNING",
"received event without `self_id` and `platform`, "
"this may be caused by Satori Server used protocol version 1.2.",
)
if "login" in values:
values["self_id"] = values["login"]["user"]["id"]
values["platform"] = values["login"]["platform"]
return values
log(
"WARNING",
"received event without login, " "this may be caused by a bug of Satori Server.",
)
return values
return values

if PYDANTIC_V2:
model_config: ConfigDict = ConfigDict(extra="allow") # type: ignore

Expand All @@ -273,7 +319,7 @@ class Config:

class EventPayload(Payload):
op: Literal[Opcode.EVENT] = Field(Opcode.EVENT)
body: Event
body: dict


PayloadType = Union[
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "nonebot-adapter-satori"
version = "0.12.6"
version = "0.13.0rc1"
description = "Satori Protocol Adapter for Nonebot2"
authors = [
{name = "RF-Tar-Railt",email = "[email protected]"},
Expand Down
11 changes: 8 additions & 3 deletions tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,6 @@
async def test_ws(app: App):
adapter: Adapter = nonebot.get_adapter(Adapter)

for client in adapter.satori_config.satori_clients:
adapter.tasks.append(asyncio.create_task(adapter.ws(client)))

@ws_handlers.put
def identify(json: dict) -> dict:
assert json["op"] == 3
Expand All @@ -36,6 +33,14 @@ def identify(json: dict) -> dict:
},
}

@ws_handlers.put
def _ping(json: dict) -> dict:
assert json == {"op": 1}
return {"op": 2}

for client in adapter.satori_config.satori_clients:
adapter.tasks.append(asyncio.create_task(adapter.ws(client)))

await asyncio.sleep(5)
bots = nonebot.get_bots()
assert "0" in bots
Expand Down

0 comments on commit 7a09f8c

Please sign in to comment.