diff --git a/gemd/json/gemd_encoder.py b/gemd/json/gemd_encoder.py index 631f54e5..f7160446 100644 --- a/gemd/json/gemd_encoder.py +++ b/gemd/json/gemd_encoder.py @@ -1,4 +1,5 @@ from json import JSONEncoder +from uuid import UUID from gemd.entity.dict_serializable import DictSerializable from gemd.enumeration.base_enumeration import BaseEnumeration @@ -13,5 +14,7 @@ def default(self, o): return o.as_dict() elif isinstance(o, BaseEnumeration): return o.value + elif isinstance(o, UUID): + return str(o) else: return JSONEncoder.default(self, o) diff --git a/gemd/json/gemd_json.py b/gemd/json/gemd_json.py index 35bdd2d1..f01276dc 100644 --- a/gemd/json/gemd_json.py +++ b/gemd/json/gemd_json.py @@ -1,5 +1,6 @@ import inspect from deprecation import deprecated +from typing import Dict, Any, Type from gemd.entity.dict_serializable import DictSerializable from gemd.entity.base_entity import BaseEntity @@ -21,7 +22,7 @@ class GEMDJson(object): def __init__(self, scope: str = 'auto'): self._scope = scope - self._clazz_index = DictSerializable.class_mapping + self._clazz_index = dict() @property def scope(self) -> str: @@ -74,8 +75,15 @@ def loads(self, json_str: str, **kwargs): # Create an index to hold the objects by their uid reference # so we can replace links with pointers index = {} + clazz_index = DictSerializable.class_mapping + clazz_index.update(self._clazz_index) raw = json_builtin.loads( - json_str, object_hook=lambda x: self._load_and_index(x, index, True), **kwargs) + json_str, + object_hook=lambda x: self._load_and_index(x, + index, + clazz_index=clazz_index, + substitute=True), + **kwargs) # the return value is in the 2nd position. return raw["object"] @@ -196,8 +204,12 @@ def raw_loads(self, json_str, **kwargs): # Create an index to hold the objects by their uid reference # so we can replace links with pointers index = {} + clazz_index = DictSerializable.class_mapping + clazz_index.update(self._clazz_index) return json_builtin.loads( - json_str, object_hook=lambda x: self._load_and_index(x, index), **kwargs) + json_str, + object_hook=lambda x: self._load_and_index(x, index, clazz_index=clazz_index), + **kwargs) @deprecated(deprecated_in="1.13.0", removed_in="2.0.0", details="Classes are now automatically registered when extending BaseEntity") @@ -229,7 +241,12 @@ def register_classes(self, classes): self._clazz_index.update(classes) - def _load_and_index(self, d, object_index, substitute=False): + @staticmethod + def _load_and_index( + d: Dict[str, Any], + object_index: Dict[str, DictSerializable], + clazz_index: Dict[str, Type], + substitute: bool = False) -> DictSerializable: """ Load the class based on the type string and index it, if a BaseEntity. @@ -254,10 +271,10 @@ def _load_and_index(self, d, object_index, substitute=False): return d typ = d.pop("type") - if typ not in self._clazz_index: + if typ not in clazz_index: raise TypeError("Unexpected base object type: {}".format(typ)) - clz = self._clazz_index[typ] + clz = clazz_index[typ] obj = clz.from_dict(d) if isinstance(obj, BaseEntity): # Add it to the object index diff --git a/gemd/json/tests/test_json.py b/gemd/json/tests/test_json.py index 96696f70..88902288 100644 --- a/gemd/json/tests/test_json.py +++ b/gemd/json/tests/test_json.py @@ -1,6 +1,7 @@ """Test serialization and deserialization of gemd objects.""" import json from copy import deepcopy +from uuid import uuid4 import pytest @@ -10,6 +11,7 @@ from gemd.entity.case_insensitive_dict import CaseInsensitiveDict from gemd.entity.attribute.condition import Condition from gemd.entity.attribute.parameter import Parameter +from gemd.entity.dict_serializable import DictSerializable from gemd.entity.link_by_uid import LinkByUID from gemd.entity.object import MeasurementRun, MaterialRun, ProcessRun from gemd.entity.object import MeasurementSpec, MaterialSpec, ProcessSpec @@ -51,7 +53,9 @@ def test_deserialize(): """Round-trip serde should leave the object unchanged.""" condition = Condition(name="A condition", value=NominalReal(7, '')) parameter = Parameter(name="A parameter", value=NormalReal(mean=17, std=1, units='')) - measurement = MeasurementRun("name", tags="A tag on a measurement", conditions=condition, + measurement = MeasurementRun("name", + tags="A tag on a measurement", + conditions=condition, parameters=parameter) copy_meas = GEMDJson().copy(measurement) assert(copy_meas.conditions[0].value == measurement.conditions[0].value) @@ -59,6 +63,17 @@ def test_deserialize(): assert(copy_meas.uids["auto"] == measurement.uids["auto"]) +def test_uuid_serde(): + """Any UUIDs in uids & LinkByUIDs shouldn't break stuff.""" + process = ProcessSpec(name="A process", uids={"uuid": uuid4(), "word": "turnbuckle"}) + copy_proc = GEMDJson().copy(process) + assert all(copy_proc.uids[scope] == str(process.uids.get(scope)) for scope in copy_proc.uids) + assert len(copy_proc.uids) == len(process.uids) + + link = LinkByUID(id=uuid4(), scope="mine") + assert GEMDJson().copy(link).id == str(link.id) + + def test_scope_control(): """Serializing a nested object should be identical to individually serializing each piece.""" input_material = MaterialSpec("input_material") @@ -232,7 +247,12 @@ def test_pure_substitutions(): ] ''' index = {} - original = json.loads(json_str, object_hook=lambda x: GEMDJson()._load_and_index(x, index)) + clazz_index = DictSerializable.class_mapping + original = json.loads(json_str, + object_hook=lambda x: GEMDJson()._load_and_index(x, + index, + clazz_index) + ) frozen = deepcopy(original) loaded = substitute_objects(original, index) assert original == frozen diff --git a/setup.py b/setup.py index 10377157..95f9e412 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setup(name='gemd', - version='1.13.1', + version='1.13.2', url='http://github.com/CitrineInformatics/gemd-python', description="Python binding for Citrine's GEMD data model", author='Citrine Informatics',