Skip to content

Commit

Permalink
better naming
Browse files Browse the repository at this point in the history
  • Loading branch information
slevang committed Oct 7, 2024
1 parent fd8b145 commit 17a5efc
Show file tree
Hide file tree
Showing 12 changed files with 30 additions and 27 deletions.
4 changes: 2 additions & 2 deletions tests/models/cross/test_cca.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from xeofs.cross import CCA

from ...utilities import engine_to_module
from ...utilities import skip_if_missing_engine


def generate_random_data(shape, lazy=False, seed=142):
Expand Down Expand Up @@ -233,7 +233,7 @@ def test_save_load(tmp_path, engine):
"""Test save/load methods in MCA class, ensuring that we can
roundtrip the model and get the same results when transforming
data."""
pytest.importorskip(engine_to_module(engine))
skip_if_missing_engine(engine)

X = generate_random_data((200, 10), seed=123)
Y = generate_random_data((200, 20), seed=321)
Expand Down
6 changes: 3 additions & 3 deletions tests/models/cross/test_cpcca.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from xeofs.cross import CPCCA

from ...utilities import engine_to_module
from ...utilities import skip_if_missing_engine


def generate_random_data(shape, lazy=False, seed=142):
Expand Down Expand Up @@ -282,7 +282,7 @@ def test_save_load(tmp_path, engine, alpha):
"""Test save/load methods in MCA class, ensuring that we can
roundtrip the model and get the same results when transforming
data."""
pytest.importorskip(engine_to_module(engine))
skip_if_missing_engine(engine)

X = generate_random_data((200, 10), seed=123)
Y = generate_random_data((200, 20), seed=321)
Expand Down Expand Up @@ -328,7 +328,7 @@ def test_save_load(tmp_path, engine, alpha):
def test_save_load_with_data(tmp_path, engine, alpha):
"""Test save/load methods in CPCCA class, ensuring that we can
roundtrip the model and get the same results for SCF."""
pytest.importorskip(engine_to_module(engine))
skip_if_missing_engine(engine)

X = generate_random_data((200, 10), seed=123)
Y = generate_random_data((200, 20), seed=321)
Expand Down
4 changes: 2 additions & 2 deletions tests/models/cross/test_hilbert_cpcca.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from xeofs.cross import HilbertCPCCA

from ...utilities import engine_to_module
from ...utilities import skip_if_missing_engine


def generate_random_data(shape, lazy=False, seed=142):
Expand Down Expand Up @@ -72,7 +72,7 @@ def test_singular_values(use_pca):
def test_save_load_with_data(tmp_path, engine, alpha):
"""Test save/load methods in CPCCA class, ensuring that we can
roundtrip the model and get the same results."""
pytest.importorskip(engine_to_module(engine))
skip_if_missing_engine(engine)

X = generate_random_data((200, 10), seed=123)
Y = generate_random_data((200, 20), seed=321)
Expand Down
4 changes: 2 additions & 2 deletions tests/models/cross/test_hilbert_mca_rotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# Import the classes from your modules
from xeofs.cross import HilbertMCA, HilbertMCARotator

from ...utilities import engine_to_module
from ...utilities import skip_if_missing_engine


@pytest.fixture
Expand Down Expand Up @@ -248,7 +248,7 @@ def test_scores_phase(mca_model, mock_data_array, dim):
def test_save_load_with_data(tmp_path, engine, mca_model):
"""Test save/load methods in HilbertMCARotator class, ensuring that we can
roundtrip the model and get the same results."""
pytest.importorskip(engine_to_module(engine))
skip_if_missing_engine(engine)

original = HilbertMCARotator(n_modes=2)
original.fit(mca_model)
Expand Down
4 changes: 2 additions & 2 deletions tests/models/cross/test_mca.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from xeofs.cross import MCA

from ...utilities import data_is_dask, engine_to_module
from ...utilities import data_is_dask, skip_if_missing_engine


@pytest.fixture
Expand Down Expand Up @@ -381,7 +381,7 @@ def test_save_load(dim, mock_data_array, tmp_path, engine):
"""Test save/load methods in MCA class, ensuring that we can
roundtrip the model and get the same results when transforming
data."""
pytest.importorskip(engine_to_module(engine))
skip_if_missing_engine(engine)

original = MCA()
original.fit(mock_data_array, mock_data_array, dim)
Expand Down
4 changes: 2 additions & 2 deletions tests/models/cross/test_mca_rotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# Import the classes from your modules
from xeofs.cross import MCA, MCARotator

from ...utilities import data_is_dask, engine_to_module
from ...utilities import data_is_dask, skip_if_missing_engine


@pytest.fixture
Expand Down Expand Up @@ -235,7 +235,7 @@ def test_save_load(dim, mock_data_array, tmp_path, engine):
"""Test save/load methods in MCA class, ensuring that we can
roundtrip the model and get the same results when transforming
data."""
pytest.importorskip(engine_to_module(engine))
skip_if_missing_engine(engine)

original_unrotated = MCA()
original_unrotated.fit(mock_data_array, mock_data_array, dim)
Expand Down
4 changes: 2 additions & 2 deletions tests/models/cross/test_rda.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from xeofs.cross import RDA

from ...utilities import engine_to_module
from ...utilities import skip_if_missing_engine


def generate_random_data(shape, lazy=False, seed=142):
Expand Down Expand Up @@ -233,7 +233,7 @@ def test_save_load(tmp_path, engine):
"""Test save/load methods in MCA class, ensuring that we can
roundtrip the model and get the same results when transforming
data."""
pytest.importorskip(engine_to_module(engine))
skip_if_missing_engine(engine)

X = generate_random_data((200, 10), seed=123)
Y = generate_random_data((200, 20), seed=321)
Expand Down
4 changes: 2 additions & 2 deletions tests/models/single/test_eof.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from xeofs.single import EOF

from ...utilities import engine_to_module
from ...utilities import skip_if_missing_engine


def test_init():
Expand Down Expand Up @@ -501,7 +501,7 @@ def test_save_load(dim, mock_data_array, tmp_path, engine):
"""Test save/load methods in EOF class, ensuring that we can
roundtrip the model and get the same results when transforming
data."""
pytest.importorskip(engine_to_module(engine))
skip_if_missing_engine(engine)

original = EOF()
original.fit(mock_data_array, dim)
Expand Down
4 changes: 2 additions & 2 deletions tests/models/single/test_eof_rotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from xeofs.data_container import DataContainer
from xeofs.single import EOF, EOFRotator

from ...utilities import data_is_dask, engine_to_module
from ...utilities import data_is_dask, skip_if_missing_engine


@pytest.fixture
Expand Down Expand Up @@ -208,7 +208,7 @@ def test_save_load(dim, mock_data_array, tmp_path, engine):
"""Test save/load methods in EOF class, ensuring that we can
roundtrip the model and get the same results when transforming
data."""
pytest.importorskip(engine_to_module(engine))
skip_if_missing_engine(engine)

original_unrotated = EOF()
original_unrotated.fit(mock_data_array, dim)
Expand Down
4 changes: 2 additions & 2 deletions tests/models/single/test_pop.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from xeofs.single import POP

from ...utilities import engine_to_module
from ...utilities import skip_if_missing_engine


def test_init():
Expand Down Expand Up @@ -161,7 +161,7 @@ def test_save_load(mock_data_array, tmp_path, engine):
roundtrip the model and get the same results when transforming
data."""
# NOTE: netcdf4 does not support complex data types, so we use only zarr here
pytest.importorskip(engine_to_module(engine))
skip_if_missing_engine(engine)

dim = "time"
original = POP()
Expand Down
4 changes: 2 additions & 2 deletions tests/models/single/test_sparse_pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from xeofs.single import SparsePCA

from ...utilities import engine_to_module
from ...utilities import skip_if_missing_engine


def test_init():
Expand Down Expand Up @@ -490,7 +490,7 @@ def test_save_load(dim, mock_data_array, tmp_path, engine):
"""Test save/load methods in SparsePCA class, ensuring that we can
roundtrip the model and get the same results when transforming
data."""
pytest.importorskip(engine_to_module(engine))
skip_if_missing_engine(engine)

original = SparsePCA()
original.fit(mock_data_array, dim)
Expand Down
11 changes: 7 additions & 4 deletions tests/utilities.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pytest
import numpy as np
import pandas as pd
from xeofs.utils.data_types import (
Expand Down Expand Up @@ -150,9 +151,11 @@ def assert_expected_coords(data1, data2, policy="all") -> None:
)


def engine_to_module(engine: str) -> str:
def skip_if_missing_engine(engine: str):
"""
Required for import skipping because xarray uses `engine="netcdf4"`
but the package itself is called `netCDF4`."""
Skip save/load tests if missing the i/o backend.
"""
# xarray uses engine="netcdf4" but the package itself is called "netCDF4".
mapping = {"netcdf4": "netCDF4"}
return mapping.get(engine, engine)
module = mapping.get(engine, engine)
pytest.importorskip(module)

0 comments on commit 17a5efc

Please sign in to comment.