Skip to content

Commit

Permalink
add cloudpickle support
Browse files Browse the repository at this point in the history
  • Loading branch information
trisongz committed Oct 20, 2023
1 parent 52d896c commit 43b775c
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 5 deletions.
15 changes: 11 additions & 4 deletions aiokeydb/serializers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from __future__ import absolute_import

from enum import Enum
from typing import Type
from typing import Type, TypeVar
from aiokeydb.types import BaseSerializer
from aiokeydb.serializers._json import JsonSerializer, OrJsonSerializer
from aiokeydb.serializers._pickle import PickleSerializer, DillSerializer, DillSerializerv2, PickleSerializerv2
from aiokeydb.serializers._pickle import PickleSerializer, DillSerializer, DillSerializerv2, PickleSerializerv2, CloudPickleSerializer
from aiokeydb.serializers._msgpack import MsgPackSerializer


SerializerT = TypeVar('SerializerT', bound = Type[BaseSerializer])

class SerializerType(str, Enum):
"""
Enum for the available serializers
Expand All @@ -21,7 +24,9 @@ class SerializerType(str, Enum):
picklev2 = 'picklev2'
dillv2 = 'dillv2'

def get_serializer(self) -> Type[BaseSerializer]:
cloudpickle = 'cloudpickle'

def get_serializer(self) -> SerializerT:
"""
Default Serializer = Dill
"""
Expand All @@ -38,10 +43,12 @@ def get_serializer(self) -> Type[BaseSerializer]:
return PickleSerializerv2
elif self == SerializerType.dillv2:
return DillSerializerv2
elif self == SerializerType.cloudpickle:
return CloudPickleSerializer
elif self == SerializerType.msgpack:
return MsgPackSerializer
elif self == SerializerType.default:
return DillSerializerv2
return DillSerializer
else:
raise ValueError(f'Invalid serializer type: {self}')

Expand Down
37 changes: 37 additions & 0 deletions aiokeydb/serializers/_pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,40 @@ def loads(data: typing.Union[str, typing.Any], *args, **kwargs) -> typing.Any:
DillSerializer = PickleSerializer
DillSerializerv2 = PickleSerializerv2

try:
import cloudpickle
from types import ModuleType

class CloudPickleSerializer(BaseSerializer):

@staticmethod
def dumps(obj: typing.Any, protocol: int = cloudpickle.DEFAULT_PROTOCOL, *args, **kwargs) -> bytes:
"""
Dumps an object to bytes
"""
return cloudpickle.dumps(obj, protocol = protocol, *args, **kwargs)

@staticmethod
def loads(data: typing.Union[str, bytes, typing.Any], *args, **kwargs) -> typing.Any:
"""
Loads an object from bytes
"""
return cloudpickle.loads(data, *args, **kwargs)

@staticmethod
def register_module(module: ModuleType):
"""
Registers a module with cloudpickle
"""
cloudpickle.register_pickle_by_value(module)

@staticmethod
def unregister_module(module: ModuleType):
"""
Registers a class with cloudpickle
"""
cloudpickle.unregister_pickle_by_value(module)


except ImportError:
CloudPickleSerializer = PickleSerializer
15 changes: 15 additions & 0 deletions aiokeydb/types/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""

import typing
from types import ModuleType

class BaseSerializer:

Expand All @@ -21,3 +22,17 @@ def loads(data: typing.Union[str, bytes, typing.Any], **kwargs) -> typing.Any:
"""
raise NotImplementedError


@staticmethod
def register_module(module: ModuleType):
"""
Dummy method that should be overridden by serializers that support
"""
return

@staticmethod
def unregister_module(module: ModuleType):
"""
Dummy method that should be overridden by serializers that support
"""
return
2 changes: 1 addition & 1 deletion aiokeydb/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
VERSION = '0.2.0rc6'
VERSION = '0.2.0rc7'

0 comments on commit 43b775c

Please sign in to comment.