Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ improve validator for MessageSegment #42

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 117 additions & 1 deletion nonebot/adapters/discord/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,21 @@
Union,
overload,
)
from typing_extensions import override
from typing_extensions import Self, override

from nonebot.adapters import (
Message as BaseMessage,
MessageSegment as BaseMessageSegment,
)
from nonebot.compat import type_validate_python

from .api import (
UNSET,
ActionRow,
AttachmentSend,
Button,
Component,
ComponentType,
DirectComponent,
Embed,
File,
Expand All @@ -32,6 +34,7 @@
SelectMenu,
Snowflake,
SnowflakeType,
TextInput,
TimeStampStyle,
)
from .utils import unescape
Expand Down Expand Up @@ -174,6 +177,32 @@ def reference(
def is_text(self) -> bool:
return self.type == "text"

@classmethod
@override
def _validate(cls, value) -> Self:
if isinstance(value, cls):
return value
if isinstance(value, MessageSegment):
raise ValueError(f"Type {type(value)} can not be converted to {cls}")
if not isinstance(value, dict):
raise ValueError(f"Expected dict for MessageSegment, got {type(value)}")
if "type" not in value:
raise ValueError(
f"Expected dict with 'type' for MessageSegment, got {value}"
)
_type = value["type"]
if _type not in SEGMENT_TYPE_MAP:
raise ValueError(f"Invalid MessageSegment type: {_type}")
segment_type = SEGMENT_TYPE_MAP[_type]

# casting value to subclass of MessageSegment
if cls is MessageSegment:
return type_validate_python(segment_type, value)
# init segment instance directly if type matched
if cls is segment_type:
return segment_type(type=_type, data=value.get("data", {}))
raise ValueError(f"Segment type {_type!r} can not be converted to {cls}")


class StickerData(TypedDict):
id: Snowflake
Expand Down Expand Up @@ -204,6 +233,33 @@ class ComponentSegment(MessageSegment):
def __str__(self) -> str:
return f"<Component:{self.data['component'].type}>"

@classmethod
@override
def _validate(cls, value) -> Self:
instance = super()._validate(value)
if "component" not in instance.data:
raise ValueError(
f"Expected dict with 'component' in 'data' for ComponentSegment, got {value}"
)
if not isinstance(
component := instance.data["component"], (ActionRow, TextInput)
):
if not isinstance(component, dict):
raise ValueError(
f"Expected dict for ComponentData, got {type(component)}"
)
if "type" not in component:
raise ValueError(
f"Expected dict with 'type' for ComponentData, got {component}"
)
if component["type"] == ComponentType.ActionRow:
instance.data["component"] = type_validate_python(ActionRow, component)
elif component["type"] == ComponentType.TextInput:
instance.data["component"] = type_validate_python(TextInput, component)
else:
raise ValueError(f"Invalid ComponentType: {component['type']}")
return instance


class CustomEmojiData(TypedDict):
name: str
Expand Down Expand Up @@ -334,6 +390,18 @@ class EmbedSegment(MessageSegment):
def __str__(self) -> str:
return f"<Embed:{self.data['embed'].type}>"

@classmethod
@override
def _validate(cls, value) -> Self:
instance = super()._validate(value)
if "embed" not in instance.data:
raise ValueError(
f"Expected dict with 'embed' in 'data' for EmbedSegment, got {value}"
)
if not isinstance(embed := instance.data["embed"], Embed):
instance.data["embed"] = type_validate_python(Embed, embed)
return instance


class AttachmentData(TypedDict):
attachment: AttachmentSend
Expand All @@ -350,6 +418,24 @@ class AttachmentSegment(MessageSegment):
def __str__(self) -> str:
return f"<Attachment:{self.data['attachment'].filename}>"

@classmethod
@override
def _validate(cls, value) -> Self:
instance = super()._validate(value)
if "attachment" not in instance.data:
raise ValueError(
f"Expected dict with 'attachment' in 'data' for AttachmentSegment, got {value}"
)
if not isinstance(attachment := instance.data["attachment"], AttachmentSend):
instance.data["attachment"] = type_validate_python(
AttachmentSend, attachment
)
if (file := instance.data.get("file")) is not None and not isinstance(
file, File
):
instance.data["file"] = type_validate_python(File, file)
return instance


class ReferenceData(TypedDict):
reference: MessageReference
Expand All @@ -365,6 +451,36 @@ class ReferenceSegment(MessageSegment):
def __str__(self):
return f"<Reference:{self.data['reference'].message_id}>"

@classmethod
@override
def _validate(cls, value) -> Self:
instance = super()._validate(value)
if "reference" not in instance.data:
raise ValueError(
f"Expected dict with 'reference' in 'data' for ReferenceSegment, got {value}"
)
if not isinstance(reference := instance.data["reference"], MessageReference):
instance.data["reference"] = type_validate_python(
MessageReference, reference
)
return instance


SEGMENT_TYPE_MAP = {
"attachment": AttachmentSegment,
"sticker": StickerSegment,
"embed": EmbedSegment,
"component": ComponentSegment,
"custom_emoji": CustomEmojiSegment,
"mention_user": MentionUserSegment,
"mention_role": MentionRoleSegment,
"mention_channel": MentionChannelSegment,
"mention_everyone": MentionEveryoneSegment,
"text": TextSegment,
"timestamp": TimestampSegment,
"reference": ReferenceSegment,
}


class Message(BaseMessage[MessageSegment]):
@classmethod
Expand Down
Loading