Skip to content

Commit

Permalink
add redis storage
Browse files Browse the repository at this point in the history
  • Loading branch information
z44d committed Sep 20, 2024
1 parent 3f91d1e commit 9144245
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 15 deletions.
3 changes: 2 additions & 1 deletion tgram/storage/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .base import StorageBase
from .kvsqlite_storage import KvsqliteStorage
from .redis_storage import RedisStorage

from . import utils

__all__ = ["StorageBase", "KvsqliteStorage", "utils"]
__all__ = ["StorageBase", "KvsqliteStorage", "RedisStorage", "utils"]
15 changes: 10 additions & 5 deletions tgram/storage/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,17 @@


class StorageBase(ABC):
def __init__(self, bot: "tgram.TgBot") -> None:
from kvsqlite import Client
def __init__(self, bot: "tgram.TgBot", type: "str") -> None:
if type == "kvsqlite":
from kvsqlite import Client

self.client = Client(
"tgram-" + str(bot.me.id), workers=bot.workers, loop=bot.loop
)
self.client = Client(
"tgram-" + str(bot.me.id), workers=bot.workers, loop=bot.loop
)
elif type == "redis":
import redis.asyncio as redis

self.client = redis.Redis(decode_responses=True)
self.bot = bot

@abstractmethod
Expand Down
2 changes: 1 addition & 1 deletion tgram/storage/kvsqlite_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

class KvsqliteStorage(StorageBase):
def __init__(self, bot: "tgram.TgBot") -> None:
super().__init__(bot)
super().__init__(bot, "kvsqlite")

async def set(self, key: str, value: Any) -> bool:
return await self.client.set(key, value)
Expand Down
91 changes: 91 additions & 0 deletions tgram/storage/redis_storage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import tgram
import json
from .base import StorageBase

from typing import Any, Dict, List, Tuple, Union


class RedisStorage(StorageBase):
def __init__(self, bot: "tgram.TgBot") -> None:
super().__init__(bot, "redis")

async def set(self, key: str, value: Any) -> bool:
return await self.client.hset("tgram-" + str(self.bot.me.id), key, value)

async def get(self, key: str) -> Any:
return await self.client.hget("tgram-" + str(self.bot.me.id), key)

async def add_chat(self, chat: "tgram.types.Chat") -> bool:
chat_json = chat.json
chats = await self.get_chats()
if chat.username:
chats.update({chat.username.lower(): chat.id})
chats.update({chat.id: chat_json})
return await self.update_chats(chats)

async def get_chat(
self, chat_id: Union[int, str], parse: bool = False
) -> Union[dict, "tgram.types.Chat"]:
chats = await self.get_chats()

if chat := chats.get(chat_id.lower() if isinstance(chat_id, str) else chat_id):
if isinstance(chat, int):
return await self.get_chat(chat, parse)
return tgram.types.Chat._parse(self.bot, chat) if parse else chat

return {}

async def get_chats(self) -> Dict[str, dict]:
return json.loads(await self.client.get("chats") or {})

async def update_chats(self, chats: Dict[str, dict]) -> bool:
return await self.client.set("chats", json.dumps(chats, ensure_ascii=False))

async def add_user(self, user: "tgram.types.User") -> bool:
user_json = user.json
users = await self.get_users()
if user.username:
users.update({user.username.lower(): user.id})
users.update({user.id: user_json})
return await self.update_users(users)

async def get_user(
self, user_id: Union[int, str], parse: bool = False
) -> Union[dict, "tgram.types.User"]:
users = await self.get_users()

if user := users.get(user_id.lower() if isinstance(user_id, str) else user_id):
if isinstance(user, int):
return await self.get_user(user, parse)
return tgram.types.User._parse(self.bot, user) if parse else user

return {}

async def get_users(self) -> Dict[str, Dict]:
return json.loads(await self.client.get("users") or {})

async def update_users(self, users: Dict[str, dict]) -> bool:
return await self.client.set("users", json.dumps(users, ensure_ascii=False))

async def mute(self, chat_id: int, user_id: int) -> bool:
mute_list = await self.get_mute_list(True)
packet = [chat_id, user_id]
if packet in mute_list:
return False
mute_list.append(packet)
return await self.update_mute_list(mute_list)

async def unmute(self, chat_id: int, user_id: int) -> bool:
mute_list = await self.get_mute_list(True)
packet = [chat_id, user_id]
if packet not in mute_list:
return False
mute_list.remove(packet)
return await self.update_mute_list(mute_list)

async def get_mute_list(self, _: bool = False) -> List[Tuple[int, int]]:
x = json.loads(await self.get("mute")) or []
return x if _ else [tuple(i) for i in x]

async def update_mute_list(self, mute_list: List[Tuple[int, int]]) -> bool:
return await self.set("mute", json.dumps(mute_list, ensure_ascii=False))
32 changes: 24 additions & 8 deletions tgram/tgbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from .errors import APIException, MutedError
from .utils import API_URL, get_file_name, ALL_UPDATES
from .sync import wrap
from .storage import KvsqliteStorage
from .storage import KvsqliteStorage, RedisStorage
from .storage.utils import check_update
from .types.type_ import Type_
from concurrent.futures.thread import ThreadPoolExecutor
Expand Down Expand Up @@ -254,7 +254,7 @@ def __init__(
retry_after: Union[int, bool] = None,
plugins: Union[Path, str] = None,
skip_updates: bool = True,
storage: bool = False,
storage: Literal["kvsqlite", "redis"] = None,
) -> None:
self.bot_token = bot_token
self.api_url = api_url
Expand Down Expand Up @@ -289,14 +289,30 @@ def __init__(
self._api_url: str = f"{api_url}bot{bot_token}/"

if storage:
try:
__import__("kvsqlite")
except ModuleNotFoundError:
if storage.lower() == "kvsqlite":
try:
__import__("kvsqlite")
except ModuleNotFoundError:
raise ValueError(
"Please install kvsqlite module before using storage, see more https://pypi.org/project/Kvsqlite/"
)
else:
self.storage = KvsqliteStorage(self)
elif storage.lower() == "redis":
try:
__import__("redis")
except ModuleNotFoundError:
raise ValueError(
"Please install redis module before using storage, see more https://pypi.org/project/redis/"
)
else:
self.storage = RedisStorage(self)
else:
raise ValueError(
"Please install kvsqlite module before using storage, see more https://pypi.org/project/Kvsqlite/"
"Unsupported storage engine {}, only {} are supported for now.".format(
storage, " ,".join(i for i in ["redis", "kvsqlite"])
)
)
else:
self.storage = KvsqliteStorage(self)

def add_handler(self, handler: "tgram.handlers.Handler", group: int = 0) -> None:
if handler.type == "all":
Expand Down

0 comments on commit 9144245

Please sign in to comment.