diff --git a/tgram/storage/__init__.py b/tgram/storage/__init__.py index 00a4741..e06638e 100644 --- a/tgram/storage/__init__.py +++ b/tgram/storage/__init__.py @@ -1,6 +1,7 @@ from .base import StorageBase from .kvsqlite_storage import KvsqliteStorage +from .redis_storage import RedisStorage from . import utils -__all__ = ["StorageBase", "KvsqliteStorage", "utils"] +__all__ = ["StorageBase", "KvsqliteStorage", "RedisStorage", "utils"] diff --git a/tgram/storage/base.py b/tgram/storage/base.py index ae678bf..df122e5 100644 --- a/tgram/storage/base.py +++ b/tgram/storage/base.py @@ -5,12 +5,17 @@ class StorageBase(ABC): - def __init__(self, bot: "tgram.TgBot") -> None: - from kvsqlite import Client + def __init__(self, bot: "tgram.TgBot", type: "str") -> None: + if type == "kvsqlite": + from kvsqlite import Client - self.client = Client( - "tgram-" + str(bot.me.id), workers=bot.workers, loop=bot.loop - ) + self.client = Client( + "tgram-" + str(bot.me.id), workers=bot.workers, loop=bot.loop + ) + elif type == "redis": + import redis.asyncio as redis + + self.client = redis.Redis(decode_responses=True) self.bot = bot @abstractmethod diff --git a/tgram/storage/kvsqlite_storage.py b/tgram/storage/kvsqlite_storage.py index 4f5dd91..da81c2d 100644 --- a/tgram/storage/kvsqlite_storage.py +++ b/tgram/storage/kvsqlite_storage.py @@ -6,7 +6,7 @@ class KvsqliteStorage(StorageBase): def __init__(self, bot: "tgram.TgBot") -> None: - super().__init__(bot) + super().__init__(bot, "kvsqlite") async def set(self, key: str, value: Any) -> bool: return await self.client.set(key, value) diff --git a/tgram/storage/redis_storage.py b/tgram/storage/redis_storage.py new file mode 100644 index 0000000..9e33cf3 --- /dev/null +++ b/tgram/storage/redis_storage.py @@ -0,0 +1,91 @@ +import tgram +import json +from .base import StorageBase + +from typing import Any, Dict, List, Tuple, Union + + +class RedisStorage(StorageBase): + def __init__(self, bot: "tgram.TgBot") -> None: + super().__init__(bot, "redis") + + async def set(self, key: str, value: Any) -> bool: + return await self.client.hset("tgram-" + str(self.bot.me.id), key, value) + + async def get(self, key: str) -> Any: + return await self.client.hget("tgram-" + str(self.bot.me.id), key) + + async def add_chat(self, chat: "tgram.types.Chat") -> bool: + chat_json = chat.json + chats = await self.get_chats() + if chat.username: + chats.update({chat.username.lower(): chat.id}) + chats.update({chat.id: chat_json}) + return await self.update_chats(chats) + + async def get_chat( + self, chat_id: Union[int, str], parse: bool = False + ) -> Union[dict, "tgram.types.Chat"]: + chats = await self.get_chats() + + if chat := chats.get(chat_id.lower() if isinstance(chat_id, str) else chat_id): + if isinstance(chat, int): + return await self.get_chat(chat, parse) + return tgram.types.Chat._parse(self.bot, chat) if parse else chat + + return {} + + async def get_chats(self) -> Dict[str, dict]: + return json.loads(await self.client.get("chats") or {}) + + async def update_chats(self, chats: Dict[str, dict]) -> bool: + return await self.client.set("chats", json.dumps(chats, ensure_ascii=False)) + + async def add_user(self, user: "tgram.types.User") -> bool: + user_json = user.json + users = await self.get_users() + if user.username: + users.update({user.username.lower(): user.id}) + users.update({user.id: user_json}) + return await self.update_users(users) + + async def get_user( + self, user_id: Union[int, str], parse: bool = False + ) -> Union[dict, "tgram.types.User"]: + users = await self.get_users() + + if user := users.get(user_id.lower() if isinstance(user_id, str) else user_id): + if isinstance(user, int): + return await self.get_user(user, parse) + return tgram.types.User._parse(self.bot, user) if parse else user + + return {} + + async def get_users(self) -> Dict[str, Dict]: + return json.loads(await self.client.get("users") or {}) + + async def update_users(self, users: Dict[str, dict]) -> bool: + return await self.client.set("users", json.dumps(users, ensure_ascii=False)) + + async def mute(self, chat_id: int, user_id: int) -> bool: + mute_list = await self.get_mute_list(True) + packet = [chat_id, user_id] + if packet in mute_list: + return False + mute_list.append(packet) + return await self.update_mute_list(mute_list) + + async def unmute(self, chat_id: int, user_id: int) -> bool: + mute_list = await self.get_mute_list(True) + packet = [chat_id, user_id] + if packet not in mute_list: + return False + mute_list.remove(packet) + return await self.update_mute_list(mute_list) + + async def get_mute_list(self, _: bool = False) -> List[Tuple[int, int]]: + x = json.loads(await self.get("mute")) or [] + return x if _ else [tuple(i) for i in x] + + async def update_mute_list(self, mute_list: List[Tuple[int, int]]) -> bool: + return await self.set("mute", json.dumps(mute_list, ensure_ascii=False)) diff --git a/tgram/tgbot.py b/tgram/tgbot.py index 1d10d08..71134ae 100644 --- a/tgram/tgbot.py +++ b/tgram/tgbot.py @@ -14,7 +14,7 @@ from .errors import APIException, MutedError from .utils import API_URL, get_file_name, ALL_UPDATES from .sync import wrap -from .storage import KvsqliteStorage +from .storage import KvsqliteStorage, RedisStorage from .storage.utils import check_update from .types.type_ import Type_ from concurrent.futures.thread import ThreadPoolExecutor @@ -254,7 +254,7 @@ def __init__( retry_after: Union[int, bool] = None, plugins: Union[Path, str] = None, skip_updates: bool = True, - storage: bool = False, + storage: Literal["kvsqlite", "redis"] = None, ) -> None: self.bot_token = bot_token self.api_url = api_url @@ -289,14 +289,30 @@ def __init__( self._api_url: str = f"{api_url}bot{bot_token}/" if storage: - try: - __import__("kvsqlite") - except ModuleNotFoundError: + if storage.lower() == "kvsqlite": + try: + __import__("kvsqlite") + except ModuleNotFoundError: + raise ValueError( + "Please install kvsqlite module before using storage, see more https://pypi.org/project/Kvsqlite/" + ) + else: + self.storage = KvsqliteStorage(self) + elif storage.lower() == "redis": + try: + __import__("redis") + except ModuleNotFoundError: + raise ValueError( + "Please install redis module before using storage, see more https://pypi.org/project/redis/" + ) + else: + self.storage = RedisStorage(self) + else: raise ValueError( - "Please install kvsqlite module before using storage, see more https://pypi.org/project/Kvsqlite/" + "Unsupported storage engine {}, only {} are supported for now.".format( + storage, " ,".join(i for i in ["redis", "kvsqlite"]) + ) ) - else: - self.storage = KvsqliteStorage(self) def add_handler(self, handler: "tgram.handlers.Handler", group: int = 0) -> None: if handler.type == "all":