Skip to content

Commit

Permalink
🐛 version 0.19.2
Browse files Browse the repository at this point in the history
fix subclass check in AlconnaParam
  • Loading branch information
RF-Tar-Railt committed Aug 26, 2023
1 parent b6049de commit 5605c4b
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 14 deletions.
2 changes: 1 addition & 1 deletion src/nonebot_plugin_alconna/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
from .consts import ALCONNA_EXEC_RESULT as ALCONNA_EXEC_RESULT
from .rule import set_output_converter as set_output_converter

__version__ = "0.19.1"
__version__ = "0.19.2"

_meta_source = {
"name": "Alconna 插件",
Expand Down
49 changes: 36 additions & 13 deletions src/nonebot_plugin_alconna/params.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,27 @@
import inspect
from typing_extensions import Annotated, TypeAlias
from typing import Any, Dict, Type, Tuple, Union, TypeVar, Callable, Optional, overload
from typing_extensions import Annotated, TypeAlias, get_args
from typing import (
Any,
Dict,
Type,
Tuple,
Union,
Literal,
TypeVar,
Callable,
Optional,
overload,
)

from nonebot.typing import T_State
from tarina import run_always_await
from nepattern.util import CUnionType
from pydantic.fields import Undefined
from tarina.generic import get_origin
from nonebot.internal.matcher import Matcher
from nonebot.internal.adapter import Bot, Event
from nonebot.internal.params import Param, Depends
from arclet.alconna.builtin import generate_duplication
from tarina import run_always_await, generic_issubclass
from arclet.alconna import Empty, Alconna, Arparma, Duplication

from .uniseg import UniMessage
Expand All @@ -17,6 +30,7 @@

T_Duplication = TypeVar("T_Duplication", bound=Duplication)
MIDDLEWARE: TypeAlias = Callable[[Bot, T_State, Any], Any]
_Contents = (Union, CUnionType, Literal)


def _alconna_result(state: T_State) -> CommandResult:
Expand Down Expand Up @@ -182,17 +196,20 @@ def __repr__(self) -> str:
def _check_param(
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
) -> Optional["AlconnaParam"]:
if param.annotation is CommandResult:
annotation = get_origin(param.annotation)
if annotation in _Contents:
annotation = get_args(param.annotation)[0]
if annotation is CommandResult:
return cls(..., type=CommandResult)
if generic_issubclass(get_origin(param.annotation), Arparma):
if annotation is Arparma:
return cls(..., type=Arparma)
if generic_issubclass(param.annotation, Alconna):
if annotation is Alconna:
return cls(..., type=Alconna)
if param.annotation is Duplication:
if annotation is Duplication:
return cls(..., type=Duplication)
if issubclass(get_origin(param.annotation), Duplication):
if inspect.isclass(annotation) and issubclass(annotation, Duplication):
return cls(..., anno=param.annotation, type=Duplication)
if get_origin(param.annotation) is Match:
if annotation is Match:
return cls(param.default, name=param.name, type=Match)
if isinstance(param.default, Query):
return cls(param.default, type=Query)
Expand All @@ -210,7 +227,8 @@ async def _solve(self, state: T_State, **kwargs: Any) -> Any:
if t is Duplication:
if anno := self.extra.get("anno"):
return anno(res.result)
return generate_duplication(res.source)(res.result)
else:
return generate_duplication(res.source)(res.result)
if t is Match:
target = res.result.all_matched_args.get(self.extra["name"], Empty)
return Match(target, target != Empty)
Expand All @@ -223,13 +241,18 @@ async def _solve(self, state: T_State, **kwargs: Any) -> Any:
elif self.default.result != Empty:
q.available = True
return q
if (key := ALCONNA_ARG_KEY.format(key=self.extra["name"])) in state:
return state[key]
if self.extra["name"] in res.result.all_matched_args:
return res.result.all_matched_args[self.extra["name"]]
return state[ALCONNA_ARG_KEY.format(key=self.extra["name"])]
return self.default if self.default not in (..., Empty) else Undefined

async def _check(self, state: T_State, **kwargs: Any) -> Any:
if self.extra["type"] == Any:
return (
if (
self.extra["name"] in _alconna_result(state).result.all_matched_args
or ALCONNA_ARG_KEY.format(key=self.extra["name"]) in state
)
):
return True
if self.default not in (..., Empty):
return True

0 comments on commit 5605c4b

Please sign in to comment.