From 5ed9753bc7d8ce349ac75f9043e61c2f6a20677e Mon Sep 17 00:00:00 2001 From: Sam Levang <39069044+slevang@users.noreply.github.com> Date: Thu, 12 Sep 2024 18:40:27 -0400 Subject: [PATCH] fix: xarray 2024.09.0 compatility (#227) --- xeofs/base_model.py | 2 +- xeofs/data_container/data_container.py | 2 +- xeofs/preprocessing/preprocessor.py | 1 - xeofs/preprocessing/transformer.py | 4 ++-- 4 files changed, 4 insertions(+), 5 deletions(-) diff --git a/xeofs/base_model.py b/xeofs/base_model.py index 7f13899..16aae46 100644 --- a/xeofs/base_model.py +++ b/xeofs/base_model.py @@ -90,7 +90,7 @@ def serialize(self) -> DataTree: """Serialize a complete model with its preprocessor.""" # Create a root node for this object with its params as attrs ds_root = xr.Dataset(attrs=dict(params=self.get_params())) - dt = DataTree(data=ds_root, name=type(self).__name__) + dt = DataTree(ds_root, name=type(self).__name__) # Retrieve the tree representation of each attached object, or set basic attrs for key, attr in self.get_serialization_attrs().items(): diff --git a/xeofs/data_container/data_container.py b/xeofs/data_container/data_container.py index 9b8f948..2729e02 100644 --- a/xeofs/data_container/data_container.py +++ b/xeofs/data_container/data_container.py @@ -36,7 +36,7 @@ def serialize(self) -> DataTree: for key, data in self.items(): if not data.name: data.name = key - dt[key] = DataTree(data) + dt[key] = DataTree(data.to_dataset()) dt[key].attrs = {key: "_is_node", "allow_compute": self._allow_compute[key]} return dt diff --git a/xeofs/preprocessing/preprocessor.py b/xeofs/preprocessing/preprocessor.py index d258323..9d8a9d5 100644 --- a/xeofs/preprocessing/preprocessor.py +++ b/xeofs/preprocessing/preprocessor.py @@ -391,7 +391,6 @@ def serialize(self) -> DataTree: dt_transformer = transformer_obj.serialize() # Place the serialized transformer in the tree dt[name] = dt_transformer - dt[name].parent = dt return dt diff --git a/xeofs/preprocessing/transformer.py b/xeofs/preprocessing/transformer.py index 6c7935a..6666428 100644 --- a/xeofs/preprocessing/transformer.py +++ b/xeofs/preprocessing/transformer.py @@ -140,7 +140,7 @@ def _serialize(self) -> DataTree: if isinstance(attr, (xr.DataArray, xr.Dataset)): # attach data to data_vars or coords ds = self._serialize_data(key, attr) - dt[key] = DataTree(name=key, data=ds) + dt[key] = DataTree(ds, name=key) dt.attrs[key] = "_is_node" elif isinstance(attr, dict) and any( [isinstance(val, xr.DataArray) for val in attr.values()] @@ -149,7 +149,7 @@ def _serialize(self) -> DataTree: dt_attr = DataTree() for k, v in attr.items(): ds = self._serialize_data(k, v) - dt_attr[k] = DataTree(name=k, data=ds) + dt_attr[k] = DataTree(ds, name=k) dt[key] = dt_attr dt.attrs[key] = "_is_tree" else: