From 9f67d2970f4cf12ac629f84eb31ebcf433708fe1 Mon Sep 17 00:00:00 2001 From: meetwq Date: Wed, 18 Sep 2024 22:09:06 +0800 Subject: [PATCH] :sparkles: improve validator for `MessageSegment` --- nonebot/adapters/discord/message.py | 118 +++++++++++++++++++++++++++- 1 file changed, 117 insertions(+), 1 deletion(-) diff --git a/nonebot/adapters/discord/message.py b/nonebot/adapters/discord/message.py index 3afac43..16846c8 100644 --- a/nonebot/adapters/discord/message.py +++ b/nonebot/adapters/discord/message.py @@ -11,12 +11,13 @@ Union, overload, ) -from typing_extensions import override +from typing_extensions import Self, override from nonebot.adapters import ( Message as BaseMessage, MessageSegment as BaseMessageSegment, ) +from nonebot.compat import type_validate_python from .api import ( UNSET, @@ -24,6 +25,7 @@ AttachmentSend, Button, Component, + ComponentType, DirectComponent, Embed, File, @@ -32,6 +34,7 @@ SelectMenu, Snowflake, SnowflakeType, + TextInput, TimeStampStyle, ) from .utils import unescape @@ -174,6 +177,32 @@ def reference( def is_text(self) -> bool: return self.type == "text" + @classmethod + @override + def _validate(cls, value) -> Self: + if isinstance(value, cls): + return value + if isinstance(value, MessageSegment): + raise ValueError(f"Type {type(value)} can not be converted to {cls}") + if not isinstance(value, dict): + raise ValueError(f"Expected dict for MessageSegment, got {type(value)}") + if "type" not in value: + raise ValueError( + f"Expected dict with 'type' for MessageSegment, got {value}" + ) + _type = value["type"] + if _type not in SEGMENT_TYPE_MAP: + raise ValueError(f"Invalid MessageSegment type: {_type}") + segment_type = SEGMENT_TYPE_MAP[_type] + + # casting value to subclass of MessageSegment + if cls is MessageSegment: + return type_validate_python(segment_type, value) + # init segment instance directly if type matched + if cls is segment_type: + return segment_type(type=_type, data=value.get("data", {})) + raise ValueError(f"Segment type {_type!r} can not be converted to {cls}") + class StickerData(TypedDict): id: Snowflake @@ -204,6 +233,33 @@ class ComponentSegment(MessageSegment): def __str__(self) -> str: return f"" + @classmethod + @override + def _validate(cls, value) -> Self: + instance = super()._validate(value) + if "component" not in instance.data: + raise ValueError( + f"Expected dict with 'component' in 'data' for ComponentSegment, got {value}" + ) + if not isinstance( + component := instance.data["component"], (ActionRow, TextInput) + ): + if not isinstance(component, dict): + raise ValueError( + f"Expected dict for ComponentData, got {type(component)}" + ) + if "type" not in component: + raise ValueError( + f"Expected dict with 'type' for ComponentData, got {component}" + ) + if component["type"] == ComponentType.ActionRow: + instance.data["component"] = type_validate_python(ActionRow, component) + elif component["type"] == ComponentType.TextInput: + instance.data["component"] = type_validate_python(TextInput, component) + else: + raise ValueError(f"Invalid ComponentType: {component['type']}") + return instance + class CustomEmojiData(TypedDict): name: str @@ -334,6 +390,18 @@ class EmbedSegment(MessageSegment): def __str__(self) -> str: return f"" + @classmethod + @override + def _validate(cls, value) -> Self: + instance = super()._validate(value) + if "embed" not in instance.data: + raise ValueError( + f"Expected dict with 'embed' in 'data' for EmbedSegment, got {value}" + ) + if not isinstance(embed := instance.data["embed"], Embed): + instance.data["embed"] = type_validate_python(Embed, embed) + return instance + class AttachmentData(TypedDict): attachment: AttachmentSend @@ -350,6 +418,24 @@ class AttachmentSegment(MessageSegment): def __str__(self) -> str: return f"" + @classmethod + @override + def _validate(cls, value) -> Self: + instance = super()._validate(value) + if "attachment" not in instance.data: + raise ValueError( + f"Expected dict with 'attachment' in 'data' for AttachmentSegment, got {value}" + ) + if not isinstance(attachment := instance.data["attachment"], AttachmentSend): + instance.data["attachment"] = type_validate_python( + AttachmentSend, attachment + ) + if (file := instance.data.get("file")) is not None and not isinstance( + file, File + ): + instance.data["file"] = type_validate_python(File, file) + return instance + class ReferenceData(TypedDict): reference: MessageReference @@ -365,6 +451,36 @@ class ReferenceSegment(MessageSegment): def __str__(self): return f"" + @classmethod + @override + def _validate(cls, value) -> Self: + instance = super()._validate(value) + if "reference" not in instance.data: + raise ValueError( + f"Expected dict with 'reference' in 'data' for ReferenceSegment, got {value}" + ) + if not isinstance(reference := instance.data["reference"], MessageReference): + instance.data["reference"] = type_validate_python( + MessageReference, reference + ) + return instance + + +SEGMENT_TYPE_MAP = { + "attachment": AttachmentSegment, + "sticker": StickerSegment, + "embed": EmbedSegment, + "component": ComponentSegment, + "custom_emoji": CustomEmojiSegment, + "mention_user": MentionUserSegment, + "mention_role": MentionRoleSegment, + "mention_channel": MentionChannelSegment, + "mention_everyone": MentionEveryoneSegment, + "text": TextSegment, + "timestamp": TimestampSegment, + "reference": ReferenceSegment, +} + class Message(BaseMessage[MessageSegment]): @classmethod