Skip to content

Commit

Permalink
chore: merge branch 'develop' into fix-mca-load
Browse files Browse the repository at this point in the history
  • Loading branch information
nicrie committed Sep 12, 2024
2 parents 6703bb6 + 5ed9753 commit 4c120c5
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 5 deletions.
2 changes: 1 addition & 1 deletion xeofs/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def serialize(self) -> DataTree:
"""Serialize a complete model with its preprocessor."""
# Create a root node for this object with its params as attrs
ds_root = xr.Dataset(attrs=dict(params=self.get_params()))
dt = DataTree(data=ds_root, name=type(self).__name__)
dt = DataTree(ds_root, name=type(self).__name__)

# Retrieve the tree representation of each attached object, or set basic attrs
for key, attr in self.get_serialization_attrs().items():
Expand Down
2 changes: 1 addition & 1 deletion xeofs/data_container/data_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def serialize(self) -> DataTree:
for key, data in self.items():
if not data.name:
data.name = key
dt[key] = DataTree(data)
dt[key] = DataTree(data.to_dataset())
dt[key].attrs = {key: "_is_node", "allow_compute": self._allow_compute[key]}

return dt
Expand Down
1 change: 0 additions & 1 deletion xeofs/preprocessing/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,6 @@ def serialize(self) -> DataTree:
dt_transformer = transformer_obj.serialize()
# Place the serialized transformer in the tree
dt[name] = dt_transformer
dt[name].parent = dt

return dt

Expand Down
4 changes: 2 additions & 2 deletions xeofs/preprocessing/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def _serialize(self) -> DataTree:
if isinstance(attr, (xr.DataArray, xr.Dataset)):
# attach data to data_vars or coords
ds = self._serialize_data(key, attr)
dt[key] = DataTree(name=key, data=ds)
dt[key] = DataTree(ds, name=key)
dt.attrs[key] = "_is_node"
elif isinstance(attr, dict) and any(
[isinstance(val, xr.DataArray) for val in attr.values()]
Expand All @@ -149,7 +149,7 @@ def _serialize(self) -> DataTree:
dt_attr = DataTree()
for k, v in attr.items():
ds = self._serialize_data(k, v)
dt_attr[k] = DataTree(name=k, data=ds)
dt_attr[k] = DataTree(ds, name=k)
dt[key] = dt_attr
dt.attrs[key] = "_is_tree"
else:
Expand Down

0 comments on commit 4c120c5

Please sign in to comment.