diff --git a/tests/models/cross/test_cpcca.py b/tests/models/cross/test_cpcca.py index 258ff14..fd81282 100644 --- a/tests/models/cross/test_cpcca.py +++ b/tests/models/cross/test_cpcca.py @@ -319,6 +319,38 @@ def test_save_load(tmp_path, engine, alpha): assert np.allclose(XYr_o[1], XYr_l[1]) +@pytest.mark.parametrize("engine", ["netcdf4", "zarr"]) +@pytest.mark.parametrize("alpha", [0.0, 0.5, 1.0]) +def test_save_load_with_data(tmp_path, engine, alpha): + """Test save/load methods in CPCCA class, ensuring that we can + roundtrip the model and get the same results for SCF.""" + X = generate_random_data((200, 10), seed=123) + Y = generate_random_data((200, 20), seed=321) + + original = CPCCA(alpha=alpha) + original.fit(X, Y, "sample") + + # Save the CPCCA model + original.save(tmp_path / "cpcca", engine=engine, save_data=True) + + # Check that the CPCCA model has been saved + assert (tmp_path / "cpcca").exists() + + # Recreate the model from saved file + loaded = CPCCA.load(tmp_path / "cpcca", engine=engine) + + # Check that the params and DataContainer objects match + assert original.get_params() == loaded.get_params() + assert all([key in loaded.data for key in original.data]) + for key in original.data: + assert loaded.data[key].equals(original.data[key]) + + # Test that the recreated model can compute the SCF + assert np.allclose( + original.squared_covariance_fraction(), loaded.squared_covariance_fraction() + ) + + def test_serialize_deserialize_dataarray(mock_data_array): """Test roundtrip serialization when the model is fit on a DataArray.""" model = CPCCA() diff --git a/tests/models/cross/test_hilbert_cpcca.py b/tests/models/cross/test_hilbert_cpcca.py new file mode 100644 index 0000000..02f6455 --- /dev/null +++ b/tests/models/cross/test_hilbert_cpcca.py @@ -0,0 +1,97 @@ +import dask.array as da +import numpy as np +import pytest +import xarray as xr + +from xeofs.cross import HilbertCPCCA + + +def generate_random_data(shape, lazy=False, seed=142): + rng = np.random.default_rng(seed) + if lazy: + return xr.DataArray( + da.random.random(shape, chunks=(5, 5)), + dims=["sample", "feature"], + coords={"sample": np.arange(shape[0]), "feature": np.arange(shape[1])}, + ) + else: + return xr.DataArray( + rng.random(shape), + dims=["sample", "feature"], + coords={"sample": np.arange(shape[0]), "feature": np.arange(shape[1])}, + ) + + +def generate_well_conditioned_data(lazy=False): + rng = np.random.default_rng(142) + t = np.linspace(0, 50, 200) + std = 0.1 + x1 = np.sin(t)[:, None] + rng.normal(0, std, size=(200, 2)) + x2 = np.sin(t)[:, None] + rng.normal(0, std, size=(200, 3)) + x1[:, 1] = x1[:, 1] ** 2 + x2[:, 1] = x2[:, 1] ** 3 + x2[:, 2] = abs(x2[:, 2]) ** (0.5) + coords_time = np.arange(len(t)) + coords_fx = [1, 2] + coords_fy = [1, 2, 3] + X = xr.DataArray( + x1, + dims=["sample", "feature"], + coords={"sample": coords_time, "feature": coords_fx}, + ) + Y = xr.DataArray( + x2, + dims=["sample", "feature"], + coords={"sample": coords_time, "feature": coords_fy}, + ) + if lazy: + X = X.chunk({"sample": 5, "feature": -1}) + Y = Y.chunk({"sample": 5, "feature": -1}) + return X, Y + else: + return X, Y + + +# Currently, netCDF4 does not support complex numbers, so skip this test +@pytest.mark.parametrize("engine", ["zarr"]) +@pytest.mark.parametrize("alpha", [0.0, 0.5, 1.0]) +def test_save_load_with_data(tmp_path, engine, alpha): + """Test save/load methods in CPCCA class, ensuring that we can + roundtrip the model and get the same results.""" + X = generate_random_data((200, 10), seed=123) + Y = generate_random_data((200, 20), seed=321) + + original = HilbertCPCCA(alpha=alpha) + original.fit(X, Y, "sample") + + # Save the CPCCA model + original.save(tmp_path / "cpcca", engine=engine, save_data=True) + + # Check that the CPCCA model has been saved + assert (tmp_path / "cpcca").exists() + + # Recreate the model from saved file + loaded = HilbertCPCCA.load(tmp_path / "cpcca", engine=engine) + + # Check that the params and DataContainer objects match + assert original.get_params() == loaded.get_params() + assert all([key in loaded.data for key in original.data]) + for key in original.data: + assert loaded.data[key].equals(original.data[key]) + + # Test that the recreated model can compute the SCF + assert np.allclose( + original.squared_covariance_fraction(), loaded.squared_covariance_fraction() + ) + + # Test that the recreated model can compute the components amplitude + A1_original, A2_original = original.components_amplitude() + A1_loaded, A2_loaded = loaded.components_amplitude() + assert np.allclose(A1_original, A1_loaded) + assert np.allclose(A2_original, A2_loaded) + + # Test that the recreated model can compute the components phase + P1_original, P2_original = original.components_phase() + P1_loaded, P2_loaded = loaded.components_phase() + assert np.allclose(P1_original, P1_loaded) + assert np.allclose(P2_original, P2_loaded) diff --git a/tests/models/cross/test_hilbert_mca_rotator.py b/tests/models/cross/test_hilbert_mca_rotator.py index 7ef22bf..2172028 100644 --- a/tests/models/cross/test_hilbert_mca_rotator.py +++ b/tests/models/cross/test_hilbert_mca_rotator.py @@ -231,3 +231,52 @@ def test_scores_phase(mca_model, mock_data_array, dim): mca_rotator = HilbertMCARotator(n_modes=2) mca_rotator.fit(mca_model) amps1, amps2 = mca_rotator.scores_phase() + + +@pytest.mark.parametrize( + "dim", + [ + (("time",)), + (("lat", "lon")), + (("lon", "lat")), + ], +) +# Currently, netCDF4 does not support complex numbers, so skip this test +@pytest.mark.parametrize("engine", ["zarr"]) +def test_save_load_with_data(tmp_path, engine, mca_model): + """Test save/load methods in HilbertMCARotator class, ensuring that we can + roundtrip the model and get the same results.""" + original = HilbertMCARotator(n_modes=2) + original.fit(mca_model) + + # Save the HilbertMCARotator model + original.save(tmp_path / "mca", engine=engine, save_data=True) + + # Check that the HilbertMCARotator model has been saved + assert (tmp_path / "mca").exists() + + # Recreate the model from saved file + loaded = HilbertMCARotator.load(tmp_path / "mca", engine=engine) + + # Check that the params and DataContainer objects match + assert original.get_params() == loaded.get_params() + assert all([key in loaded.data for key in original.data]) + for key in original.data: + assert loaded.data[key].equals(original.data[key]) + + # Test that the recreated model can compute the SCF + assert np.allclose( + original.squared_covariance_fraction(), loaded.squared_covariance_fraction() + ) + + # Test that the recreated model can compute the components amplitude + A1_original, A2_original = original.components_amplitude() + A1_loaded, A2_loaded = loaded.components_amplitude() + assert np.allclose(A1_original, A1_loaded) + assert np.allclose(A2_original, A2_loaded) + + # Test that the recreated model can compute the components phase + P1_original, P2_original = original.components_phase() + P1_loaded, P2_loaded = loaded.components_phase() + assert np.allclose(P1_original, P1_loaded) + assert np.allclose(P2_original, P2_loaded) diff --git a/xeofs/cross/base_model_cross_set.py b/xeofs/cross/base_model_cross_set.py index a7d8bdb..352106e 100644 --- a/xeofs/cross/base_model_cross_set.py +++ b/xeofs/cross/base_model_cross_set.py @@ -511,6 +511,8 @@ def get_serialization_attrs(self) -> dict: preprocessor2=self.preprocessor2, whitener1=self.whitener1, whitener2=self.whitener2, + sample_name=self.sample_name, + feature_name=self.feature_name, ) def _augment_data(self, X: DataArray, Y: DataArray) -> tuple[DataArray, DataArray]: diff --git a/xeofs/cross/cpcca.py b/xeofs/cross/cpcca.py index 11cc914..ad88808 100644 --- a/xeofs/cross/cpcca.py +++ b/xeofs/cross/cpcca.py @@ -1218,8 +1218,8 @@ def components_phase(self, normalized=True) -> tuple[DataObject, DataObject]: Px = self.whitener1.inverse_transform_components(Px) Py = self.whitener2.inverse_transform_components(Py) - Px = xr.apply_ufunc(np.angle, Px, keep_attrs=True) - Py = xr.apply_ufunc(np.angle, Py, keep_attrs=True) + Px = xr.apply_ufunc(np.angle, Px, keep_attrs=True, dask="allowed") + Py = xr.apply_ufunc(np.angle, Py, keep_attrs=True, dask="allowed") Px.name = "components_phase_X" Py.name = "components_phase_Y" @@ -1288,8 +1288,8 @@ def scores_phase(self, normalized=False) -> tuple[DataArray, DataArray]: Rx = self.whitener1.inverse_transform_scores(Rx) Ry = self.whitener2.inverse_transform_scores(Ry) - Rx = xr.apply_ufunc(np.angle, Rx, keep_attrs=True) - Ry = xr.apply_ufunc(np.angle, Ry, keep_attrs=True) + Rx = xr.apply_ufunc(np.angle, Rx, keep_attrs=True, dask="allowed") + Ry = xr.apply_ufunc(np.angle, Ry, keep_attrs=True, dask="allowed") Rx.name = "scores_phase_X" Ry.name = "scores_phase_Y" diff --git a/xeofs/cross/cpcca_rotator.py b/xeofs/cross/cpcca_rotator.py index 162157c..1e2fd0d 100644 --- a/xeofs/cross/cpcca_rotator.py +++ b/xeofs/cross/cpcca_rotator.py @@ -111,6 +111,8 @@ def get_serialization_attrs(self) -> dict: whitener2=self.whitener2, model=self.model, sorted=self.sorted, + sample_name=self.sample_name, + feature_name=self.feature_name, ) def _fit_algorithm(self, model) -> Self: