diff --git a/src/omnipy/compute/mixins/serialize.py b/src/omnipy/compute/mixins/serialize.py index 5ae5ac7c..460a4a09 100644 --- a/src/omnipy/compute/mixins/serialize.py +++ b/src/omnipy/compute/mixins/serialize.py @@ -15,27 +15,16 @@ from omnipy.compute.mixins.name import NameJobBaseMixin from omnipy.config.job import JobConfig from omnipy.data.dataset import Dataset -from omnipy.data.serializer import SerializerRegistry -from omnipy.modules import register_serializers +from omnipy.modules import get_serializer_registry PersistOpts = PersistOutputsOptions RestoreOpts = RestoreOutputsOptions ProtocolOpts = OutputStorageProtocolOptions -def _setup_serializer_registry() -> IsSerializerRegistry: - from omnipy.hub.runtime import runtime - if runtime is not None: - return runtime.objects.serializers - else: - registry = SerializerRegistry() - register_serializers(registry) - return registry - - class SerializerFuncJobBaseMixin: - _serializer_registry: IsSerializerRegistry = _setup_serializer_registry() + _serializer_registry: IsSerializerRegistry = get_serializer_registry() def __init__(self, *, diff --git a/src/omnipy/data/dataset.py b/src/omnipy/data/dataset.py index b0907dc6..db840e73 100644 --- a/src/omnipy/data/dataset.py +++ b/src/omnipy/data/dataset.py @@ -398,14 +398,8 @@ def load(self, tar_gz_file_path: str): @staticmethod def _get_serializer_registry(): - from omnipy.data.serializer import SerializerRegistry - from omnipy.hub.runtime import runtime - from omnipy.modules import register_serializers - serializer_registry = SerializerRegistry() if runtime is None else \ - runtime.objects.serializers - if len(serializer_registry.serializers) == 0: - register_serializers(serializer_registry) - return serializer_registry + from omnipy.modules import get_serializer_registry + return get_serializer_registry def as_multi_model_dataset(self) -> 'MultiModelDataset[ModelT]': multi_model_dataset = MultiModelDataset[self.get_model_class()]() diff --git a/src/omnipy/modules/__init__.py b/src/omnipy/modules/__init__.py index 94d3a7c0..4e2cf881 100644 --- a/src/omnipy/modules/__init__.py +++ b/src/omnipy/modules/__init__.py @@ -13,5 +13,16 @@ def register_serializers(registry: IsSerializerRegistry): registry.register(PandasDatasetToTarFileSerializer) +def get_serializer_registry(): + from omnipy.data.serializer import SerializerRegistry + from omnipy.hub.runtime import runtime + + serializer_registry = SerializerRegistry() if runtime is None else \ + runtime.objects.serializers + if len(serializer_registry.serializers) == 0: + register_serializers(serializer_registry) + return serializer_registry + + # TODO: Add module with helper classes/functions/takss to make it simpler to contact REST apis # Augmentation service should have some attempts at this.