Skip to content

Commit

Permalink
-
Browse files Browse the repository at this point in the history
  • Loading branch information
A.Shpak committed Mar 18, 2024
1 parent ccbf6e6 commit 046155f
Show file tree
Hide file tree
Showing 5 changed files with 185 additions and 87 deletions.
26 changes: 19 additions & 7 deletions chatushka/_chatushka.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from collections.abc import AsyncGenerator, Callable, MutableMapping, Sequence
from contextlib import (
AbstractAsyncContextManager,
_AsyncGeneratorContextManager,
asynccontextmanager,
_AsyncGeneratorContextManager,
)
from typing import Any, final

Expand Down Expand Up @@ -58,13 +58,15 @@ def add_cmd(
action: Callable,
case_sensitive: bool = False,
chance_rate: float = 1.0,
):
results_model: type[Any] | None = None,
) -> None:
self.add_matcher(
CommandMatcher(
*commands,
action=action,
case_sensitive=case_sensitive,
chance_rate=chance_rate,
results_model=results_model,
)
)

Expand All @@ -73,6 +75,7 @@ def cmd(
*commands: str,
case_sensitive: bool = False,
chance_rate: float = 1.0,
results_model: type[Any] | None = None,
) -> Callable:
def _wrapper(
func,
Expand All @@ -82,6 +85,7 @@ def _wrapper(
action=func,
case_sensitive=case_sensitive,
chance_rate=chance_rate,
results_model=results_model,
)

return _wrapper
Expand All @@ -91,19 +95,22 @@ def add_regex(
*patterns: str,
action: Callable,
chance_rate: float = 1.0,
):
results_model: type[Any] | None = None,
) -> None:
self.add_matcher(
RegExMatcher(
*patterns,
action=action,
chance_rate=chance_rate,
results_model=results_model,
)
)

def regex(
self,
*patterns: str,
chance_rate: float = 1.0,
results_model: type[Any] | None = None,
) -> Callable:
def _wrapper(
func,
Expand All @@ -112,36 +119,41 @@ def _wrapper(
*patterns,
action=func,
chance_rate=chance_rate,
results_model=results_model
)

return _wrapper

def add_event(
self,
event: Events,
*events: Events,
action: Callable,
chance_rate: float = 1.0,
results_model: type[Any] | None = None,
) -> None:
self.add_matcher(
EventMatcher(
event=event,
*events,
action=action,
chance_rate=chance_rate,
results_model=results_model,
)
)

def event(
self,
event: Events,
*events: Events,
chance_rate: float = 1.0,
results_model: type[Any] | None = None,
) -> Callable:
def _wrapper(
func,
) -> None:
self.add_event(
event=event,
*events,
action=func,
chance_rate=chance_rate,
results_model=results_model,
)

return _wrapper
Expand Down
117 changes: 90 additions & 27 deletions chatushka/_matchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
from inspect import iscoroutinefunction, signature
from random import random
from re import Pattern, compile
from typing import TypeVar
from typing import TypeVar, Any

from pydantic import BaseModel

from chatushka._models import Events, Update
from chatushka._transport import TelegramBotAPI
Expand All @@ -18,55 +20,80 @@ def __init__(
self,
action: Callable,
chance_rate: float = 1.0,
results_model: type[Any] | None = None,
) -> None:
self._action: Callable = action
self._chance_rate = chance_rate
self._results_model = results_model

@abstractmethod
def _check(
self,
update: Update,
) -> bool:
) -> list[str] | None:
raise NotImplementedError

async def __call__(
self,
api: TelegramBotAPI,
update: Update,
) -> None:
if not self._check(
update=update,
):
if (
results := self._check(
update=update,
)
) is None:
return
await self._call_action(
api=api,
update=update,
results=results,
)

def _get_chance(
self,
) -> bool:
return random() <= self._chance_rate

def _make_results_model(
self,
results: list[str],
) -> Any:
if self._results_model is None:
return None
if issubclass(self._results_model, BaseModel):
params = {}
for i, name in enumerate(self._results_model.model_fields):
if len(results) < i:
break
params[name] = results[i]
return self._results_model.model_validate(params)
return self._results_model(results)

async def _call_action(
self,
api: TelegramBotAPI,
update: Update,
results: Any,
) -> None:
if not self._get_chance():
return
kwargs = {}
results_from_model = self._make_results_model(results)
kwargs.update(
{
"api": api,
"update": update,
"message": update.message,
"chat": update.message.chat if update.message else None,
"user": update.message.user if update.message else None,
"results": results_from_model or results,
}
)
sig = signature(self._action)
kwargs = {param: kwargs.get(param) for param in sig.parameters if param in kwargs}
kwargs = {
param: kwargs.get(param) for param in sig.parameters if param in kwargs
}
if update and update.message is not None and "message" in sig.parameters:
kwargs["message"] = update.message
if iscoroutinefunction(self._action):
Expand All @@ -86,10 +113,12 @@ def __init__(
prefixes: str | Sequence[str] = (),
case_sensitive: bool = False,
chance_rate: float = 1.0,
results_model: type[Any] | None = None,
) -> None:
super().__init__(
action=action,
chance_rate=chance_rate,
results_model=results_model,
)
if case_sensitive:
commands = tuple(command.upper() for command in commands)
Expand All @@ -104,21 +133,34 @@ def add_commands_prefixes(
) -> None:
if not prefixes:
return
self._commands = tuple(f"{prefix}{command}" for command in self._commands for prefix in prefixes)
self._commands = tuple(
f"{prefix}{command}" for command in self._commands for prefix in prefixes
)

def _make_args(
self,
text: str,
) -> list[str]:
args = [arg for arg in text.split(" ") if arg]
if len(args) == 1:
return []
return args[1:]

def _check(
self,
update: Update,
) -> bool:
) -> list[str] | None:
if not update.message or not update.message.text:
return False
return None
case_sensitive = self._case_sensitive or False
for command in self._commands:
if case_sensitive and update.message.text.upper().startswith(command.upper()):
return True
if case_sensitive and update.message.text.upper().startswith(
command.upper()
):
return self._make_args(update.message.text)
if update.message.text.startswith(command):
return True
return False
return self._make_args(update.message.text)
return None


class RegExMatcher(
Expand All @@ -130,20 +172,28 @@ def __init__(
*patterns: str | Pattern,
action: Callable,
chance_rate: float = 1.0,
results_model: type[Any] | None = None,
) -> None:
super().__init__(
action=action,
chance_rate=chance_rate,
results_model=results_model,
)
self._patterns = [compile(pattern) if isinstance(pattern, str) else pattern for pattern in patterns]
self._patterns = [
compile(pattern) if isinstance(pattern, str) else pattern
for pattern in patterns
]

def _check(
self,
update: Update,
) -> bool:
) -> list[str] | None:
if not update.message or not update.message.text:
return False
return any(pattern.findall(update.message.text) for pattern in self._patterns)
return None
for pattern in self._patterns:
if result := pattern.findall(update.message.text):
return result
return None


class EventMatcher(
Expand All @@ -152,24 +202,37 @@ class EventMatcher(
):
def __init__(
self,
event: Events,
*events: Events,
action: Callable,
chance_rate: float = 1.0,
results_model: type[Any] | None = None,
) -> None:
super().__init__(
action=action,
chance_rate=chance_rate,
results_model=results_model,
)
self._event = event
self._events = events

def _check(
self,
update: Update,
) -> bool:
if update.message and update.message.text and self._event == "on_message":
return True
if update.message and update.message.new_chat_members and self._event == "on_new_chat_members":
return True
if update.message and update.message.new_chat_members and self._event == "on_left_chat_member":
return True
return False
) -> list[str] | None:
results = []
if update.message and update.message.text and "on_message" in self._events:
results.append("on_message")
if (
update.message
and update.message.new_chat_members
and "on_new_chat_members" in self._events
):
results.append("on_new_chat_members")
if (
update.message
and update.message.new_chat_members
and "on_left_chat_member" in self._events
):
results.append("on_left_chat_member")
if not results:
return None
return results
22 changes: 11 additions & 11 deletions chatushka/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,17 +58,17 @@ class ChatMemberOwner(
class ChatMemberAdministrator(
ChatMemberBase,
):
can_be_edited: bool | None
can_manage_chat: bool | None
can_delete_messages: bool | None
can_manage_voice_chats: bool | None
can_restrict_members: bool | None
can_promote_members: bool | None
can_change_info: bool | None
can_invite_users: bool | None
can_post_messages: bool | None
can_edit_messages: bool | None
can_pin_messages: bool | None
can_be_edited: bool | None = None
can_manage_chat: bool | None = None
can_delete_messages: bool | None = None
can_manage_voice_chats: bool | None = None
can_restrict_members: bool | None = None
can_promote_members: bool | None = None
can_change_info: bool | None = None
can_invite_users: bool | None = None
can_post_messages: bool | None = None
can_edit_messages: bool | None =None
can_pin_messages: bool | None = None


class Chat(
Expand Down
Loading

0 comments on commit 046155f

Please sign in to comment.