From 6049dfbe05ce417103c8c2268cf98006a04fe5bc Mon Sep 17 00:00:00 2001 From: uy_sun Date: Tue, 12 Sep 2023 17:12:19 +0800 Subject: [PATCH] =?UTF-8?q?=E9=80=9A=E8=BF=87=20nonebot=5Fplugin=5Fuserinf?= =?UTF-8?q?o=20=E8=8E=B7=E5=8F=96=E7=94=A8=E6=88=B7=E5=90=8D=E7=A7=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/plugins/user/__init__.py | 23 ++++++++++++----------- src/utils/annotated.py | 6 ++++++ 2 files changed, 18 insertions(+), 11 deletions(-) diff --git a/src/plugins/user/__init__.py b/src/plugins/user/__init__.py index e6b4fd83..1f571d3d 100644 --- a/src/plugins/user/__init__.py +++ b/src/plugins/user/__init__.py @@ -2,7 +2,6 @@ from typing import cast from expiringdict import ExpiringDict -from nonebot.params import Depends from nonebot_plugin_alconna import ( Alconna, AlconnaQuery, @@ -12,16 +11,18 @@ on_alconna, ) from nonebot_plugin_datastore import create_session -from nonebot_plugin_session import Session, SessionLevel, extract_session +from nonebot_plugin_session import SessionLevel from sqlalchemy import select from sqlalchemy.orm import selectinload +from src.utils.annotated import MyUserInfo, Session + from .models import Bind, User -async def create_user(pid: str, platform: str): +async def create_user(pid: str, platform: str, nickname: str): async with create_session() as session: - user = User(name=pid) + user = User(name=nickname) session.add(user) bind = Bind( pid=pid, @@ -71,15 +72,14 @@ async def set_user(pid: str, platform: str, aid: int): @user_cmd.handle() -async def _(session: Session = Depends(extract_session)): - if session.platform == "unknown": +async def _(session: Session, user_info: MyUserInfo): + if session.platform == "unknown" or not session.id1: await bind_cmd.finish("不支持的平台") - - assert session.id1 and session.platform + return user = await get_user(session.id1, session.platform) if not user: - user = await create_user(session.id1, session.platform) + user = await create_user(session.id1, session.platform, user_info.user_name) await user_cmd.finish(f"{user.id} {user.name}") @@ -97,9 +97,10 @@ async def _(session: Session = Depends(extract_session)): @bind_cmd.handle() async def _( + session: Session, + user_info: MyUserInfo, token: str | None = None, remove: Query[bool] = AlconnaQuery("r.value", default=False), - session: Session = Depends(extract_session), ): if ( session.platform == "unknown" @@ -111,7 +112,7 @@ async def _( user = await get_user(session.id1, session.platform) if not user: - user = await create_user(session.id1, session.platform) + user = await create_user(session.id1, session.platform, user_info.user_name) if remove.result: async with create_session() as db_session: diff --git a/src/utils/annotated.py b/src/utils/annotated.py index e4ca3c72..d1cd7ec1 100644 --- a/src/utils/annotated.py +++ b/src/utils/annotated.py @@ -2,6 +2,10 @@ from nonebot.params import Depends from nonebot_plugin_datastore import get_session +from nonebot_plugin_session import Session as _Session +from nonebot_plugin_session import extract_session +from nonebot_plugin_userinfo import EventUserInfo +from nonebot_plugin_userinfo import UserInfo as _MyUserInfo from sqlalchemy.ext.asyncio import AsyncSession as _AsyncSession from .depends import ( @@ -24,3 +28,5 @@ OptionalPlainTextArgs = Annotated[str | None, Depends(get_plaintext_args)] Platform = Annotated[str, Depends(get_platform)] OptionalPlatform = Annotated[str | None, Depends(get_platform)] +MyUserInfo = Annotated[_MyUserInfo, EventUserInfo()] +Session = Annotated[_Session, Depends(extract_session)]