Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
trisongz committed Oct 19, 2023
1 parent 8e0114f commit 244aaaf
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 9 deletions.
11 changes: 9 additions & 2 deletions aiokeydb/serializers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Type
from aiokeydb.types import BaseSerializer
from aiokeydb.serializers._json import JsonSerializer, OrJsonSerializer
from aiokeydb.serializers._pickle import PickleSerializer, DillSerializer
from aiokeydb.serializers._pickle import PickleSerializer, DillSerializer, DillSerializerv2, PickleSerializerv2
from aiokeydb.serializers._msgpack import MsgPackSerializer

class SerializerType(str, Enum):
Expand All @@ -18,6 +18,9 @@ class SerializerType(str, Enum):
msgpack = 'msgpack'
default = 'default'

picklev2 = 'picklev2'
dillv2 = 'dillv2'

def get_serializer(self) -> Type[BaseSerializer]:
"""
Default Serializer = Dill
Expand All @@ -31,10 +34,14 @@ def get_serializer(self) -> Type[BaseSerializer]:
return PickleSerializer
elif self == SerializerType.dill:
return DillSerializer
elif self == SerializerType.picklev2:
return PickleSerializerv2
elif self == SerializerType.dillv2:
return DillSerializerv2
elif self == SerializerType.msgpack:
return MsgPackSerializer
elif self == SerializerType.default:
return DillSerializer
return DillSerializerv2
else:
raise ValueError(f'Invalid serializer type: {self}')

Expand Down
48 changes: 47 additions & 1 deletion aiokeydb/serializers/_pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,11 @@
import sys
import pickle
import typing
import binascii
import contextlib
from io import BytesIO
from aiokeydb.types.serializer import BaseSerializer
from pickle import DEFAULT_PROTOCOL, Pickler, Unpickler


if sys.version_info.minor < 8:
Expand All @@ -18,6 +21,7 @@

try:
import dill
from dill import DEFAULT_PROTOCOL as DILL_DEFAULT_PROTOCOL
_dill_avail = True
except ImportError:
dill = object
Expand All @@ -38,8 +42,29 @@ def dumps(obj: typing.Any, protocol: int = DefaultProtocols.pickle, *args, **kwa
@staticmethod
def loads(data: typing.Union[str, bytes, typing.Any], *args, **kwargs) -> typing.Any:
return pickle.loads(data, *args, **kwargs)


class PickleSerializerv2(BaseSerializer):

@staticmethod
def dumps(obj: typing.Any, protocol: int = DEFAULT_PROTOCOL, *args, **kwargs) -> str:
"""
v2 Encoding
"""
f = BytesIO()
p = Pickler(f, protocol = protocol)
p.dump(obj)
return f.getvalue().hex()

@staticmethod
def loads(data: typing.Union[str, typing.Any], *args, **kwargs) -> typing.Any:
"""
V2 Decoding
"""
return Unpickler(BytesIO(binascii.unhexlify(data))).load()

if _dill_avail:
from dill import Pickler as DillPickler, Unpickler as DillUnpickler

class DillSerializer(BaseSerializer):

@staticmethod
Expand All @@ -49,6 +74,27 @@ def dumps(obj: typing.Any, protocol: int = DefaultProtocols.dill, *args, **kwarg
@staticmethod
def loads(data: typing.Union[str, bytes, typing.Any], *args, **kwargs) -> typing.Any:
return dill.loads(data, *args, **kwargs)

class DillSerializerv2(BaseSerializer):

@staticmethod
def dumps(obj: typing.Any, protocol: int = DILL_DEFAULT_PROTOCOL, *args, **kwargs) -> str:
"""
v2 Encoding
"""
f = BytesIO()
p = DillPickler(f, protocol = protocol)
p.dump(obj)
return f.getvalue().hex()

@staticmethod
def loads(data: typing.Union[str, typing.Any], *args, **kwargs) -> typing.Any:
"""
V2 Decoding
"""
return DillUnpickler(BytesIO(binascii.unhexlify(data))).load()

else:
DillSerializer = PickleSerializer
DillSerializerv2 = PickleSerializerv2

16 changes: 11 additions & 5 deletions aiokeydb/types/task_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,7 @@ def serialize(self, job: Job):
"""
Dumps a job.
"""
# logger.info(f"Serializing Job: {job}", prefix = f'{self.serializer.__name__}', colored = True)
try:
return self.serializer.dumps(job.to_dict())
except Exception as e:
Expand All @@ -604,6 +605,7 @@ def deserialize(self, job_bytes: bytes):
"""
if not job_bytes: return None
job_dict: typing.Dict = self.serializer.loads(job_bytes)
# logger.info(f"Deserialized Job: {job_dict}", prefix = f'{self.serializer.__name__}', colored = True)
assert (
job_dict.pop("queue") == self.queue_name
), f"Job {job_dict} fetched by wrong queue: {self.queue_name}"
Expand Down Expand Up @@ -855,7 +857,7 @@ async def update(self, job: Job):
Update a job.
"""
job.touched = now()
await self.ctx.async_set(job.id, self.serialize(job))
await self.ctx.async_client.set(job.id, self.serialize(job))
await self.notify(job)
if self.function_tracker_enabled:
await self.track_job_id(job)
Expand Down Expand Up @@ -2033,13 +2035,15 @@ async def sync_queue_info(self):
Syncs the current queue info with keydb
"""
queue_info = await self.info(jobs = True)
await self.ctx.async_set(self.queue_info_key, json.dumps(queue_info, cls = ObjectEncoder), ex = 60)
# await self.ctx.async_set(self.queue_info_key, json.dumps(queue_info, cls = ObjectEncoder), ex = 60)
await self.ctx.async_client.set(self.queue_info_key, json.dumps(queue_info, cls = ObjectEncoder), ex = 60)

async def fetch_queue_info(self):
"""
Fetches the current queue info from keydb
"""
queue_info = await self.ctx.async_get(self.queue_info_key)
# queue_info = await self.ctx.async_get(self.queue_info_key)
queue_info = await self.ctx.async_client.get(self.queue_info_key)
return json.loads(queue_info) if queue_info else {}

async def stats(self, ttl: int = 60):
Expand Down Expand Up @@ -2291,7 +2295,8 @@ async def track_job(self, job: Job):
async with self._fail_ok(verbose = False):
function_tracker = await self._get_function_tracker(job.function, none_ok = False)
function_tracker.track_job(job)
await self.ctx.async_set(f'{self._stats.function_tracker_key}.{job.function}', function_tracker.serialize(), ex = self.function_tracker_ttl)
# await self.ctx.async_set(f'{self._stats.function_tracker_key}.{job.function}', function_tracker.serialize(), ex = self.function_tracker_ttl)
await self.ctx.async_client.set(f'{self._stats.function_tracker_key}.{job.function}', function_tracker.serialize(), ex = self.function_tracker_ttl)
await self.track_job_id(job)

async def get_function_trackers(self) -> typing.Dict[str, FunctionTracker]:
Expand All @@ -2302,7 +2307,8 @@ async def get_function_trackers(self) -> typing.Dict[str, FunctionTracker]:
_keys = await self.ctx.async_keys(f'{self._stats.function_tracker_key}.*')
_function_trackers = {}
for key in _keys:
function_tracker = await self.ctx.async_get(key, default = None)
# function_tracker = await self.ctx.async_get(key, default = None)
function_tracker = await self.ctx.async_client.get(key)
if function_tracker:
function_tracker = FunctionTracker.deserialize(function_tracker)
_function_trackers[function_tracker.function] = function_tracker
Expand Down
2 changes: 1 addition & 1 deletion aiokeydb/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
VERSION = '0.2.0rc2'
VERSION = '0.2.0rc3'

0 comments on commit 244aaaf

Please sign in to comment.