Skip to content

Commit

Permalink
增加一个 UserSession
Browse files Browse the repository at this point in the history
  • Loading branch information
he0119 committed Sep 13, 2023
1 parent d154648 commit 92c92db
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 53 deletions.
66 changes: 15 additions & 51 deletions src/plugins/user/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,29 +10,17 @@
Query,
on_alconna,
)
from nonebot_plugin_datastore import create_session
from nonebot_plugin_session import SessionLevel
from sqlalchemy import select

from src.utils.annotated import MyUserInfo, Session

from .models import Bind
from .utils import create_user, get_user, set_user
from .annotated import UserSession
from .utils import get_user, remove_bind, set_bind

user_cmd = on_alconna(Alconna("user"), use_cmd_start=True)


@user_cmd.handle()
async def _(session: Session, user_info: MyUserInfo):
if session.platform == "unknown" or not session.id1:
await bind_cmd.finish("不支持的平台")
return

user = await get_user(session.id1, session.platform)
if not user:
user = await create_user(session.id1, session.platform, user_info.user_name)

await user_cmd.finish(f"用户名: {user.name}\n创建日期: {user.created_at}")
async def _(session: UserSession):
await user_cmd.finish(f"用户名: {session.name}\n创建日期: {session.created_at}")


tokens = cast(
Expand All @@ -48,44 +36,21 @@ async def _(session: Session, user_info: MyUserInfo):

@bind_cmd.handle()
async def _(
session: Session,
user_info: MyUserInfo,
session: UserSession,
token: str | None = None,
remove: Query[bool] = AlconnaQuery("r.value", default=False),
):
if (
session.platform == "unknown"
or session.level == SessionLevel.LEVEL0
or not session.id1
):
await bind_cmd.finish("不支持的平台")
return

user = await get_user(session.id1, session.platform)
if not user:
user = await create_user(session.id1, session.platform, user_info.user_name)

if remove.result:
async with create_session() as db_session:
bind = (
await db_session.scalars(
select(Bind)
.where(Bind.pid == session.id1)
.where(Bind.platform == session.platform)
)
).one()

if bind.aid == bind.bid:
await bind_cmd.finish("不能解绑最初绑定的账号")
else:
bind.aid = bind.bid
await db_session.commit()
await bind_cmd.finish("解绑成功")
result = await remove_bind(session.pid, session.platform)
if result:
await bind_cmd.finish("解绑成功")
else:
await bind_cmd.finish("不能解绑最初绑定的账号")

# 生成令牌
if not token:
token = f"nonebot/{random.randint(100000, 999999)}"
tokens[token] = (session.id1, session.platform, user.id, session.level)
tokens[token] = (session.pid, session.platform, session.uid, session.level)
await bind_cmd.finish(
f"命令 bind 可用于在多个平台间绑定用户数据。绑定过程中,原始平台的用户数据将完全保留,而目标平台的用户数据将被原始平台的数据所覆盖。\n请确认当前平台是你的目标平台,并在 5 分钟内使用你的账号在原始平台内向机器人发送以下文本:\n/bind {token}\n绑定完成后,你可以随时使用「bind -r」来解除绑定状态。"
)
Expand All @@ -98,26 +63,25 @@ async def _(
# 此时 pid 和 platform 为目标平台的信息
if level == SessionLevel.LEVEL2 or level == SessionLevel.LEVEL3:
token = f"nonebot/{random.randint(100000, 999999)}"
tokens[token] = (session.id1, session.platform, user_id, None)
tokens[token] = (session.pid, session.platform, user_id, None)
await bind_cmd.finish(
f"令牌核验成功!下面将进行第二步操作。\n请在 5 分钟内使用你的账号在目标平台内向机器人发送以下文本:\n/bind {token}\n注意:当前平台是你的原始平台,这里的用户数据将覆盖目标平台的数据。"
)
# 群内绑定的第二步,会在目标平台发送令牌
# 此时 pid 和 platform 为原始平台的信息
# 需要重新获取其用户信息,然后将目标平台绑定至原始平台
elif level is None:
if user.id != user_id:
if session.uid != user_id:
await bind_cmd.finish("请使用最开始要绑定账号进行操作")

user = await get_user(pid, platform)
assert user
await set_user(session.id1, session.platform, user.id)
await set_bind(session.pid, session.platform, user.id)
await bind_cmd.finish("绑定成功")
# 私聊绑定时,会在原始平台发送令牌
# 此时 pid 和 platform 为目标平台的信息
# 直接将目标平台绑定至原始平台
elif level == SessionLevel.LEVEL1:
await set_user(pid, platform, user.id)
await set_bind(pid, platform, session.uid)
await bind_cmd.finish("绑定成功")
else:
await bind_cmd.finish("令牌无效/已过期")
10 changes: 10 additions & 0 deletions src/plugins/user/annotated.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from typing import Annotated

from nonebot.params import Depends

from .depends import UserSession as _UserSession
from .depends import get_or_create_user
from .models import User as _User

User = Annotated[_User, Depends(get_or_create_user)]
UserSession = Annotated[_UserSession, Depends(_UserSession)]
69 changes: 69 additions & 0 deletions src/plugins/user/depends.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from dataclasses import dataclass
from datetime import datetime

from nonebot.matcher import Matcher
from nonebot.params import Depends
from nonebot_plugin_session import SessionLevel

from src.utils.annotated import MyUserInfo, Session

from . import utils
from .models import User


async def get_or_create_user(matcher: Matcher, session: Session, user_info: MyUserInfo):
"""获取一个用户,如果不存在则创建"""
if (
session.platform == "unknown"
or session.level == SessionLevel.LEVEL0
or not session.id1
):
await matcher.finish("用户相关功能暂不支持当前平台")
return

try:
user = await utils.get_user(session.id1, session.platform)
except ValueError:
user = await utils.create_user(
session.id1, session.platform, user_info.user_name
)

return user


@dataclass
class UserSession:
session: Session
info: MyUserInfo
user: User = Depends(get_or_create_user)

@property
def uid(self) -> int:
"""用户 ID"""
return self.user.id

@property
def name(self) -> str:
"""用户名"""
return self.user.name

@property
def created_at(self) -> datetime:
"""用户创建日期"""
return self.user.created_at

@property
def pid(self) -> str:
"""用户所在平台 ID"""
assert self.session.id1
return self.session.id1

@property
def platform(self) -> str:
"""用户所在平台"""
return self.session.platform

@property
def level(self) -> SessionLevel:
"""用户会话级别"""
return self.session.level
24 changes: 22 additions & 2 deletions src/plugins/user/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@


async def create_user(pid: str, platform: str, nickname: str):
"""创建账号"""
async with create_session() as session:
user = User(name=nickname)
session.add(user)
Expand All @@ -22,6 +23,7 @@ async def create_user(pid: str, platform: str, nickname: str):


async def get_user(pid: str, platform: str):
"""获取账号"""
async with create_session() as session:
bind = (
await session.scalars(
Expand All @@ -33,12 +35,13 @@ async def get_user(pid: str, platform: str):
).one_or_none()

if not bind:
return
raise ValueError("找不到用户信息")

return bind.auser


async def set_user(pid: str, platform: str, aid: int):
async def set_bind(pid: str, platform: str, aid: int):
"""设置账号绑定"""
async with create_session() as session:
bind = (
await session.scalars(
Expand All @@ -51,3 +54,20 @@ async def set_user(pid: str, platform: str, aid: int):

bind.aid = aid
await session.commit()


async def remove_bind(pid: str, platform: str):
"""解除账号绑定"""
async with create_session() as db_session:
bind = (
await db_session.scalars(
select(Bind).where(Bind.pid == pid).where(Bind.platform == platform)
)
).one()

if bind.aid == bind.bid:
return False
else:
bind.aid = bind.bid
await db_session.commit()
return True

0 comments on commit 92c92db

Please sign in to comment.