-
Notifications
You must be signed in to change notification settings - Fork 1
/
client.py
79 lines (58 loc) · 2.35 KB
/
client.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import logging
from fastapi import WebSocket
from utils.redis import redis_connection
from utils.crypt import jwt_decode
from utils.types import ChannelNames
from utils.json import is_valid_json
from aioredis.pubsub import Receiver
from json import dumps
from jwt.exceptions import ExpiredSignatureError, InvalidSignatureError
import asyncio
import aioredis
import os
class Client:
def __init__(self, websocket: WebSocket):
self._websocket = websocket
self._redis = None
self._channels: ChannelNames = []
self._receiver = Receiver()
async def authorize(self, token: str):
try:
payload = jwt_decode(token)
logging.info('Client authenticated. User: %s' % payload['iss'])
self._redis = await aioredis.create_redis_pool((os.environ['REDIS_HOST'], 6379))
except (ExpiredSignatureError, InvalidSignatureError):
logging.warning('Invalid token: %s. Signature has expired.' % token)
async def handle_message(self, message: str):
if message[0:9] == 'subscribe':
channel_name = message[10:]
self._channels.append(channel_name)
await self.subscribe()
if not (data := is_valid_json(message)):
return
channel_name = data.get('channel')
await self.publish(channel_name, dumps(data))
async def publish(self, channel_name: str, data: str):
pool = await redis_connection()
with await pool as conn:
await conn.publish(channel_name, data)
async def subscribe(self):
"""
Subscribe to Redis channel and if - event occurs - pass data to the client by calling websocket method
:return:
"""
async def reader():
async for channel, message in self._receiver.iter():
logging.info(message)
await self._websocket.send_text(message.decode('utf-8'))
asyncio.ensure_future(reader())
await self._redis.subscribe(
*[self._receiver.channel(channel) for channel in self._channels]
)
logging.info('Subscribed channels: %s' % self._channels)
async def unsubscribe(self):
if self._redis:
await self._redis.unsubscribe(*self._channels)
self._redis.close()
await self._redis.wait_closed()
logging.info('Unsubscribed')