Skip to content

Commit

Permalink
fix: 修复无法正确发送启动问候的问题 (#396)
Browse files Browse the repository at this point in the history
  • Loading branch information
he0119 authored Sep 30, 2023
1 parent 9f57f34 commit f6c1e26
Show file tree
Hide file tree
Showing 8 changed files with 119 additions and 26 deletions.
3 changes: 3 additions & 0 deletions .env
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ KAIHEILA_BOTS=[]

# 内部插件配置 -----------------------------

# 迁移使用的配置
MIGRATION_BOT_ID

# FF14
FFLOGS_RANGE=14
FFLOGS_CACHE_TIME=4:30:00
Expand Down
35 changes: 12 additions & 23 deletions src/plugins/morning/plugins/hello/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,7 @@
from nonebot.params import CommandArg, Depends
from nonebot.plugin import PluginMetadata, inherit_supported_adapters, on_command
from nonebot_plugin_saa import PlatformTarget, Text, get_target
from nonebot_plugin_saa.utils.auto_select_bot import (
extract_adapter_type,
list_targets_map,
)
from sqlalchemy import or_, select
from sqlalchemy.sql import ColumnElement
from sqlalchemy import select

from src.utils.annotated import AsyncSession
from src.utils.helpers import strtobool
Expand All @@ -25,7 +20,7 @@
__plugin_meta__ = PluginMetadata(
name="启动问候",
description="启动时发送问候",
usage="""开启时会在每天机器人第一次启动时发送问候
usage="""开启时会在机器人第一次启动时发送问候
查看当前群是否开启启动问候
/hello
Expand All @@ -42,20 +37,9 @@
@driver.on_bot_connect
async def hello_on_connect(bot: Bot, session: AsyncSession) -> None:
"""启动时发送问候"""
whereclause: list[ColumnElement[bool]] = []
adapter_name = extract_adapter_type(bot)
if list_targets := list_targets_map.get(adapter_name):
targets = await list_targets(bot)
if not targets:
logger.info(f"没有找到适配器 {adapter_name} 支持的发送目标")
return
for target in targets:
whereclause.append(or_(Hello.target == target.dict()))
else:
logger.info(f"不支持的适配器 {adapter_name}")
return

groups = (await session.scalars(select(Hello).where(*whereclause))).all()
groups = (
await session.scalars(select(Hello).where(Hello.bot_id == bot.self_id))
).all()
if not groups:
return

Expand All @@ -74,20 +58,25 @@ async def hello_on_connect(bot: Bot, session: AsyncSession) -> None:

@hello_cmd.handle()
async def hello_handle(
bot: Bot,
session: AsyncSession,
arg: Message = CommandArg(),
target: PlatformTarget = Depends(get_target),
):
args = arg.extract_plain_text()

group = (
await session.scalars(select(Hello).where(Hello.target == target.dict()))
await session.scalars(
select(Hello)
.where(Hello.target == target.dict())
.where(Hello.bot_id == bot.self_id)
)
).one_or_none()

if args:
if strtobool(args):
if not group:
session.add(Hello(target=target.dict()))
session.add(Hello(target=target.dict(), bot_id=bot.self_id))
await session.commit()
await hello_cmd.finish("已在本群开启启动问候功能")
else:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""add bot_id
Revision ID: 371e102dcdb9
Revises: 5469ed61acff
Create Date: 2023-09-30 08:39:55.899833
"""
import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision = "371e102dcdb9"
down_revision = "5469ed61acff"
branch_labels = None
depends_on = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("hello_hello", schema=None) as batch_op:
batch_op.add_column(sa.Column("bot_id", sa.String(), nullable=True))

# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("hello_hello", schema=None) as batch_op:
batch_op.drop_column("bot_id")

# ### end Alembic commands ###
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""remove bot_id platform et al
"""remove platform et al
Revision ID: 9dbb35122585
Revises: 92c2c4affdce
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
"""add bot_id data
Revision ID: b683352e0089
Revises: 371e102dcdb9
Create Date: 2023-09-30 08:46:28.829111
"""
import sqlalchemy as sa
from alembic import op
from nonebot import get_driver
from sqlalchemy.ext.automap import automap_base
from sqlalchemy.orm import Session

# revision identifiers, used by Alembic.
revision = "b683352e0089"
down_revision = "371e102dcdb9"
branch_labels = None
depends_on = None


def upgrade() -> None:
Base = automap_base()
Base.prepare(autoload_with=op.get_bind())
Hello = Base.classes.hello_hello
config = get_driver().config
migration_bot_id = getattr(config, "migration_bot_id", None)
with Session(op.get_bind()) as session:
hellos = session.scalars(sa.select(Hello)).all()
if hellos and migration_bot_id is None:
raise ValueError("你需要设置 migration_bot_id 以完成迁移")

for hello in hellos:
hello.bot_id = migration_bot_id
session.commit()


def downgrade() -> None:
pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""make bot_id non-nullable
Revision ID: e92b0f680c78
Revises: b683352e0089
Create Date: 2023-09-30 08:57:28.404492
"""
import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision = "e92b0f680c78"
down_revision = "b683352e0089"
branch_labels = None
depends_on = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("hello_hello", schema=None) as batch_op:
batch_op.alter_column("bot_id", existing_type=sa.VARCHAR(), nullable=False)

# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("hello_hello", schema=None) as batch_op:
batch_op.alter_column("bot_id", existing_type=sa.VARCHAR(), nullable=True)

# ### end Alembic commands ###
1 change: 1 addition & 0 deletions src/plugins/morning/plugins/hello/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
class Hello(Model):
id: Mapped[int] = mapped_column(primary_key=True)
target: Mapped[dict] = mapped_column(JSON().with_variant(JSONB, "postgresql"))
bot_id: Mapped[str]

@property
def saa_target(self) -> PlatformTarget:
Expand Down
4 changes: 2 additions & 2 deletions tests/plugins/morning/test_hello.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ async def test_hello_enabled(app: App):
from src.plugins.morning.plugins.hello import Hello, hello_cmd

async with create_session() as session:
session.add(Hello(target=TargetQQGroup(group_id=10000).dict()))
session.add(Hello(target=TargetQQGroup(group_id=10000).dict(), bot_id="test"))
await session.commit()

async with app.test_matcher(hello_cmd) as ctx:
Expand Down Expand Up @@ -72,7 +72,7 @@ async def test_hello_disable(app: App):
from src.plugins.morning.plugins.hello import Hello, hello_cmd

async with create_session() as session:
session.add(Hello(target=TargetQQGroup(group_id=10000).dict()))
session.add(Hello(target=TargetQQGroup(group_id=10000).dict(), bot_id="test"))
await session.commit()

async with app.test_matcher(hello_cmd) as ctx:
Expand Down

0 comments on commit f6c1e26

Please sign in to comment.