Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add POP analysis #215

Merged
merged 1 commit into from
Sep 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions docs/api_reference/_autosummary/xeofs.single.POP.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
POP
===

.. currentmodule:: xeofs.single

.. autoclass:: POP
:members:
:inherited-members:


.. automethod:: __init__


.. rubric:: Methods

.. autosummary::

~POP.__init__
~POP.components
~POP.components_amplitude
~POP.components_phase
~POP.compute
~POP.damping_times
~POP.deserialize
~POP.eigenvalues
~POP.fit
~POP.fit_transform
~POP.get_params
~POP.get_serialization_attrs
~POP.inverse_transform
~POP.load
~POP.periods
~POP.save
~POP.scores
~POP.scores_amplitude
~POP.scores_phase
~POP.serialize
~POP.transform






1 change: 1 addition & 0 deletions docs/api_reference/single_set_analysis.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Methods that investigate relationships or patterns between variables within a si
~xeofs.single.ComplexEOF
~xeofs.single.HilbertEOF
~xeofs.single.ExtendedEOF
~xeofs.single.POP
~xeofs.single.OPA
~xeofs.single.GWPCA
~xeofs.single.SparsePCA
Expand Down
Binary file modified docs/auto_examples/auto_examples_jupyter.zip
Binary file not shown.
Binary file modified docs/auto_examples/auto_examples_python.zip
Binary file not shown.
219 changes: 219 additions & 0 deletions tests/models/single/test_pop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
import numpy as np
import pytest
import xarray as xr

from xeofs.single import POP


def test_init():
"""Tests the initialization of the POP class"""
pop = POP(n_modes=5, standardize=True, use_coslat=True)

# Assert preprocessor has been initialized
assert hasattr(pop, "_params")
assert hasattr(pop, "preprocessor")
assert hasattr(pop, "whitener")


def test_fit(mock_data_array):
pop = POP()
pop.fit(mock_data_array, "time")


def test_eigenvalues(mock_data_array):
pop = POP()
pop.fit(mock_data_array, "time")

eigenvalues = pop.eigenvalues()
assert isinstance(eigenvalues, xr.DataArray)


def test_damping_times(mock_data_array):
pop = POP()
pop.fit(mock_data_array, "time")

times = pop.damping_times()
assert isinstance(times, xr.DataArray)


def test_periods(mock_data_array):
pop = POP()
pop.fit(mock_data_array, "time")

periods = pop.periods()
assert isinstance(periods, xr.DataArray)


def test_components(mock_data_array):
"""Tests the components method of the POP class"""
sample_dim = ("time",)
pop = POP()
pop.fit(mock_data_array, sample_dim)

# Test components method
components = pop.components()
feature_dims = tuple(set(mock_data_array.dims) - set(sample_dim))
assert isinstance(components, xr.DataArray), "Components is not a DataArray"
assert set(components.dims) == set(
("mode",) + feature_dims
), "Components does not have the right feature dimensions"


def test_scores(mock_data_array):
"""Tests the scores method of the POP class"""
sample_dim = ("time",)
pop = POP()
pop.fit(mock_data_array, sample_dim)

# Test scores method
scores = pop.scores()
assert isinstance(scores, xr.DataArray), "Scores is not a DataArray"
assert set(scores.dims) == set(
(sample_dim + ("mode",))
), "Scores does not have the right dimensions"


def test_transform(mock_data_array):
"""Test projecting new unseen data onto the POPs"""
dim = ("time",)
data = mock_data_array.isel({dim[0]: slice(1, None)})
new_data = mock_data_array.isel({dim[0]: slice(0, 1)})

# Create a xarray DataArray with random data
model = POP(n_modes=2, solver="full")
model.fit(data, dim)
scores = model.scores()

# Project data onto the components
projections = model.transform(data)

# Check that the projection has the right dimensions
assert projections.dims == scores.dims, "Projection has wrong dimensions" # type: ignore

# Check that the projection has the right data type
assert isinstance(projections, xr.DataArray), "Projection is not a DataArray"

# Check that the projection has the right name
assert projections.name == "scores", "Projection has wrong name: {}".format(
projections.name
)

# Check that the projection's data is the same as the scores
np.testing.assert_allclose(
scores.sel(mode=slice(1, 3)), projections.sel(mode=slice(1, 3)), rtol=1e-3
)

# Project unseen data onto the components
new_projections = model.transform(new_data)

# Check that the projection has the right dimensions
assert new_projections.dims == scores.dims, "Projection has wrong dimensions" # type: ignore

# Check that the projection has the right data type
assert isinstance(new_projections, xr.DataArray), "Projection is not a DataArray"

# Check that the projection has the right name
assert new_projections.name == "scores", "Projection has wrong name: {}".format(
new_projections.name
)

# Ensure that the new projections are not NaNs
assert np.all(new_projections.notnull().values), "New projections contain NaNs"


def test_inverse_transform(mock_data_array):
"""Test inverse_transform method in POP class."""

dim = ("time",)
# instantiate the POP class with necessary parameters
pop = POP(n_modes=20, standardize=True)

# fit the POP model
pop.fit(mock_data_array, dim=dim)
scores = pop.scores()

# Test with single mode
scores_selection = scores.sel(mode=1)
X_rec_1 = pop.inverse_transform(scores_selection)
assert isinstance(X_rec_1, xr.DataArray)

# Test with single mode as list
scores_selection = scores.sel(mode=[1])
X_rec_1_list = pop.inverse_transform(scores_selection)
assert isinstance(X_rec_1_list, xr.DataArray)

# Single mode and list should be equal
xr.testing.assert_allclose(X_rec_1, X_rec_1_list)

# Test with all modes
X_rec = pop.inverse_transform(scores)
assert isinstance(X_rec, xr.DataArray)

# Check that the reconstructed data has the same dimensions as the original data
assert set(X_rec.dims) == set(mock_data_array.dims)


@pytest.mark.parametrize("engine", ["zarr"])
def test_save_load(mock_data_array, tmp_path, engine):
"""Test save/load methods in POP class, ensuring that we can
roundtrip the model and get the same results when transforming
data."""
# NOTE: netcdf4 does not support complex data types, so we use only zarr here
dim = "time"
original = POP()
original.fit(mock_data_array, dim)

# Save the POP model
original.save(tmp_path / "pop", engine=engine)

# Check that the POP model has been saved
assert (tmp_path / "pop").exists()

# Recreate the model from saved file
loaded = POP.load(tmp_path / "pop", engine=engine)

# Check that the params and DataContainer objects match
assert original.get_params() == loaded.get_params()
assert all([key in loaded.data for key in original.data])
for key in original.data:
if original.data._allow_compute[key]:
assert loaded.data[key].equals(original.data[key])
else:
# but ensure that input data is not saved by default
assert loaded.data[key].size <= 1
assert loaded.data[key].attrs["placeholder"] is True

# Test that the recreated model can be used to transform new data
assert np.allclose(
original.transform(mock_data_array), loaded.transform(mock_data_array)
)

# The loaded model should also be able to inverse_transform new data
assert np.allclose(
original.inverse_transform(original.scores()),
loaded.inverse_transform(loaded.scores()),
)


def test_serialize_deserialize_dataarray(mock_data_array):
"""Test roundtrip serialization when the model is fit on a DataArray."""
dim = "time"
model = POP()
model.fit(mock_data_array, dim)
dt = model.serialize()
rebuilt_model = POP.deserialize(dt)
assert np.allclose(
model.transform(mock_data_array), rebuilt_model.transform(mock_data_array)
)


def test_serialize_deserialize_dataset(mock_dataset):
"""Test roundtrip serialization when the model is fit on a Dataset."""
dim = "time"
model = POP()
model.fit(mock_dataset, dim)
dt = model.serialize()
rebuilt_model = POP.deserialize(dt)
assert np.allclose(
model.transform(mock_dataset), rebuilt_model.transform(mock_dataset)
)
1 change: 1 addition & 0 deletions xeofs/cross/base_model_cross_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ def fit(

if self.get_params()["compute"]:
self.data.compute()
self._post_compute()

return self

Expand Down
3 changes: 3 additions & 0 deletions xeofs/linalg/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .utils import total_variance

__all__ = ["total_variance"]
14 changes: 13 additions & 1 deletion xeofs/linalg/_numpy/_svd.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def __init__(
solver: str = "auto",
random_state: np.random.Generator | int | None = None,
solver_kwargs: dict = {},
is_complex: bool | str = "auto",
):
sanity_check_n_modes(n_modes)
self.is_based_on_variance = True if isinstance(n_modes, float) else False
Expand All @@ -83,6 +84,7 @@ def __init__(
self.solver = solver
self.random_state = random_state
self.solver_kwargs = solver_kwargs
self.is_complex = is_complex

def _get_n_modes_precompute(self, rank: int) -> int:
if self.is_based_on_variance:
Expand Down Expand Up @@ -122,7 +124,17 @@ def fit_transform(self, X):
# Source: https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.PCA.html

use_dask = True if isinstance(X, DaskArray) else False
use_complex = True if np.iscomplexobj(X) else False

match self.is_complex:
case bool():
use_complex = self.is_complex
case "auto":
use_complex = True if np.iscomplexobj(X) else False
case _:
raise ValueError(
f"Unrecognized value for is_complex '{self.is_complex}'. "
"Valid options are True, False, and 'auto'."
)

is_small_data = max(X.shape) < 500

Expand Down
3 changes: 3 additions & 0 deletions xeofs/linalg/svd.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ class SVD:
def __init__(
self,
n_modes: int | float | str,
is_complex: bool | str = "auto",
init_rank_reduction: float = 0.3,
flip_signs: bool = True,
solver: str = "auto",
Expand All @@ -20,6 +21,7 @@ def __init__(
feature_name: str = "feature",
):
self.n_modes = n_modes
self.is_complex = is_complex
self.init_rank_reduction = init_rank_reduction
self.flip_signs = flip_signs
self.solver = solver
Expand Down Expand Up @@ -54,6 +56,7 @@ def fit_transform(self, X: DataArray) -> tuple[DataArray, DataArray, DataArray]:
flip_signs=self.flip_signs,
solver=self.solver,
random_state=self.random_state,
is_complex=self.is_complex,
**self.solver_kwargs,
)
U, s, V = xr.apply_ufunc(
Expand Down
6 changes: 6 additions & 0 deletions xeofs/linalg/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from ..utils.data_types import DataArray


def total_variance(X: DataArray, dim: str) -> DataArray:
"""Compute the total variance of the centered data."""
return (X * X.conj()).sum() / (X[dim].size - 1)
2 changes: 2 additions & 0 deletions xeofs/single/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
from .eof_rotator import ComplexEOFRotator, EOFRotator, HilbertEOFRotator
from .gwpca import GWPCA
from .opa import OPA
from .pop import POP
from .sparse_pca import SparsePCA

__all__ = [
"EOF",
"ExtendedEOF",
"SparsePCA",
"POP",
"OPA",
"GWPCA",
"ComplexEOF",
Expand Down
1 change: 1 addition & 0 deletions xeofs/single/base_model_single_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def fit(

if self._params["compute"]:
self.data.compute()
self._post_compute()

return self

Expand Down
Loading
Loading