diff --git a/.gitignore b/.gitignore index 9fb7292..91b8f5c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,7 @@ # Personal .vscode/ +# Test related to CCA +tests/models/test_cca_solution.py # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/docs/_autosummary/xeofs.models.ComplexEOF.rst b/docs/_autosummary/xeofs.models.ComplexEOF.rst index dad48cf..5cdb7ee 100644 --- a/docs/_autosummary/xeofs.models.ComplexEOF.rst +++ b/docs/_autosummary/xeofs.models.ComplexEOF.rst @@ -24,6 +24,7 @@ ~ComplexEOF.explained_variance ~ComplexEOF.explained_variance_ratio ~ComplexEOF.fit + ~ComplexEOF.fit_transform ~ComplexEOF.get_params ~ComplexEOF.inverse_transform ~ComplexEOF.scores diff --git a/docs/_autosummary/xeofs.models.ComplexEOFRotator.rst b/docs/_autosummary/xeofs.models.ComplexEOFRotator.rst index 1f5267a..2fe12b8 100644 --- a/docs/_autosummary/xeofs.models.ComplexEOFRotator.rst +++ b/docs/_autosummary/xeofs.models.ComplexEOFRotator.rst @@ -24,6 +24,7 @@ ~ComplexEOFRotator.explained_variance ~ComplexEOFRotator.explained_variance_ratio ~ComplexEOFRotator.fit + ~ComplexEOFRotator.fit_transform ~ComplexEOFRotator.get_params ~ComplexEOFRotator.inverse_transform ~ComplexEOFRotator.scores diff --git a/docs/_autosummary/xeofs.models.ComplexMCA.rst b/docs/_autosummary/xeofs.models.ComplexMCA.rst index 92206b3..8dbc52d 100644 --- a/docs/_autosummary/xeofs.models.ComplexMCA.rst +++ b/docs/_autosummary/xeofs.models.ComplexMCA.rst @@ -33,6 +33,7 @@ ~ComplexMCA.singular_values ~ComplexMCA.squared_covariance ~ComplexMCA.squared_covariance_fraction + ~ComplexMCA.total_covariance ~ComplexMCA.transform diff --git a/docs/_autosummary/xeofs.models.ComplexMCARotator.rst b/docs/_autosummary/xeofs.models.ComplexMCARotator.rst index 518c069..61b4a59 100644 --- a/docs/_autosummary/xeofs.models.ComplexMCARotator.rst +++ b/docs/_autosummary/xeofs.models.ComplexMCARotator.rst @@ -33,6 +33,7 @@ ~ComplexMCARotator.singular_values ~ComplexMCARotator.squared_covariance ~ComplexMCARotator.squared_covariance_fraction + ~ComplexMCARotator.total_covariance ~ComplexMCARotator.transform diff --git a/docs/_autosummary/xeofs.models.EOF.rst b/docs/_autosummary/xeofs.models.EOF.rst index 07069dd..20b35dc 100644 --- a/docs/_autosummary/xeofs.models.EOF.rst +++ b/docs/_autosummary/xeofs.models.EOF.rst @@ -22,6 +22,7 @@ ~EOF.explained_variance ~EOF.explained_variance_ratio ~EOF.fit + ~EOF.fit_transform ~EOF.get_params ~EOF.inverse_transform ~EOF.scores diff --git a/docs/_autosummary/xeofs.models.EOFRotator.rst b/docs/_autosummary/xeofs.models.EOFRotator.rst index adcca8e..93c0df2 100644 --- a/docs/_autosummary/xeofs.models.EOFRotator.rst +++ b/docs/_autosummary/xeofs.models.EOFRotator.rst @@ -22,6 +22,7 @@ ~EOFRotator.explained_variance ~EOFRotator.explained_variance_ratio ~EOFRotator.fit + ~EOFRotator.fit_transform ~EOFRotator.get_params ~EOFRotator.inverse_transform ~EOFRotator.scores diff --git a/docs/_autosummary/xeofs.models.MCA.rst b/docs/_autosummary/xeofs.models.MCA.rst index 59a40b7..b8de506 100644 --- a/docs/_autosummary/xeofs.models.MCA.rst +++ b/docs/_autosummary/xeofs.models.MCA.rst @@ -29,6 +29,7 @@ ~MCA.singular_values ~MCA.squared_covariance ~MCA.squared_covariance_fraction + ~MCA.total_covariance ~MCA.transform diff --git a/docs/_autosummary/xeofs.models.MCARotator.rst b/docs/_autosummary/xeofs.models.MCARotator.rst index eabd0bc..4c604c0 100644 --- a/docs/_autosummary/xeofs.models.MCARotator.rst +++ b/docs/_autosummary/xeofs.models.MCARotator.rst @@ -29,6 +29,7 @@ ~MCARotator.singular_values ~MCARotator.squared_covariance ~MCARotator.squared_covariance_fraction + ~MCARotator.total_covariance ~MCARotator.transform diff --git a/docs/_autosummary/xeofs.models.OPA.rst b/docs/_autosummary/xeofs.models.OPA.rst index 478174a..4d029bd 100644 --- a/docs/_autosummary/xeofs.models.OPA.rst +++ b/docs/_autosummary/xeofs.models.OPA.rst @@ -22,6 +22,7 @@ ~OPA.decorrelation_time ~OPA.filter_patterns ~OPA.fit + ~OPA.fit_transform ~OPA.get_params ~OPA.inverse_transform ~OPA.scores diff --git a/docs/_autosummary/xeofs.validation.EOFBootstrapper.rst b/docs/_autosummary/xeofs.validation.EOFBootstrapper.rst index 902cc3a..8a68c2d 100644 --- a/docs/_autosummary/xeofs.validation.EOFBootstrapper.rst +++ b/docs/_autosummary/xeofs.validation.EOFBootstrapper.rst @@ -22,6 +22,7 @@ ~EOFBootstrapper.explained_variance ~EOFBootstrapper.explained_variance_ratio ~EOFBootstrapper.fit + ~EOFBootstrapper.fit_transform ~EOFBootstrapper.get_params ~EOFBootstrapper.inverse_transform ~EOFBootstrapper.scores diff --git a/docs/auto_examples/1eof/images/sphx_glr_plot_eeof_001.png b/docs/auto_examples/1eof/images/sphx_glr_plot_eeof_001.png new file mode 100644 index 0000000..a6f4eb7 Binary files /dev/null and b/docs/auto_examples/1eof/images/sphx_glr_plot_eeof_001.png differ diff --git a/docs/auto_examples/1eof/images/sphx_glr_plot_eeof_002.png b/docs/auto_examples/1eof/images/sphx_glr_plot_eeof_002.png new file mode 100644 index 0000000..e4d1037 Binary files /dev/null and b/docs/auto_examples/1eof/images/sphx_glr_plot_eeof_002.png differ diff --git a/docs/auto_examples/1eof/images/sphx_glr_plot_eeof_003.png b/docs/auto_examples/1eof/images/sphx_glr_plot_eeof_003.png new file mode 100644 index 0000000..4502176 Binary files /dev/null and b/docs/auto_examples/1eof/images/sphx_glr_plot_eeof_003.png differ diff --git a/docs/auto_examples/1eof/images/sphx_glr_plot_gwpca_001.png b/docs/auto_examples/1eof/images/sphx_glr_plot_gwpca_001.png new file mode 100644 index 0000000..14d3f63 Binary files /dev/null and b/docs/auto_examples/1eof/images/sphx_glr_plot_gwpca_001.png differ diff --git a/docs/auto_examples/1eof/images/sphx_glr_plot_gwpca_002.png b/docs/auto_examples/1eof/images/sphx_glr_plot_gwpca_002.png new file mode 100644 index 0000000..814f2f6 Binary files /dev/null and b/docs/auto_examples/1eof/images/sphx_glr_plot_gwpca_002.png differ diff --git a/docs/auto_examples/1eof/images/thumb/sphx_glr_plot_eeof_thumb.png b/docs/auto_examples/1eof/images/thumb/sphx_glr_plot_eeof_thumb.png new file mode 100644 index 0000000..8797ba3 Binary files /dev/null and b/docs/auto_examples/1eof/images/thumb/sphx_glr_plot_eeof_thumb.png differ diff --git a/docs/auto_examples/1eof/images/thumb/sphx_glr_plot_gwpca_thumb.png b/docs/auto_examples/1eof/images/thumb/sphx_glr_plot_gwpca_thumb.png new file mode 100644 index 0000000..b7a9a0f Binary files /dev/null and b/docs/auto_examples/1eof/images/thumb/sphx_glr_plot_gwpca_thumb.png differ diff --git a/docs/auto_examples/1eof/index.rst b/docs/auto_examples/1eof/index.rst index 3e89c05..7ecc9cd 100644 --- a/docs/auto_examples/1eof/index.rst +++ b/docs/auto_examples/1eof/index.rst @@ -12,6 +12,23 @@
+.. raw:: html + +
+ +.. only:: html + + .. image:: /auto_examples/1eof/images/thumb/sphx_glr_plot_eeof_thumb.png + :alt: + + :ref:`sphx_glr_auto_examples_1eof_plot_eeof.py` + +.. raw:: html + +
Extented EOF analysis
+
+ + .. raw:: html
@@ -114,6 +131,23 @@
+.. raw:: html + +
+ +.. only:: html + + .. image:: /auto_examples/1eof/images/thumb/sphx_glr_plot_gwpca_thumb.png + :alt: + + :ref:`sphx_glr_auto_examples_1eof_plot_gwpca.py` + +.. raw:: html + +
Geographically weighted PCA
+
+ + .. raw:: html
@@ -122,10 +156,12 @@ .. toctree:: :hidden: + /auto_examples/1eof/plot_eeof /auto_examples/1eof/plot_eof-tmode /auto_examples/1eof/plot_eof-smode /auto_examples/1eof/plot_multivariate-eof /auto_examples/1eof/plot_mreof /auto_examples/1eof/plot_rotated_eof /auto_examples/1eof/plot_weighted-eof + /auto_examples/1eof/plot_gwpca diff --git a/docs/auto_examples/1eof/plot_eeof.ipynb b/docs/auto_examples/1eof/plot_eeof.ipynb new file mode 100644 index 0000000..a3a7b80 --- /dev/null +++ b/docs/auto_examples/1eof/plot_eeof.ipynb @@ -0,0 +1,151 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n# Extented EOF analysis\n\nThis example demonstrates Extended EOF (EEOF) analysis on ``xarray`` tutorial \ndata. EEOF analysis, also termed as Multivariate/Multichannel Singular \nSpectrum Analysis, advances traditional EOF analysis to capture propagating \nsignals or oscillations in multivariate datasets. At its core, this \ninvolves the formulation of a lagged covariance matrix that encapsulates \nboth spatial and temporal correlations. Subsequently, this matrix is \ndecomposed to yield its eigenvectors (components) and eigenvalues (explained variance).\n\nLet's begin by setting up the required packages and fetching the data:\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "import xarray as xr\nimport xeofs as xe\nimport matplotlib.pyplot as plt\n\nxr.set_options(display_expand_data=False)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Load the tutorial data.\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "t2m = xr.tutorial.load_dataset(\"air_temperature\").air" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Prior to conducting the EEOF analysis, it's essential to determine the\nstructure of the lagged covariance matrix. This entails defining the time\ndelay ``tau`` and the ``embedding`` dimension. The former signifies the\ninterval between the original and lagged time series, while the latter\ndictates the number of time-lagged copies in the delay-coordinate space,\nrepresenting the system's dynamics.\nFor illustration, using ``tau=4`` and ``embedding=40``, we generate 40\ndelayed versions of the time series, each offset by 4 time steps, resulting\nin a maximum shift of ``tau x embedding = 160``. Given our dataset's\n6-hour intervals, tau = 4 translates to a 24-hour shift.\nIt's obvious that this way of constructing the lagged covariance matrix\nand subsequently decomposing it can be computationally expensive. For example,\ngiven our dataset's dimensions,\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "t2m.shape" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "the extended dataset would have 40 x 25 x 53 = 53000 features\nwhich is much larger than the original dataset's 1325 features.\nTo mitigate this, we can first preprocess the data using PCA / EOF analysis\nand then perform EEOF analysis on the resulting PCA / EOF scores. Here,\nwe'll use ``n_pca_modes=50`` to retain the first 50 PCA modes, so we end\nup with 40 x 50 = 200 (latent) features.\nWith these parameters set, we proceed to instantiate the ``ExtendedEOF``\nmodel and fit our data.\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "model = xe.models.ExtendedEOF(\n n_modes=10, tau=4, embedding=40, n_pca_modes=50, use_coslat=True\n)\nmodel.fit(t2m, dim=\"time\")\nscores = model.scores()\ncomponents = model.components()\ncomponents" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "A notable distinction from standard EOF analysis is the incorporation of an\nextra ``embedding`` dimension in the components. Nonetheless, the\noverarching methodology mirrors traditional EOF practices. The results,\nfor instance, can be assessed by examining the explained variance ratio.\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "model.explained_variance_ratio().plot()\nplt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Additionally, we can look into the scores; let's spotlight mode 4.\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "scores.sel(mode=4).plot()\nplt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In wrapping up, we visualize the corresponding EEOF component of mode 4.\nFor visualization purposes, we'll focus on the component at a specific\nlatitude, in this instance, 60 degrees north.\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "components.sel(mode=4, lat=60).plot()\nplt.show()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.4" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/docs/auto_examples/1eof/plot_eeof.py b/docs/auto_examples/1eof/plot_eeof.py new file mode 100644 index 0000000..2a7ab1d --- /dev/null +++ b/docs/auto_examples/1eof/plot_eeof.py @@ -0,0 +1,83 @@ +""" +Extented EOF analysis +===================== + +This example demonstrates Extended EOF (EEOF) analysis on ``xarray`` tutorial +data. EEOF analysis, also termed as Multivariate/Multichannel Singular +Spectrum Analysis, advances traditional EOF analysis to capture propagating +signals or oscillations in multivariate datasets. At its core, this +involves the formulation of a lagged covariance matrix that encapsulates +both spatial and temporal correlations. Subsequently, this matrix is +decomposed to yield its eigenvectors (components) and eigenvalues (explained variance). + +Let's begin by setting up the required packages and fetching the data: +""" + +import xarray as xr +import xeofs as xe +import matplotlib.pyplot as plt + +xr.set_options(display_expand_data=False) + +# %% +# Load the tutorial data. +t2m = xr.tutorial.load_dataset("air_temperature").air + + +# %% +# Prior to conducting the EEOF analysis, it's essential to determine the +# structure of the lagged covariance matrix. This entails defining the time +# delay ``tau`` and the ``embedding`` dimension. The former signifies the +# interval between the original and lagged time series, while the latter +# dictates the number of time-lagged copies in the delay-coordinate space, +# representing the system's dynamics. +# For illustration, using ``tau=4`` and ``embedding=40``, we generate 40 +# delayed versions of the time series, each offset by 4 time steps, resulting +# in a maximum shift of ``tau x embedding = 160``. Given our dataset's +# 6-hour intervals, tau = 4 translates to a 24-hour shift. +# It's obvious that this way of constructing the lagged covariance matrix +# and subsequently decomposing it can be computationally expensive. For example, +# given our dataset's dimensions, + +t2m.shape + +# %% +# the extended dataset would have 40 x 25 x 53 = 53000 features +# which is much larger than the original dataset's 1325 features. +# To mitigate this, we can first preprocess the data using PCA / EOF analysis +# and then perform EEOF analysis on the resulting PCA / EOF scores. Here, +# we'll use ``n_pca_modes=50`` to retain the first 50 PCA modes, so we end +# up with 40 x 50 = 200 (latent) features. +# With these parameters set, we proceed to instantiate the ``ExtendedEOF`` +# model and fit our data. + +model = xe.models.ExtendedEOF( + n_modes=10, tau=4, embedding=40, n_pca_modes=50, use_coslat=True +) +model.fit(t2m, dim="time") +scores = model.scores() +components = model.components() +components + +# %% +# A notable distinction from standard EOF analysis is the incorporation of an +# extra ``embedding`` dimension in the components. Nonetheless, the +# overarching methodology mirrors traditional EOF practices. The results, +# for instance, can be assessed by examining the explained variance ratio. + +model.explained_variance_ratio().plot() +plt.show() + +# %% +# Additionally, we can look into the scores; let's spotlight mode 4. + +scores.sel(mode=4).plot() +plt.show() + +# %% +# In wrapping up, we visualize the corresponding EEOF component of mode 4. +# For visualization purposes, we'll focus on the component at a specific +# latitude, in this instance, 60 degrees north. + +components.sel(mode=4, lat=60).plot() +plt.show() diff --git a/docs/auto_examples/1eof/plot_eeof.py.md5 b/docs/auto_examples/1eof/plot_eeof.py.md5 new file mode 100644 index 0000000..48d6006 --- /dev/null +++ b/docs/auto_examples/1eof/plot_eeof.py.md5 @@ -0,0 +1 @@ +7f3b66c7aec555c78dde9031213be3ad \ No newline at end of file diff --git a/docs/auto_examples/1eof/plot_eeof.rst b/docs/auto_examples/1eof/plot_eeof.rst new file mode 100644 index 0000000..494c85a --- /dev/null +++ b/docs/auto_examples/1eof/plot_eeof.rst @@ -0,0 +1,693 @@ + +.. DO NOT EDIT. +.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. +.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: +.. "auto_examples/1eof/plot_eeof.py" +.. LINE NUMBERS ARE GIVEN BELOW. + +.. only:: html + + .. note:: + :class: sphx-glr-download-link-note + + :ref:`Go to the end ` + to download the full example code + +.. rst-class:: sphx-glr-example-title + +.. _sphx_glr_auto_examples_1eof_plot_eeof.py: + + +Extented EOF analysis +===================== + +This example demonstrates Extended EOF (EEOF) analysis on ``xarray`` tutorial +data. EEOF analysis, also termed as Multivariate/Multichannel Singular +Spectrum Analysis, advances traditional EOF analysis to capture propagating +signals or oscillations in multivariate datasets. At its core, this +involves the formulation of a lagged covariance matrix that encapsulates +both spatial and temporal correlations. Subsequently, this matrix is +decomposed to yield its eigenvectors (components) and eigenvalues (explained variance). + +Let's begin by setting up the required packages and fetching the data: + +.. GENERATED FROM PYTHON SOURCE LINES 15-22 + +.. code-block:: default + + + import xarray as xr + import xeofs as xe + import matplotlib.pyplot as plt + + xr.set_options(display_expand_data=False) + + + + + +.. rst-class:: sphx-glr-script-out + + .. code-block:: none + + + + + + +.. GENERATED FROM PYTHON SOURCE LINES 23-24 + +Load the tutorial data. + +.. GENERATED FROM PYTHON SOURCE LINES 24-27 + +.. code-block:: default + + t2m = xr.tutorial.load_dataset("air_temperature").air + + + + + + + + + +.. GENERATED FROM PYTHON SOURCE LINES 28-41 + +Prior to conducting the EEOF analysis, it's essential to determine the +structure of the lagged covariance matrix. This entails defining the time +delay ``tau`` and the ``embedding`` dimension. The former signifies the +interval between the original and lagged time series, while the latter +dictates the number of time-lagged copies in the delay-coordinate space, +representing the system's dynamics. +For illustration, using ``tau=4`` and ``embedding=40``, we generate 40 +delayed versions of the time series, each offset by 4 time steps, resulting +in a maximum shift of ``tau x embedding = 160``. Given our dataset's +6-hour intervals, tau = 4 translates to a 24-hour shift. +It's obvious that this way of constructing the lagged covariance matrix +and subsequently decomposing it can be computationally expensive. For example, +given our dataset's dimensions, + +.. GENERATED FROM PYTHON SOURCE LINES 41-44 + +.. code-block:: default + + + t2m.shape + + + + + +.. rst-class:: sphx-glr-script-out + + .. code-block:: none + + + (2920, 25, 53) + + + +.. GENERATED FROM PYTHON SOURCE LINES 45-53 + +the extended dataset would have 40 x 25 x 53 = 53000 features +which is much larger than the original dataset's 1325 features. +To mitigate this, we can first preprocess the data using PCA / EOF analysis +and then perform EEOF analysis on the resulting PCA / EOF scores. Here, +we'll use ``n_pca_modes=50`` to retain the first 50 PCA modes, so we end +up with 40 x 50 = 200 (latent) features. +With these parameters set, we proceed to instantiate the ``ExtendedEOF`` +model and fit our data. + +.. GENERATED FROM PYTHON SOURCE LINES 53-62 + +.. code-block:: default + + + model = xe.models.ExtendedEOF( + n_modes=10, tau=4, embedding=40, n_pca_modes=50, use_coslat=True + ) + model.fit(t2m, dim="time") + scores = model.scores() + components = model.components() + components + + + + + + +.. raw:: html + +
+
+ + + + + + + + + + + + + + +
<xarray.DataArray 'components' (mode: 10, embedding: 40, lat: 25, lon: 53)>
+    0.0003857 0.0003649 0.0003575 0.0003567 ... -0.001347 -0.0009396 -0.0005447
+    Coordinates:
+      * lat        (lat) float32 15.0 17.5 20.0 22.5 25.0 ... 67.5 70.0 72.5 75.0
+      * lon        (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0
+      * embedding  (embedding) int64 0 4 8 12 16 20 24 ... 136 140 144 148 152 156
+      * mode       (mode) int64 1 2 3 4 5 6 7 8 9 10
+    Attributes:
+        model:        Extended EOF Analysis
+        n_modes:      10
+        center:       True
+        standardize:  False
+        use_coslat:   True
+        solver:       auto
+        software:     xeofs
+        version:      1.0.3
+        date:         2023-10-23 11:30:31
+
+
+
+ +.. GENERATED FROM PYTHON SOURCE LINES 63-67 + +A notable distinction from standard EOF analysis is the incorporation of an +extra ``embedding`` dimension in the components. Nonetheless, the +overarching methodology mirrors traditional EOF practices. The results, +for instance, can be assessed by examining the explained variance ratio. + +.. GENERATED FROM PYTHON SOURCE LINES 67-71 + +.. code-block:: default + + + model.explained_variance_ratio().plot() + plt.show() + + + + +.. image-sg:: /auto_examples/1eof/images/sphx_glr_plot_eeof_001.png + :alt: plot eeof + :srcset: /auto_examples/1eof/images/sphx_glr_plot_eeof_001.png + :class: sphx-glr-single-img + + + + + +.. GENERATED FROM PYTHON SOURCE LINES 72-73 + +Additionally, we can look into the scores; let's spotlight mode 4. + +.. GENERATED FROM PYTHON SOURCE LINES 73-77 + +.. code-block:: default + + + scores.sel(mode=4).plot() + plt.show() + + + + +.. image-sg:: /auto_examples/1eof/images/sphx_glr_plot_eeof_002.png + :alt: mode = 4 + :srcset: /auto_examples/1eof/images/sphx_glr_plot_eeof_002.png + :class: sphx-glr-single-img + + + + + +.. GENERATED FROM PYTHON SOURCE LINES 78-81 + +In wrapping up, we visualize the corresponding EEOF component of mode 4. +For visualization purposes, we'll focus on the component at a specific +latitude, in this instance, 60 degrees north. + +.. GENERATED FROM PYTHON SOURCE LINES 81-84 + +.. code-block:: default + + + components.sel(mode=4, lat=60).plot() + plt.show() + + + +.. image-sg:: /auto_examples/1eof/images/sphx_glr_plot_eeof_003.png + :alt: lat = 60.0 [degrees_north], mode = 4 + :srcset: /auto_examples/1eof/images/sphx_glr_plot_eeof_003.png + :class: sphx-glr-single-img + + + + + + +.. rst-class:: sphx-glr-timing + + **Total running time of the script:** (0 minutes 3.585 seconds) + + +.. _sphx_glr_download_auto_examples_1eof_plot_eeof.py: + +.. only:: html + + .. container:: sphx-glr-footer sphx-glr-footer-example + + + + + .. container:: sphx-glr-download sphx-glr-download-python + + :download:`Download Python source code: plot_eeof.py ` + + .. container:: sphx-glr-download sphx-glr-download-jupyter + + :download:`Download Jupyter notebook: plot_eeof.ipynb ` + + +.. only:: html + + .. rst-class:: sphx-glr-signature + + `Gallery generated by Sphinx-Gallery `_ diff --git a/docs/auto_examples/1eof/plot_eeof_codeobj.pickle b/docs/auto_examples/1eof/plot_eeof_codeobj.pickle new file mode 100644 index 0000000..9f9912c Binary files /dev/null and b/docs/auto_examples/1eof/plot_eeof_codeobj.pickle differ diff --git a/docs/auto_examples/1eof/plot_gwpca.ipynb b/docs/auto_examples/1eof/plot_gwpca.ipynb new file mode 100644 index 0000000..de4c22a --- /dev/null +++ b/docs/auto_examples/1eof/plot_gwpca.ipynb @@ -0,0 +1,169 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n# Geographically weighted PCA\nGeographically Weighted Principal Component Analysis (GWPCA) is a spatial analysis method that identifies and visualizes local spatial patterns and relationships in multivariate datasets across various geographic areas. It operates by applying PCA within a moving window over a geographical region, which enables the extraction of local principal components that can differ across locations.\n\nTIn this demonstration, we'll apply GWPCA to a dataset detailing the chemical compositions of soils from countries around the Baltic Sea [1]_. This example is inspired by a tutorial originally crafted and published by Chris Brunsdon [2]_. \nThe dataset comprises 10 variables (chemical elements) and spans 768 samples. \nHere, each sample refers to a pair of latitude and longitude coordinates, representing specific sampling stations.\n\n.. [1] Reimann, C. et al. Baltic soil survey: total concentrations of major and selected trace elements in arable soils from 10 countries around the Baltic Sea. Science of The Total Environment 257, 155\u2013170 (2000).\n.. [2] https://rpubs.com/chrisbrunsdon/99675\n\n\n\n

Note

The dataset we're using is found in the R package \n [mvoutlier](https://cran.r-project.org/web/packages/mvoutlier/mvoutlier.pdf). \n To access it, we'll employ the Python package \n [rpy2](https://rpy2.github.io/doc/latest/html/index.html) which facilitates \n interaction with R packages from within Python.

\n\n

Note

Presently, there's no support for ``xarray.Dataset`` lacking an explicit feature dimension. \n As a workaround, ``xarray.DataArray.to_array`` can be used to convert the ``Dataset`` to an ``DataArray``.

\n\n

Warning

Bear in mind that GWPCA requires significant computational power.\n The ``xeofs`` implementation is optimized for CPU efficiency and is best suited \n for smaller to medium data sets. For more extensive datasets where parallel processing becomes essential,\n it's advisable to turn to the R package [GWmodel](https://cran.r-project.org/web/packages/GWmodel/GWmodel.pdf).\n This package harnesses CUDA to enable GPU-accelerated GWPCA for optimized performance.

\n\n\nLet's import the necessary packages.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "# For the analysis\nimport numpy as np\nimport xarray as xr\nimport xeofs as xe\n\n# For visualization\nimport matplotlib.pyplot as plt\nimport seaborn as sns\n\n# For accessing R packages\nimport rpy2.robjects as ro\nfrom rpy2.robjects.packages import importr\nfrom rpy2.robjects import pandas2ri" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next, we'll install the R package [mvoutlier](https://cran.r-project.org/web/packages/mvoutlier/mvoutlier.pdf)\nusing the [rpy2](https://rpy2.github.io/doc/latest/html/index.html) package.\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "xr.set_options(display_expand_data=False)\nutils = importr(\"utils\")\nutils.chooseCRANmirror(ind=1)\nutils.install_packages(\"mvoutlier\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's load the dataset and convert it into a ``pandas.DataFrame``.\nAlongside, we'll also load the background data that outlines the borders of countries\nin the Baltic Sea region. This will help us visually represent the GWPCA results.\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ro.r(\n \"\"\"\n require(\"mvoutlier\")\n data(bsstop)\n Data <- bsstop[,1:14]\n background <- bss.background\n \"\"\"\n)\nwith (ro.default_converter + pandas2ri.converter).context():\n data_df = ro.conversion.get_conversion().rpy2py(ro.r[\"Data\"])\n background_df = ro.conversion.get_conversion().rpy2py(ro.r[\"background\"])\ndata_df.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Since ``xeofs`` uses ``xarray``, we convert the data into an ``xarray.DataArray``.\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "data_df = data_df.rename(columns={\"ID\": \"station\"}).set_index(\"station\")\ndata = data_df.to_xarray()\ndata = data.rename({\"XCOO\": \"x\", \"YCOO\": \"y\"})\ndata = data.set_index(station=(\"x\", \"y\"))\ndata = data.drop_vars(\"CNo\")\nda = data.to_array(dim=\"element\")\nda" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's dive into the GWPCA. First, initialize a ``GWPCA`` instance and fit it to the data.\nThe ``station`` dimension serves as our sample dimension, along which the local PCAs will be applied.\nSince these PCAs need to gauge distances to adjacent stations, we must specify\na distance metric. Our station data includes coordinates in meters, so we'll\nchoose the ``euclidean`` metric. If you have coordinates in degrees (like\nlatitude and longitude), choose the ``haversine`` metric instead.\nWe're also using a ``bisquare`` kernel with a bandwidth of 1000 km. Note that the\nbandwidth unit always follows input data (which is in meters here),\nexcept when using the ``haversine`` metric, which always gives distances in\nkilometers. Lastly, we'll standardize the input to ensure consistent scales\nfor the chemical elements.\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "gwpca = xe.models.GWPCA(\n n_modes=5,\n standardize=True,\n metric=\"euclidean\",\n kernel=\"bisquare\",\n bandwidth=1000000.0,\n)\ngwpca.fit(da, \"station\")\ngwpca.components()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The ``components`` method returns the local principal components for each station. Note that the\ndimensionality of the returned array is ``[station, element, mode]``, so in practice we don't really have\nreduced the dimensionality of the data set. However, we can\nextract the largest locally weighted components for each station which tells us which chemical elements\ndominate the local PCAs.\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "llwc = gwpca.largest_locally_weighted_components()\nllwc" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's visualize the spatial patterns of the chemical elements.\nAs the stations are positioned on a irregular grid, we'll transform the\n``llwc`` ``DataArray`` into a ``pandas.DataFrame``. After that, we can easily visualize\nit using the ``scatter`` method.\nFor demonstation, we'll concentrate on the first mode:\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "llwc1_df = llwc.sel(mode=1).to_dataframe()\n\nelements = da.element.values\nn_elements = len(elements)\ncolors = np.arange(n_elements)\ncol_dict = {el: col for el, col in zip(elements, colors)}\n\nllwc1_df[\"colors\"] = llwc1_df[\"largest_locally_weighted_components\"].map(col_dict)\ncmap = sns.color_palette(\"tab10\", n_colors=n_elements, as_cmap=True)\n\n\nfig = plt.figure(figsize=(10, 10))\nax = fig.add_subplot(111)\nbackground_df.plot.scatter(ax=ax, x=\"V1\", y=\"V2\", color=\".3\", marker=\".\", s=1)\ns = ax.scatter(\n x=llwc1_df[\"x\"],\n y=llwc1_df[\"y\"],\n c=llwc1_df[\"colors\"],\n ec=\"w\",\n s=40,\n cmap=cmap,\n vmin=-0.5,\n vmax=n_elements - 0.5,\n)\ncbar = fig.colorbar(mappable=s, ax=ax, label=\"Largest locally weighted component\")\ncbar.set_ticks(colors)\ncbar.set_ticklabels(elements)\nax.set_title(\"Largest locally weighted element\", loc=\"left\", weight=800)\nplt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In the final step, let's examine the explained variance. Like standard PCA,\nthis gives us insight into the variance explained by each mode. But with a\nlocal PCA for every station, the explained variance varies spatially. Notably,\nthe first mode's explained variance differs across countries, ranging from\nroughly 40% to 70%.\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "exp_var_ratio = gwpca.explained_variance_ratio()\nevr1_df = exp_var_ratio.sel(mode=1).to_dataframe()\n\nfig = plt.figure(figsize=(10, 10))\nax = fig.add_subplot(111)\nbackground_df.plot.scatter(ax=ax, x=\"V1\", y=\"V2\", color=\".3\", marker=\".\", s=1)\nevr1_df.plot.scatter(\n ax=ax, x=\"x\", y=\"y\", c=\"explained_variance_ratio\", vmin=0.4, vmax=0.7\n)\nax.set_title(\"Fraction of locally explained variance\", loc=\"left\", weight=800)\nplt.show()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.4" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/docs/auto_examples/1eof/plot_gwpca.py b/docs/auto_examples/1eof/plot_gwpca.py new file mode 100644 index 0000000..a084283 --- /dev/null +++ b/docs/auto_examples/1eof/plot_gwpca.py @@ -0,0 +1,174 @@ +""" +Geographically weighted PCA +=========================== +Geographically Weighted Principal Component Analysis (GWPCA) is a spatial analysis method that identifies and visualizes local spatial patterns and relationships in multivariate datasets across various geographic areas. It operates by applying PCA within a moving window over a geographical region, which enables the extraction of local principal components that can differ across locations. + +TIn this demonstration, we'll apply GWPCA to a dataset detailing the chemical compositions of soils from countries around the Baltic Sea [1]_. This example is inspired by a tutorial originally crafted and published by Chris Brunsdon [2]_. +The dataset comprises 10 variables (chemical elements) and spans 768 samples. +Here, each sample refers to a pair of latitude and longitude coordinates, representing specific sampling stations. + +.. [1] Reimann, C. et al. Baltic soil survey: total concentrations of major and selected trace elements in arable soils from 10 countries around the Baltic Sea. Science of The Total Environment 257, 155–170 (2000). +.. [2] https://rpubs.com/chrisbrunsdon/99675 + + + +.. note:: The dataset we're using is found in the R package + `mvoutlier `_. + To access it, we'll employ the Python package + `rpy2 `_ which facilitates + interaction with R packages from within Python. + +.. note:: Presently, there's no support for ``xarray.Dataset`` lacking an explicit feature dimension. + As a workaround, ``xarray.DataArray.to_array`` can be used to convert the ``Dataset`` to an ``DataArray``. + +.. warning:: Bear in mind that GWPCA requires significant computational power. + The ``xeofs`` implementation is optimized for CPU efficiency and is best suited + for smaller to medium data sets. For more extensive datasets where parallel processing becomes essential, + it's advisable to turn to the R package `GWmodel `_. + This package harnesses CUDA to enable GPU-accelerated GWPCA for optimized performance. + + +Let's import the necessary packages. +""" +# For the analysis +import numpy as np +import xarray as xr +import xeofs as xe + +# For visualization +import matplotlib.pyplot as plt +import seaborn as sns + +# For accessing R packages +import rpy2.robjects as ro +from rpy2.robjects.packages import importr +from rpy2.robjects import pandas2ri + +# %% +# Next, we'll install the R package `mvoutlier `_ +# using the `rpy2 `_ package. + +xr.set_options(display_expand_data=False) +utils = importr("utils") +utils.chooseCRANmirror(ind=1) +utils.install_packages("mvoutlier") + +# %% +# Let's load the dataset and convert it into a ``pandas.DataFrame``. +# Alongside, we'll also load the background data that outlines the borders of countries +# in the Baltic Sea region. This will help us visually represent the GWPCA results. + +ro.r( + """ + require("mvoutlier") + data(bsstop) + Data <- bsstop[,1:14] + background <- bss.background + """ +) +with (ro.default_converter + pandas2ri.converter).context(): + data_df = ro.conversion.get_conversion().rpy2py(ro.r["Data"]) + background_df = ro.conversion.get_conversion().rpy2py(ro.r["background"]) +data_df.head() + +# %% +# Since ``xeofs`` uses ``xarray``, we convert the data into an ``xarray.DataArray``. + +data_df = data_df.rename(columns={"ID": "station"}).set_index("station") +data = data_df.to_xarray() +data = data.rename({"XCOO": "x", "YCOO": "y"}) +data = data.set_index(station=("x", "y")) +data = data.drop_vars("CNo") +da = data.to_array(dim="element") +da + +# %% +# Let's dive into the GWPCA. First, initialize a ``GWPCA`` instance and fit it to the data. +# The ``station`` dimension serves as our sample dimension, along which the local PCAs will be applied. +# Since these PCAs need to gauge distances to adjacent stations, we must specify +# a distance metric. Our station data includes coordinates in meters, so we'll +# choose the ``euclidean`` metric. If you have coordinates in degrees (like +# latitude and longitude), choose the ``haversine`` metric instead. +# We're also using a ``bisquare`` kernel with a bandwidth of 1000 km. Note that the +# bandwidth unit always follows input data (which is in meters here), +# except when using the ``haversine`` metric, which always gives distances in +# kilometers. Lastly, we'll standardize the input to ensure consistent scales +# for the chemical elements. + +gwpca = xe.models.GWPCA( + n_modes=5, + standardize=True, + metric="euclidean", + kernel="bisquare", + bandwidth=1000000.0, +) +gwpca.fit(da, "station") +gwpca.components() + + +# %% +# The ``components`` method returns the local principal components for each station. Note that the +# dimensionality of the returned array is ``[station, element, mode]``, so in practice we don't really have +# reduced the dimensionality of the data set. However, we can +# extract the largest locally weighted components for each station which tells us which chemical elements +# dominate the local PCAs. + +llwc = gwpca.largest_locally_weighted_components() +llwc + +# %% +# Let's visualize the spatial patterns of the chemical elements. +# As the stations are positioned on a irregular grid, we'll transform the +# ``llwc`` ``DataArray`` into a ``pandas.DataFrame``. After that, we can easily visualize +# it using the ``scatter`` method. +# For demonstation, we'll concentrate on the first mode: + +llwc1_df = llwc.sel(mode=1).to_dataframe() + +elements = da.element.values +n_elements = len(elements) +colors = np.arange(n_elements) +col_dict = {el: col for el, col in zip(elements, colors)} + +llwc1_df["colors"] = llwc1_df["largest_locally_weighted_components"].map(col_dict) +cmap = sns.color_palette("tab10", n_colors=n_elements, as_cmap=True) + + +fig = plt.figure(figsize=(10, 10)) +ax = fig.add_subplot(111) +background_df.plot.scatter(ax=ax, x="V1", y="V2", color=".3", marker=".", s=1) +s = ax.scatter( + x=llwc1_df["x"], + y=llwc1_df["y"], + c=llwc1_df["colors"], + ec="w", + s=40, + cmap=cmap, + vmin=-0.5, + vmax=n_elements - 0.5, +) +cbar = fig.colorbar(mappable=s, ax=ax, label="Largest locally weighted component") +cbar.set_ticks(colors) +cbar.set_ticklabels(elements) +ax.set_title("Largest locally weighted element", loc="left", weight=800) +plt.show() + +# %% +# In the final step, let's examine the explained variance. Like standard PCA, +# this gives us insight into the variance explained by each mode. But with a +# local PCA for every station, the explained variance varies spatially. Notably, +# the first mode's explained variance differs across countries, ranging from +# roughly 40% to 70%. + + +exp_var_ratio = gwpca.explained_variance_ratio() +evr1_df = exp_var_ratio.sel(mode=1).to_dataframe() + +fig = plt.figure(figsize=(10, 10)) +ax = fig.add_subplot(111) +background_df.plot.scatter(ax=ax, x="V1", y="V2", color=".3", marker=".", s=1) +evr1_df.plot.scatter( + ax=ax, x="x", y="y", c="explained_variance_ratio", vmin=0.4, vmax=0.7 +) +ax.set_title("Fraction of locally explained variance", loc="left", weight=800) +plt.show() diff --git a/docs/auto_examples/1eof/plot_gwpca.py.md5 b/docs/auto_examples/1eof/plot_gwpca.py.md5 new file mode 100644 index 0000000..6e37a98 --- /dev/null +++ b/docs/auto_examples/1eof/plot_gwpca.py.md5 @@ -0,0 +1 @@ +958e12c0ca3bfc03fe27e2e22362e165 \ No newline at end of file diff --git a/docs/auto_examples/1eof/plot_gwpca.rst b/docs/auto_examples/1eof/plot_gwpca.rst new file mode 100644 index 0000000..3f6b009 --- /dev/null +++ b/docs/auto_examples/1eof/plot_gwpca.rst @@ -0,0 +1,1812 @@ + +.. DO NOT EDIT. +.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. +.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: +.. "auto_examples/1eof/plot_gwpca.py" +.. LINE NUMBERS ARE GIVEN BELOW. + +.. only:: html + + .. note:: + :class: sphx-glr-download-link-note + + :ref:`Go to the end ` + to download the full example code + +.. rst-class:: sphx-glr-example-title + +.. _sphx_glr_auto_examples_1eof_plot_gwpca.py: + + +Geographically weighted PCA +=========================== +Geographically Weighted Principal Component Analysis (GWPCA) is a spatial analysis method that identifies and visualizes local spatial patterns and relationships in multivariate datasets across various geographic areas. It operates by applying PCA within a moving window over a geographical region, which enables the extraction of local principal components that can differ across locations. + +TIn this demonstration, we'll apply GWPCA to a dataset detailing the chemical compositions of soils from countries around the Baltic Sea [1]_. This example is inspired by a tutorial originally crafted and published by Chris Brunsdon [2]_. +The dataset comprises 10 variables (chemical elements) and spans 768 samples. +Here, each sample refers to a pair of latitude and longitude coordinates, representing specific sampling stations. + +.. [1] Reimann, C. et al. Baltic soil survey: total concentrations of major and selected trace elements in arable soils from 10 countries around the Baltic Sea. Science of The Total Environment 257, 155–170 (2000). +.. [2] https://rpubs.com/chrisbrunsdon/99675 + + + +.. note:: The dataset we're using is found in the R package + `mvoutlier `_. + To access it, we'll employ the Python package + `rpy2 `_ which facilitates + interaction with R packages from within Python. + +.. note:: Presently, there's no support for ``xarray.Dataset`` lacking an explicit feature dimension. + As a workaround, ``xarray.DataArray.to_array`` can be used to convert the ``Dataset`` to an ``DataArray``. + +.. warning:: Bear in mind that GWPCA requires significant computational power. + The ``xeofs`` implementation is optimized for CPU efficiency and is best suited + for smaller to medium data sets. For more extensive datasets where parallel processing becomes essential, + it's advisable to turn to the R package `GWmodel `_. + This package harnesses CUDA to enable GPU-accelerated GWPCA for optimized performance. + + +Let's import the necessary packages. + +.. GENERATED FROM PYTHON SOURCE LINES 33-47 + +.. code-block:: default + + # For the analysis + import numpy as np + import xarray as xr + import xeofs as xe + + # For visualization + import matplotlib.pyplot as plt + import seaborn as sns + + # For accessing R packages + import rpy2.robjects as ro + from rpy2.robjects.packages import importr + from rpy2.robjects import pandas2ri + + + + + + + + +.. GENERATED FROM PYTHON SOURCE LINES 48-50 + +Next, we'll install the R package `mvoutlier `_ +using the `rpy2 `_ package. + +.. GENERATED FROM PYTHON SOURCE LINES 50-56 + +.. code-block:: default + + + xr.set_options(display_expand_data=False) + utils = importr("utils") + utils.chooseCRANmirror(ind=1) + utils.install_packages("mvoutlier") + + + + + +.. rst-class:: sphx-glr-script-out + + .. code-block:: none + + R[write to console]: trying URL 'https://cloud.r-project.org/src/contrib/mvoutlier_2.1.1.tar.gz' + + R[write to console]: Content type 'application/x-gzip' + R[write to console]: length 476636 bytes (465 KB) + + R[write to console]: = + R[write to console]: = + R[write to console]: = + R[write to console]: = + R[write to console]: = + R[write to console]: = + R[write to console]: = + R[write to console]: = + R[write to console]: = + R[write to console]: = + R[write to console]: = + R[write to console]: = + R[write to console]: = + R[write to console]: = + R[write to console]: = + R[write to console]: = + R[write to console]: = + R[write to console]: = + R[write to console]: = + R[write to console]: = + R[write to console]: = + R[write to console]: = + R[write to console]: = + R[write to console]: = + R[write to console]: = + R[write to console]: = + R[write to console]: = + R[write to console]: = + R[write to console]: = + R[write to console]: = + R[write to console]: = + R[write to console]: = + R[write to console]: = + R[write to console]: = + R[write to console]: = + R[write to console]: = + R[write to console]: = + R[write to console]: = + R[write to console]: = + R[write to console]: = + R[write to console]: = + R[write to console]: = + R[write to console]: = + R[write to console]: = + R[write to console]: = + R[write to console]: = + R[write to console]: = + R[write to console]: = + R[write to console]: = + R[write to console]: = + R[write to console]: + + R[write to console]: downloaded 465 KB + + + R[write to console]: + + R[write to console]: + R[write to console]: The downloaded source packages are in + ‘/tmp/RtmpZTx0wl/downloaded_packages’ + R[write to console]: + R[write to console]: + + R[write to console]: Updating HTML index of packages in '.Library' + + R[write to console]: Making 'packages.html' ... + R[write to console]: done + + + [0] + + + +.. GENERATED FROM PYTHON SOURCE LINES 57-60 + +Let's load the dataset and convert it into a ``pandas.DataFrame``. +Alongside, we'll also load the background data that outlines the borders of countries +in the Baltic Sea region. This will help us visually represent the GWPCA results. + +.. GENERATED FROM PYTHON SOURCE LINES 60-74 + +.. code-block:: default + + + ro.r( + """ + require("mvoutlier") + data(bsstop) + Data <- bsstop[,1:14] + background <- bss.background + """ + ) + with (ro.default_converter + pandas2ri.converter).context(): + data_df = ro.conversion.get_conversion().rpy2py(ro.r["Data"]) + background_df = ro.conversion.get_conversion().rpy2py(ro.r["background"]) + data_df.head() + + + + + +.. rst-class:: sphx-glr-script-out + + .. code-block:: none + + R[write to console]: Loading required package: mvoutlier + + R[write to console]: Loading required package: sgeostat + + + +.. raw:: html + +
+
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
IDCNoXCOOYCOOSiO2_TTiO2_TAl2O3_TFe2O3_TMnO_TMgO_TCaO_TNa2O_TK2O_TP2O5_T
15001.060.0-619656.56805304.143.611.29013.0712.250.1673.222.481.142.010.481
25002.0120.0214714.17745546.658.730.91314.786.480.1052.473.082.191.780.298
35003.033.0-368415.57065039.258.140.90211.895.700.1262.443.172.131.160.408
45004.039.0226609.06922431.043.980.52410.004.080.0521.001.371.601.820.395
55005.0103.0544050.07808760.060.900.70213.206.370.0792.593.132.971.350.139
+
+
+
+
+ +.. GENERATED FROM PYTHON SOURCE LINES 75-76 + +Since ``xeofs`` uses ``xarray``, we convert the data into an ``xarray.DataArray``. + +.. GENERATED FROM PYTHON SOURCE LINES 76-85 + +.. code-block:: default + + + data_df = data_df.rename(columns={"ID": "station"}).set_index("station") + data = data_df.to_xarray() + data = data.rename({"XCOO": "x", "YCOO": "y"}) + data = data.set_index(station=("x", "y")) + data = data.drop_vars("CNo") + da = data.to_array(dim="element") + da + + + + + + +.. raw:: html + +
+
+ + + + + + + + + + + + + + +
<xarray.DataArray (element: 10, station: 768)>
+    43.61 58.73 58.14 43.98 60.9 54.0 82.72 ... 0.196 0.202 0.207 0.109 0.141 0.185
+    Coordinates:
+      * station  (station) object MultiIndex
+      * x        (station) float64 -6.197e+05 2.147e+05 ... -2.82e+05 -1.273e+05
+      * y        (station) float64 6.805e+06 7.746e+06 ... 5.796e+06 6.523e+06
+      * element  (element) object 'SiO2_T' 'TiO2_T' 'Al2O3_T' ... 'K2O_T' 'P2O5_T'
+
+
+
+ +.. GENERATED FROM PYTHON SOURCE LINES 86-97 + +Let's dive into the GWPCA. First, initialize a ``GWPCA`` instance and fit it to the data. +The ``station`` dimension serves as our sample dimension, along which the local PCAs will be applied. +Since these PCAs need to gauge distances to adjacent stations, we must specify +a distance metric. Our station data includes coordinates in meters, so we'll +choose the ``euclidean`` metric. If you have coordinates in degrees (like +latitude and longitude), choose the ``haversine`` metric instead. +We're also using a ``bisquare`` kernel with a bandwidth of 1000 km. Note that the +bandwidth unit always follows input data (which is in meters here), +except when using the ``haversine`` metric, which always gives distances in +kilometers. Lastly, we'll standardize the input to ensure consistent scales +for the chemical elements. + +.. GENERATED FROM PYTHON SOURCE LINES 97-109 + +.. code-block:: default + + + gwpca = xe.models.GWPCA( + n_modes=5, + standardize=True, + metric="euclidean", + kernel="bisquare", + bandwidth=1000000.0, + ) + gwpca.fit(da, "station") + gwpca.components() + + + + + + + +.. raw:: html + +
+
+ + + + + + + + + + + + + + +
<xarray.DataArray 'components' (mode: 5, element: 10, station: 768)>
+    0.1813 -0.3584 0.1243 0.2 -0.3812 ... -0.1229 0.2865 -0.4732 -0.4197 -0.4249
+    Coordinates:
+      * element  (element) object 'SiO2_T' 'TiO2_T' 'Al2O3_T' ... 'K2O_T' 'P2O5_T'
+      * mode     (mode) int64 1 2 3 4 5
+      * station  (station) object MultiIndex
+      * x        (station) float64 -6.197e+05 2.147e+05 ... -2.82e+05 -1.273e+05
+      * y        (station) float64 6.805e+06 7.746e+06 ... 5.796e+06 6.523e+06
+    Attributes:
+        model:        GWPCA
+        n_modes:      5
+        center:       True
+        standardize:  True
+        use_coslat:   False
+        solver:       auto
+        software:     xeofs
+        version:      1.0.3
+        date:         2023-10-21 13:16:29
+
+
+
+ +.. GENERATED FROM PYTHON SOURCE LINES 110-115 + +The ``components`` method returns the local principal components for each station. Note that the +dimensionality of the returned array is ``[station, element, mode]``, so in practice we don't really have +reduced the dimensionality of the data set. However, we can +extract the largest locally weighted components for each station which tells us which chemical elements +dominate the local PCAs. + +.. GENERATED FROM PYTHON SOURCE LINES 115-119 + +.. code-block:: default + + + llwc = gwpca.largest_locally_weighted_components() + llwc + + + + + + +.. raw:: html + +
+
+ + + + + + + + + + + + + + +
<xarray.DataArray 'largest_locally_weighted_components' (mode: 5, station: 768)>
+    'MgO_T' 'Al2O3_T' 'MgO_T' 'TiO2_T' ... 'K2O_T' 'Fe2O3_T' 'Fe2O3_T' 'CaO_T'
+    Coordinates:
+      * mode     (mode) int64 1 2 3 4 5
+      * station  (station) object MultiIndex
+      * x        (station) float64 -6.197e+05 2.147e+05 ... -2.82e+05 -1.273e+05
+      * y        (station) float64 6.805e+06 7.746e+06 ... 5.796e+06 6.523e+06
+
+
+
+ +.. GENERATED FROM PYTHON SOURCE LINES 120-125 + +Let's visualize the spatial patterns of the chemical elements. +As the stations are positioned on a irregular grid, we'll transform the +``llwc`` ``DataArray`` into a ``pandas.DataFrame``. After that, we can easily visualize +it using the ``scatter`` method. +For demonstation, we'll concentrate on the first mode: + +.. GENERATED FROM PYTHON SOURCE LINES 125-156 + +.. code-block:: default + + + llwc1_df = llwc.sel(mode=1).to_dataframe() + + elements = da.element.values + n_elements = len(elements) + colors = np.arange(n_elements) + col_dict = {el: col for el, col in zip(elements, colors)} + + llwc1_df["colors"] = llwc1_df["largest_locally_weighted_components"].map(col_dict) + cmap = sns.color_palette("tab10", n_colors=n_elements, as_cmap=True) + + + fig = plt.figure(figsize=(10, 10)) + ax = fig.add_subplot(111) + background_df.plot.scatter(ax=ax, x="V1", y="V2", color=".3", marker=".", s=1) + s = ax.scatter( + x=llwc1_df["x"], + y=llwc1_df["y"], + c=llwc1_df["colors"], + ec="w", + s=40, + cmap=cmap, + vmin=-0.5, + vmax=n_elements - 0.5, + ) + cbar = fig.colorbar(mappable=s, ax=ax, label="Largest locally weighted component") + cbar.set_ticks(colors) + cbar.set_ticklabels(elements) + ax.set_title("Largest locally weighted element", loc="left", weight=800) + plt.show() + + + + +.. image-sg:: /auto_examples/1eof/images/sphx_glr_plot_gwpca_001.png + :alt: Largest locally weighted element + :srcset: /auto_examples/1eof/images/sphx_glr_plot_gwpca_001.png + :class: sphx-glr-single-img + + + + + +.. GENERATED FROM PYTHON SOURCE LINES 157-162 + +In the final step, let's examine the explained variance. Like standard PCA, +this gives us insight into the variance explained by each mode. But with a +local PCA for every station, the explained variance varies spatially. Notably, +the first mode's explained variance differs across countries, ranging from +roughly 40% to 70%. + +.. GENERATED FROM PYTHON SOURCE LINES 162-175 + +.. code-block:: default + + + + exp_var_ratio = gwpca.explained_variance_ratio() + evr1_df = exp_var_ratio.sel(mode=1).to_dataframe() + + fig = plt.figure(figsize=(10, 10)) + ax = fig.add_subplot(111) + background_df.plot.scatter(ax=ax, x="V1", y="V2", color=".3", marker=".", s=1) + evr1_df.plot.scatter( + ax=ax, x="x", y="y", c="explained_variance_ratio", vmin=0.4, vmax=0.7 + ) + ax.set_title("Fraction of locally explained variance", loc="left", weight=800) + plt.show() + + + +.. image-sg:: /auto_examples/1eof/images/sphx_glr_plot_gwpca_002.png + :alt: Fraction of locally explained variance + :srcset: /auto_examples/1eof/images/sphx_glr_plot_gwpca_002.png + :class: sphx-glr-single-img + + + + + + +.. rst-class:: sphx-glr-timing + + **Total running time of the script:** (0 minutes 33.235 seconds) + + +.. _sphx_glr_download_auto_examples_1eof_plot_gwpca.py: + +.. only:: html + + .. container:: sphx-glr-footer sphx-glr-footer-example + + + + + .. container:: sphx-glr-download sphx-glr-download-python + + :download:`Download Python source code: plot_gwpca.py ` + + .. container:: sphx-glr-download sphx-glr-download-jupyter + + :download:`Download Jupyter notebook: plot_gwpca.ipynb ` + + +.. only:: html + + .. rst-class:: sphx-glr-signature + + `Gallery generated by Sphinx-Gallery `_ diff --git a/docs/auto_examples/1eof/plot_gwpca_codeobj.pickle b/docs/auto_examples/1eof/plot_gwpca_codeobj.pickle new file mode 100644 index 0000000..2932321 Binary files /dev/null and b/docs/auto_examples/1eof/plot_gwpca_codeobj.pickle differ diff --git a/docs/auto_examples/1eof/sg_execution_times.rst b/docs/auto_examples/1eof/sg_execution_times.rst index 931745b..7752aa4 100644 --- a/docs/auto_examples/1eof/sg_execution_times.rst +++ b/docs/auto_examples/1eof/sg_execution_times.rst @@ -3,20 +3,25 @@ .. _sphx_glr_auto_examples_1eof_sg_execution_times: + Computation times ================= -**00:17.020** total execution time for **auto_examples_1eof** files: +**00:03.585** total execution time for **auto_examples_1eof** files: +--------------------------------------------------------------------------------------------+-----------+--------+ -| :ref:`sphx_glr_auto_examples_1eof_plot_rotated_eof.py` (``plot_rotated_eof.py``) | 00:17.020 | 0.0 MB | +| :ref:`sphx_glr_auto_examples_1eof_plot_eeof.py` (``plot_eeof.py``) | 00:03.585 | 0.0 MB | +--------------------------------------------------------------------------------------------+-----------+--------+ | :ref:`sphx_glr_auto_examples_1eof_plot_eof-smode.py` (``plot_eof-smode.py``) | 00:00.000 | 0.0 MB | +--------------------------------------------------------------------------------------------+-----------+--------+ | :ref:`sphx_glr_auto_examples_1eof_plot_eof-tmode.py` (``plot_eof-tmode.py``) | 00:00.000 | 0.0 MB | +--------------------------------------------------------------------------------------------+-----------+--------+ +| :ref:`sphx_glr_auto_examples_1eof_plot_gwpca.py` (``plot_gwpca.py``) | 00:00.000 | 0.0 MB | ++--------------------------------------------------------------------------------------------+-----------+--------+ | :ref:`sphx_glr_auto_examples_1eof_plot_mreof.py` (``plot_mreof.py``) | 00:00.000 | 0.0 MB | +--------------------------------------------------------------------------------------------+-----------+--------+ | :ref:`sphx_glr_auto_examples_1eof_plot_multivariate-eof.py` (``plot_multivariate-eof.py``) | 00:00.000 | 0.0 MB | +--------------------------------------------------------------------------------------------+-----------+--------+ +| :ref:`sphx_glr_auto_examples_1eof_plot_rotated_eof.py` (``plot_rotated_eof.py``) | 00:00.000 | 0.0 MB | ++--------------------------------------------------------------------------------------------+-----------+--------+ | :ref:`sphx_glr_auto_examples_1eof_plot_weighted-eof.py` (``plot_weighted-eof.py``) | 00:00.000 | 0.0 MB | +--------------------------------------------------------------------------------------------+-----------+--------+ diff --git a/docs/auto_examples/auto_examples_jupyter.zip b/docs/auto_examples/auto_examples_jupyter.zip index 3030f85..09abd69 100644 Binary files a/docs/auto_examples/auto_examples_jupyter.zip and b/docs/auto_examples/auto_examples_jupyter.zip differ diff --git a/docs/auto_examples/auto_examples_python.zip b/docs/auto_examples/auto_examples_python.zip index 578c681..1074732 100644 Binary files a/docs/auto_examples/auto_examples_python.zip and b/docs/auto_examples/auto_examples_python.zip differ diff --git a/docs/auto_examples/index.rst b/docs/auto_examples/index.rst index f0fcccc..2877785 100644 --- a/docs/auto_examples/index.rst +++ b/docs/auto_examples/index.rst @@ -27,6 +27,23 @@ Examples
+.. raw:: html + +
+ +.. only:: html + + .. image:: /auto_examples/1eof/images/thumb/sphx_glr_plot_eeof_thumb.png + :alt: + + :ref:`sphx_glr_auto_examples_1eof_plot_eeof.py` + +.. raw:: html + +
Extented EOF analysis
+
+ + .. raw:: html
@@ -129,6 +146,23 @@ Examples
+.. raw:: html + +
+ +.. only:: html + + .. image:: /auto_examples/1eof/images/thumb/sphx_glr_plot_gwpca_thumb.png + :alt: + + :ref:`sphx_glr_auto_examples_1eof_plot_gwpca.py` + +.. raw:: html + +
Geographically weighted PCA
+
+ + .. raw:: html
diff --git a/examples/1eof/plot_eeof.png b/examples/1eof/plot_eeof.png new file mode 100644 index 0000000..71cc545 Binary files /dev/null and b/examples/1eof/plot_eeof.png differ diff --git a/examples/1eof/plot_eeof.py b/examples/1eof/plot_eeof.py new file mode 100644 index 0000000..2a7ab1d --- /dev/null +++ b/examples/1eof/plot_eeof.py @@ -0,0 +1,83 @@ +""" +Extented EOF analysis +===================== + +This example demonstrates Extended EOF (EEOF) analysis on ``xarray`` tutorial +data. EEOF analysis, also termed as Multivariate/Multichannel Singular +Spectrum Analysis, advances traditional EOF analysis to capture propagating +signals or oscillations in multivariate datasets. At its core, this +involves the formulation of a lagged covariance matrix that encapsulates +both spatial and temporal correlations. Subsequently, this matrix is +decomposed to yield its eigenvectors (components) and eigenvalues (explained variance). + +Let's begin by setting up the required packages and fetching the data: +""" + +import xarray as xr +import xeofs as xe +import matplotlib.pyplot as plt + +xr.set_options(display_expand_data=False) + +# %% +# Load the tutorial data. +t2m = xr.tutorial.load_dataset("air_temperature").air + + +# %% +# Prior to conducting the EEOF analysis, it's essential to determine the +# structure of the lagged covariance matrix. This entails defining the time +# delay ``tau`` and the ``embedding`` dimension. The former signifies the +# interval between the original and lagged time series, while the latter +# dictates the number of time-lagged copies in the delay-coordinate space, +# representing the system's dynamics. +# For illustration, using ``tau=4`` and ``embedding=40``, we generate 40 +# delayed versions of the time series, each offset by 4 time steps, resulting +# in a maximum shift of ``tau x embedding = 160``. Given our dataset's +# 6-hour intervals, tau = 4 translates to a 24-hour shift. +# It's obvious that this way of constructing the lagged covariance matrix +# and subsequently decomposing it can be computationally expensive. For example, +# given our dataset's dimensions, + +t2m.shape + +# %% +# the extended dataset would have 40 x 25 x 53 = 53000 features +# which is much larger than the original dataset's 1325 features. +# To mitigate this, we can first preprocess the data using PCA / EOF analysis +# and then perform EEOF analysis on the resulting PCA / EOF scores. Here, +# we'll use ``n_pca_modes=50`` to retain the first 50 PCA modes, so we end +# up with 40 x 50 = 200 (latent) features. +# With these parameters set, we proceed to instantiate the ``ExtendedEOF`` +# model and fit our data. + +model = xe.models.ExtendedEOF( + n_modes=10, tau=4, embedding=40, n_pca_modes=50, use_coslat=True +) +model.fit(t2m, dim="time") +scores = model.scores() +components = model.components() +components + +# %% +# A notable distinction from standard EOF analysis is the incorporation of an +# extra ``embedding`` dimension in the components. Nonetheless, the +# overarching methodology mirrors traditional EOF practices. The results, +# for instance, can be assessed by examining the explained variance ratio. + +model.explained_variance_ratio().plot() +plt.show() + +# %% +# Additionally, we can look into the scores; let's spotlight mode 4. + +scores.sel(mode=4).plot() +plt.show() + +# %% +# In wrapping up, we visualize the corresponding EEOF component of mode 4. +# For visualization purposes, we'll focus on the component at a specific +# latitude, in this instance, 60 degrees north. + +components.sel(mode=4, lat=60).plot() +plt.show() diff --git a/examples/1eof/plot_gwpca.py b/examples/1eof/plot_gwpca.py new file mode 100644 index 0000000..a084283 --- /dev/null +++ b/examples/1eof/plot_gwpca.py @@ -0,0 +1,174 @@ +""" +Geographically weighted PCA +=========================== +Geographically Weighted Principal Component Analysis (GWPCA) is a spatial analysis method that identifies and visualizes local spatial patterns and relationships in multivariate datasets across various geographic areas. It operates by applying PCA within a moving window over a geographical region, which enables the extraction of local principal components that can differ across locations. + +TIn this demonstration, we'll apply GWPCA to a dataset detailing the chemical compositions of soils from countries around the Baltic Sea [1]_. This example is inspired by a tutorial originally crafted and published by Chris Brunsdon [2]_. +The dataset comprises 10 variables (chemical elements) and spans 768 samples. +Here, each sample refers to a pair of latitude and longitude coordinates, representing specific sampling stations. + +.. [1] Reimann, C. et al. Baltic soil survey: total concentrations of major and selected trace elements in arable soils from 10 countries around the Baltic Sea. Science of The Total Environment 257, 155–170 (2000). +.. [2] https://rpubs.com/chrisbrunsdon/99675 + + + +.. note:: The dataset we're using is found in the R package + `mvoutlier `_. + To access it, we'll employ the Python package + `rpy2 `_ which facilitates + interaction with R packages from within Python. + +.. note:: Presently, there's no support for ``xarray.Dataset`` lacking an explicit feature dimension. + As a workaround, ``xarray.DataArray.to_array`` can be used to convert the ``Dataset`` to an ``DataArray``. + +.. warning:: Bear in mind that GWPCA requires significant computational power. + The ``xeofs`` implementation is optimized for CPU efficiency and is best suited + for smaller to medium data sets. For more extensive datasets where parallel processing becomes essential, + it's advisable to turn to the R package `GWmodel `_. + This package harnesses CUDA to enable GPU-accelerated GWPCA for optimized performance. + + +Let's import the necessary packages. +""" +# For the analysis +import numpy as np +import xarray as xr +import xeofs as xe + +# For visualization +import matplotlib.pyplot as plt +import seaborn as sns + +# For accessing R packages +import rpy2.robjects as ro +from rpy2.robjects.packages import importr +from rpy2.robjects import pandas2ri + +# %% +# Next, we'll install the R package `mvoutlier `_ +# using the `rpy2 `_ package. + +xr.set_options(display_expand_data=False) +utils = importr("utils") +utils.chooseCRANmirror(ind=1) +utils.install_packages("mvoutlier") + +# %% +# Let's load the dataset and convert it into a ``pandas.DataFrame``. +# Alongside, we'll also load the background data that outlines the borders of countries +# in the Baltic Sea region. This will help us visually represent the GWPCA results. + +ro.r( + """ + require("mvoutlier") + data(bsstop) + Data <- bsstop[,1:14] + background <- bss.background + """ +) +with (ro.default_converter + pandas2ri.converter).context(): + data_df = ro.conversion.get_conversion().rpy2py(ro.r["Data"]) + background_df = ro.conversion.get_conversion().rpy2py(ro.r["background"]) +data_df.head() + +# %% +# Since ``xeofs`` uses ``xarray``, we convert the data into an ``xarray.DataArray``. + +data_df = data_df.rename(columns={"ID": "station"}).set_index("station") +data = data_df.to_xarray() +data = data.rename({"XCOO": "x", "YCOO": "y"}) +data = data.set_index(station=("x", "y")) +data = data.drop_vars("CNo") +da = data.to_array(dim="element") +da + +# %% +# Let's dive into the GWPCA. First, initialize a ``GWPCA`` instance and fit it to the data. +# The ``station`` dimension serves as our sample dimension, along which the local PCAs will be applied. +# Since these PCAs need to gauge distances to adjacent stations, we must specify +# a distance metric. Our station data includes coordinates in meters, so we'll +# choose the ``euclidean`` metric. If you have coordinates in degrees (like +# latitude and longitude), choose the ``haversine`` metric instead. +# We're also using a ``bisquare`` kernel with a bandwidth of 1000 km. Note that the +# bandwidth unit always follows input data (which is in meters here), +# except when using the ``haversine`` metric, which always gives distances in +# kilometers. Lastly, we'll standardize the input to ensure consistent scales +# for the chemical elements. + +gwpca = xe.models.GWPCA( + n_modes=5, + standardize=True, + metric="euclidean", + kernel="bisquare", + bandwidth=1000000.0, +) +gwpca.fit(da, "station") +gwpca.components() + + +# %% +# The ``components`` method returns the local principal components for each station. Note that the +# dimensionality of the returned array is ``[station, element, mode]``, so in practice we don't really have +# reduced the dimensionality of the data set. However, we can +# extract the largest locally weighted components for each station which tells us which chemical elements +# dominate the local PCAs. + +llwc = gwpca.largest_locally_weighted_components() +llwc + +# %% +# Let's visualize the spatial patterns of the chemical elements. +# As the stations are positioned on a irregular grid, we'll transform the +# ``llwc`` ``DataArray`` into a ``pandas.DataFrame``. After that, we can easily visualize +# it using the ``scatter`` method. +# For demonstation, we'll concentrate on the first mode: + +llwc1_df = llwc.sel(mode=1).to_dataframe() + +elements = da.element.values +n_elements = len(elements) +colors = np.arange(n_elements) +col_dict = {el: col for el, col in zip(elements, colors)} + +llwc1_df["colors"] = llwc1_df["largest_locally_weighted_components"].map(col_dict) +cmap = sns.color_palette("tab10", n_colors=n_elements, as_cmap=True) + + +fig = plt.figure(figsize=(10, 10)) +ax = fig.add_subplot(111) +background_df.plot.scatter(ax=ax, x="V1", y="V2", color=".3", marker=".", s=1) +s = ax.scatter( + x=llwc1_df["x"], + y=llwc1_df["y"], + c=llwc1_df["colors"], + ec="w", + s=40, + cmap=cmap, + vmin=-0.5, + vmax=n_elements - 0.5, +) +cbar = fig.colorbar(mappable=s, ax=ax, label="Largest locally weighted component") +cbar.set_ticks(colors) +cbar.set_ticklabels(elements) +ax.set_title("Largest locally weighted element", loc="left", weight=800) +plt.show() + +# %% +# In the final step, let's examine the explained variance. Like standard PCA, +# this gives us insight into the variance explained by each mode. But with a +# local PCA for every station, the explained variance varies spatially. Notably, +# the first mode's explained variance differs across countries, ranging from +# roughly 40% to 70%. + + +exp_var_ratio = gwpca.explained_variance_ratio() +evr1_df = exp_var_ratio.sel(mode=1).to_dataframe() + +fig = plt.figure(figsize=(10, 10)) +ax = fig.add_subplot(111) +background_df.plot.scatter(ax=ax, x="V1", y="V2", color=".3", marker=".", s=1) +evr1_df.plot.scatter( + ax=ax, x="x", y="y", c="explained_variance_ratio", vmin=0.4, vmax=0.7 +) +ax.set_title("Fraction of locally explained variance", loc="left", weight=800) +plt.show() diff --git a/poetry.lock b/poetry.lock index 6f4649c..4b83264 100644 --- a/poetry.lock +++ b/poetry.lock @@ -4,7 +4,7 @@ name = "accessible-pygments" version = "0.0.4" description = "A collection of accessible pygments styles" -optional = false +optional = true python-versions = "*" files = [ {file = "accessible-pygments-0.0.4.tar.gz", hash = "sha256:e7b57a9b15958e9601c7e9eb07a440c813283545a20973f2574a5f453d0e953e"}, @@ -18,7 +18,7 @@ pygments = ">=1.5" name = "alabaster" version = "0.7.13" description = "A configurable sidebar-enabled Sphinx theme" -optional = false +optional = true python-versions = ">=3.6" files = [ {file = "alabaster-0.7.13-py3-none-any.whl", hash = "sha256:1ee19aca801bbabb5ba3f5f258e4422dfa86f82f3e9cefb0859b283cdd7f62a3"}, @@ -29,7 +29,7 @@ files = [ name = "attrs" version = "23.1.0" description = "Classes Without Boilerplate" -optional = false +optional = true python-versions = ">=3.7" files = [ {file = "attrs-23.1.0-py3-none-any.whl", hash = "sha256:1f28b4522cdc2fb4256ac1a020c78acf9cba2c6b461ccd2c126f3aa8e8335d04"}, @@ -47,7 +47,7 @@ tests-no-zope = ["cloudpickle", "hypothesis", "mypy (>=1.1.1)", "pympler", "pyte name = "babel" version = "2.12.1" description = "Internationalization utilities" -optional = false +optional = true python-versions = ">=3.7" files = [ {file = "Babel-2.12.1-py3-none-any.whl", hash = "sha256:b4246fb7677d3b98f501a39d43396d3cafdc8eadb045f4a31be01863f655c610"}, @@ -58,7 +58,7 @@ files = [ name = "beautifulsoup4" version = "4.12.2" description = "Screen-scraping library" -optional = false +optional = true python-versions = ">=3.6.0" files = [ {file = "beautifulsoup4-4.12.2-py3-none-any.whl", hash = "sha256:bd2520ca0d9d7d12694a53d44ac482d181b4ec1888909b035a3dbf40d0f57d4a"}, @@ -121,7 +121,7 @@ uvloop = ["uvloop (>=0.15.2)"] name = "bleach" version = "6.0.0" description = "An easy safelist-based HTML-sanitizing tool." -optional = false +optional = true python-versions = ">=3.7" files = [ {file = "bleach-6.0.0-py3-none-any.whl", hash = "sha256:33c16e3353dbd13028ab4799a0f89a83f113405c766e9c122df8a06f5b85b3f4"}, @@ -150,7 +150,7 @@ files = [ name = "cffi" version = "1.15.1" description = "Foreign Function Interface for Python calling C code." -optional = false +optional = true python-versions = "*" files = [ {file = "cffi-1.15.1-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:a66d3508133af6e8548451b25058d5812812ec3798c886bf38ed24a98216fab2"}, @@ -470,7 +470,7 @@ test = ["pandas[test]", "pre-commit", "pytest", "pytest-cov", "pytest-rerunfailu name = "defusedxml" version = "0.7.1" description = "XML bomb protection for Python stdlib modules" -optional = false +optional = true python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" files = [ {file = "defusedxml-0.7.1-py2.py3-none-any.whl", hash = "sha256:a352e7e428770286cc899e2542b6cdaedb2b4953ff269a210103ec58f6198a61"}, @@ -481,7 +481,7 @@ files = [ name = "docutils" version = "0.20.1" description = "Docutils -- Python Documentation Utilities" -optional = false +optional = true python-versions = ">=3.7" files = [ {file = "docutils-0.20.1-py3-none-any.whl", hash = "sha256:96f387a2c5562db4476f09f13bbab2192e764cac08ebbf3a34a95d9b1e4a59d6"}, @@ -506,7 +506,7 @@ test = ["pytest (>=6)"] name = "fastjsonschema" version = "2.18.0" description = "Fastest Python implementation of JSON schema" -optional = false +optional = true python-versions = "*" files = [ {file = "fastjsonschema-2.18.0-py3-none-any.whl", hash = "sha256:128039912a11a807068a7c87d0da36660afbfd7202780db26c4aa7153cfdc799"}, @@ -582,7 +582,7 @@ files = [ name = "imagesize" version = "1.4.1" description = "Getting image size from png/jpeg/jpeg2000/gif file" -optional = false +optional = true python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" files = [ {file = "imagesize-1.4.1-py2.py3-none-any.whl", hash = "sha256:0d8d18d08f840c19d0ee7ca1fd82490fdc3729b7ac93f49870406ddde8ef8d8b"}, @@ -623,7 +623,7 @@ files = [ name = "jinja2" version = "3.1.2" description = "A very fast and expressive template engine." -optional = false +optional = true python-versions = ">=3.7" files = [ {file = "Jinja2-3.1.2-py3-none-any.whl", hash = "sha256:6088930bfe239f0e6710546ab9c19c9ef35e29792895fed6e6e31a023a182a61"}, @@ -651,7 +651,7 @@ files = [ name = "jsonschema" version = "4.18.4" description = "An implementation of JSON Schema validation for Python" -optional = false +optional = true python-versions = ">=3.8" files = [ {file = "jsonschema-4.18.4-py3-none-any.whl", hash = "sha256:971be834317c22daaa9132340a51c01b50910724082c2c1a2ac87eeec153a3fe"}, @@ -672,7 +672,7 @@ format-nongpl = ["fqdn", "idna", "isoduration", "jsonpointer (>1.13)", "rfc3339- name = "jsonschema-specifications" version = "2023.7.1" description = "The JSON Schema meta-schemas and vocabularies, exposed as a Registry" -optional = false +optional = true python-versions = ">=3.8" files = [ {file = "jsonschema_specifications-2023.7.1-py3-none-any.whl", hash = "sha256:05adf340b659828a004220a9613be00fa3f223f2b82002e273dee62fd50524b1"}, @@ -686,7 +686,7 @@ referencing = ">=0.28.0" name = "jupyter-client" version = "8.3.0" description = "Jupyter protocol implementation and client libraries" -optional = false +optional = true python-versions = ">=3.8" files = [ {file = "jupyter_client-8.3.0-py3-none-any.whl", hash = "sha256:7441af0c0672edc5d28035e92ba5e32fadcfa8a4e608a434c228836a89df6158"}, @@ -708,7 +708,7 @@ test = ["coverage", "ipykernel (>=6.14)", "mypy", "paramiko", "pre-commit", "pyt name = "jupyter-core" version = "5.3.1" description = "Jupyter core package. A base package on which Jupyter projects rely." -optional = false +optional = true python-versions = ">=3.8" files = [ {file = "jupyter_core-5.3.1-py3-none-any.whl", hash = "sha256:ae9036db959a71ec1cac33081eeb040a79e681f08ab68b0883e9a676c7a90dce"}, @@ -728,13 +728,46 @@ test = ["ipykernel", "pre-commit", "pytest", "pytest-cov", "pytest-timeout"] name = "jupyterlab-pygments" version = "0.2.2" description = "Pygments theme using JupyterLab CSS variables" -optional = false +optional = true python-versions = ">=3.7" files = [ {file = "jupyterlab_pygments-0.2.2-py2.py3-none-any.whl", hash = "sha256:2405800db07c9f770863bcf8049a529c3dd4d3e28536638bd7c1c01d2748309f"}, {file = "jupyterlab_pygments-0.2.2.tar.gz", hash = "sha256:7405d7fde60819d905a9fa8ce89e4cd830e318cdad22a0030f7a901da705585d"}, ] +[[package]] +name = "llvmlite" +version = "0.40.1" +description = "lightweight wrapper around basic LLVM functionality" +optional = false +python-versions = ">=3.8" +files = [ + {file = "llvmlite-0.40.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:84ce9b1c7a59936382ffde7871978cddcda14098e5a76d961e204523e5c372fb"}, + {file = "llvmlite-0.40.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:3673c53cb21c65d2ff3704962b5958e967c6fc0bd0cff772998face199e8d87b"}, + {file = "llvmlite-0.40.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bba2747cf5b4954e945c287fe310b3fcc484e2a9d1b0c273e99eb17d103bb0e6"}, + {file = "llvmlite-0.40.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bbd5e82cc990e5a3e343a3bf855c26fdfe3bfae55225f00efd01c05bbda79918"}, + {file = "llvmlite-0.40.1-cp310-cp310-win32.whl", hash = "sha256:09f83ea7a54509c285f905d968184bba00fc31ebf12f2b6b1494d677bb7dde9b"}, + {file = "llvmlite-0.40.1-cp310-cp310-win_amd64.whl", hash = "sha256:7b37297f3cbd68d14a97223a30620589d98ad1890e5040c9e5fc181063f4ed49"}, + {file = "llvmlite-0.40.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:a66a5bd580951751b4268f4c3bddcef92682814d6bc72f3cd3bb67f335dd7097"}, + {file = "llvmlite-0.40.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:467b43836b388eaedc5a106d76761e388dbc4674b2f2237bc477c6895b15a634"}, + {file = "llvmlite-0.40.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0c23edd196bd797dc3a7860799054ea3488d2824ecabc03f9135110c2e39fcbc"}, + {file = "llvmlite-0.40.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a36d9f244b6680cb90bbca66b146dabb2972f4180c64415c96f7c8a2d8b60a36"}, + {file = "llvmlite-0.40.1-cp311-cp311-win_amd64.whl", hash = "sha256:5b3076dc4e9c107d16dc15ecb7f2faf94f7736cd2d5e9f4dc06287fd672452c1"}, + {file = "llvmlite-0.40.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:4a7525db121f2e699809b539b5308228854ccab6693ecb01b52c44a2f5647e20"}, + {file = "llvmlite-0.40.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:84747289775d0874e506f907a4513db889471607db19b04de97d144047fec885"}, + {file = "llvmlite-0.40.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e35766e42acef0fe7d1c43169a8ffc327a47808fae6a067b049fe0e9bbf84dd5"}, + {file = "llvmlite-0.40.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cda71de10a1f48416309e408ea83dab5bf36058f83e13b86a2961defed265568"}, + {file = "llvmlite-0.40.1-cp38-cp38-win32.whl", hash = "sha256:96707ebad8b051bbb4fc40c65ef93b7eeee16643bd4d579a14d11578e4b7a647"}, + {file = "llvmlite-0.40.1-cp38-cp38-win_amd64.whl", hash = "sha256:e44f854dc11559795bcdeaf12303759e56213d42dabbf91a5897aa2d8b033810"}, + {file = "llvmlite-0.40.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:f643d15aacd0b0b0dc8b74b693822ba3f9a53fa63bc6a178c2dba7cc88f42144"}, + {file = "llvmlite-0.40.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:39a0b4d0088c01a469a5860d2e2d7a9b4e6a93c0f07eb26e71a9a872a8cadf8d"}, + {file = "llvmlite-0.40.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9329b930d699699846623054121ed105fd0823ed2180906d3b3235d361645490"}, + {file = "llvmlite-0.40.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e2dbbb8424037ca287983b115a29adf37d806baf7e1bf4a67bd2cffb74e085ed"}, + {file = "llvmlite-0.40.1-cp39-cp39-win32.whl", hash = "sha256:e74e7bec3235a1e1c9ad97d897a620c5007d0ed80c32c84c1d787e7daa17e4ec"}, + {file = "llvmlite-0.40.1-cp39-cp39-win_amd64.whl", hash = "sha256:ff8f31111bb99d135ff296757dc81ab36c2dee54ed4bd429158a96da9807c316"}, + {file = "llvmlite-0.40.1.tar.gz", hash = "sha256:5cdb0d45df602099d833d50bd9e81353a5e036242d3c003c5b294fc61d1986b4"}, +] + [[package]] name = "locket" version = "1.0.0" @@ -750,7 +783,7 @@ files = [ name = "markupsafe" version = "2.1.3" description = "Safely add untrusted strings to HTML/XML markup." -optional = false +optional = true python-versions = ">=3.7" files = [ {file = "MarkupSafe-2.1.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:cd0f502fe016460680cd20aaa5a76d241d6f35a1c3350c474bac1273803893fa"}, @@ -820,7 +853,7 @@ files = [ name = "mistune" version = "3.0.1" description = "A sane and fast Markdown parser with useful plugins and renderers" -optional = false +optional = true python-versions = ">=3.7" files = [ {file = "mistune-3.0.1-py3-none-any.whl", hash = "sha256:b9b3e438efbb57c62b5beb5e134dab664800bdf1284a7ee09e8b12b13eb1aac6"}, @@ -842,7 +875,7 @@ files = [ name = "nbclient" version = "0.8.0" description = "A client library for executing notebooks. Formerly nbconvert's ExecutePreprocessor." -optional = false +optional = true python-versions = ">=3.8.0" files = [ {file = "nbclient-0.8.0-py3-none-any.whl", hash = "sha256:25e861299e5303a0477568557c4045eccc7a34c17fc08e7959558707b9ebe548"}, @@ -864,7 +897,7 @@ test = ["flaky", "ipykernel (>=6.19.3)", "ipython", "ipywidgets", "nbconvert (>= name = "nbconvert" version = "7.7.3" description = "Converting Jupyter Notebooks" -optional = false +optional = true python-versions = ">=3.8" files = [ {file = "nbconvert-7.7.3-py3-none-any.whl", hash = "sha256:3022adadff3f86578a47fab7c2228bb3ca9c56a24345642a22f917f6168b48fc"}, @@ -901,7 +934,7 @@ webpdf = ["playwright"] name = "nbformat" version = "5.9.1" description = "The Jupyter Notebook format" -optional = false +optional = true python-versions = ">=3.8" files = [ {file = "nbformat-5.9.1-py3-none-any.whl", hash = "sha256:b7968ebf4811178a4108ee837eae1442e3f054132100f0359219e9ed1ce3ca45"}, @@ -920,13 +953,13 @@ test = ["pep440", "pre-commit", "pytest", "testpath"] [[package]] name = "nbsphinx" -version = "0.9.2" +version = "0.9.3" description = "Jupyter Notebook Tools for Sphinx" -optional = false +optional = true python-versions = ">=3.6" files = [ - {file = "nbsphinx-0.9.2-py3-none-any.whl", hash = "sha256:2746680ece5ad3b0e980639d717a5041a1c1aafb416846b72dfaeecc306bc351"}, - {file = "nbsphinx-0.9.2.tar.gz", hash = "sha256:540db7f4066347f23d0650c4ae8e7d85334c69adf749e030af64c12e996ff88e"}, + {file = "nbsphinx-0.9.3-py3-none-any.whl", hash = "sha256:6e805e9627f4a358bd5720d5cbf8bf48853989c79af557afd91a5f22e163029f"}, + {file = "nbsphinx-0.9.3.tar.gz", hash = "sha256:ec339c8691b688f8676104a367a4b8cf3ea01fd089dc28d24dec22d563b11562"}, ] [package.dependencies] @@ -978,38 +1011,78 @@ certifi = "*" cftime = "*" numpy = "*" +[[package]] +name = "numba" +version = "0.57.1" +description = "compiling Python code using LLVM" +optional = false +python-versions = ">=3.8" +files = [ + {file = "numba-0.57.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:db8268eb5093cae2288942a8cbd69c9352f6fe6e0bfa0a9a27679436f92e4248"}, + {file = "numba-0.57.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:643cb09a9ba9e1bd8b060e910aeca455e9442361e80fce97690795ff9840e681"}, + {file = "numba-0.57.1-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:53e9fab973d9e82c9f8449f75994a898daaaf821d84f06fbb0b9de2293dd9306"}, + {file = "numba-0.57.1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:c0602e4f896e6a6d844517c3ab434bc978e7698a22a733cc8124465898c28fa8"}, + {file = "numba-0.57.1-cp310-cp310-win32.whl", hash = "sha256:3d6483c27520d16cf5d122868b79cad79e48056ecb721b52d70c126bed65431e"}, + {file = "numba-0.57.1-cp310-cp310-win_amd64.whl", hash = "sha256:a32ee263649aa3c3587b833d6311305379529570e6c20deb0c6f4fb5bc7020db"}, + {file = "numba-0.57.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4c078f84b5529a7fdb8413bb33d5100f11ec7b44aa705857d9eb4e54a54ff505"}, + {file = "numba-0.57.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e447c4634d1cc99ab50d4faa68f680f1d88b06a2a05acf134aa6fcc0342adeca"}, + {file = "numba-0.57.1-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:4838edef2df5f056cb8974670f3d66562e751040c448eb0b67c7e2fec1726649"}, + {file = "numba-0.57.1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:9b17fbe4a69dcd9a7cd49916b6463cd9a82af5f84911feeb40793b8bce00dfa7"}, + {file = "numba-0.57.1-cp311-cp311-win_amd64.whl", hash = "sha256:93df62304ada9b351818ba19b1cfbddaf72cd89348e81474326ca0b23bf0bae1"}, + {file = "numba-0.57.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:8e00ca63c5d0ad2beeb78d77f087b3a88c45ea9b97e7622ab2ec411a868420ee"}, + {file = "numba-0.57.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:ff66d5b022af6c7d81ddbefa87768e78ed4f834ab2da6ca2fd0d60a9e69b94f5"}, + {file = "numba-0.57.1-cp38-cp38-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:60ec56386076e9eed106a87c96626d5686fbb16293b9834f0849cf78c9491779"}, + {file = "numba-0.57.1-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:6c057ccedca95df23802b6ccad86bb318be624af45b5a38bb8412882be57a681"}, + {file = "numba-0.57.1-cp38-cp38-win32.whl", hash = "sha256:5a82bf37444039c732485c072fda21a361790ed990f88db57fd6941cd5e5d307"}, + {file = "numba-0.57.1-cp38-cp38-win_amd64.whl", hash = "sha256:9bcc36478773ce838f38afd9a4dfafc328d4ffb1915381353d657da7f6473282"}, + {file = "numba-0.57.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:ae50c8c90c2ce8057f9618b589223e13faa8cbc037d8f15b4aad95a2c33a0582"}, + {file = "numba-0.57.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9a1b2b69448e510d672ff9a6b18d2db9355241d93c6a77677baa14bec67dc2a0"}, + {file = "numba-0.57.1-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:3cf78d74ad9d289fbc1e5b1c9f2680fca7a788311eb620581893ab347ec37a7e"}, + {file = "numba-0.57.1-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:f47dd214adc5dcd040fe9ad2adbd2192133c9075d2189ce1b3d5f9d72863ef05"}, + {file = "numba-0.57.1-cp39-cp39-win32.whl", hash = "sha256:a3eac19529956185677acb7f01864919761bfffbb9ae04bbbe5e84bbc06cfc2b"}, + {file = "numba-0.57.1-cp39-cp39-win_amd64.whl", hash = "sha256:9587ba1bf5f3035575e45562ada17737535c6d612df751e811d702693a72d95e"}, + {file = "numba-0.57.1.tar.gz", hash = "sha256:33c0500170d213e66d90558ad6aca57d3e03e97bb11da82e6d87ab793648cb17"}, +] + +[package.dependencies] +llvmlite = "==0.40.*" +numpy = ">=1.21,<1.25" + [[package]] name = "numpy" -version = "1.25.1" +version = "1.24.4" description = "Fundamental package for array computing in Python" optional = false -python-versions = ">=3.9" +python-versions = ">=3.8" files = [ - {file = "numpy-1.25.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:77d339465dff3eb33c701430bcb9c325b60354698340229e1dff97745e6b3efa"}, - {file = "numpy-1.25.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:d736b75c3f2cb96843a5c7f8d8ccc414768d34b0a75f466c05f3a739b406f10b"}, - {file = "numpy-1.25.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4a90725800caeaa160732d6b31f3f843ebd45d6b5f3eec9e8cc287e30f2805bf"}, - {file = "numpy-1.25.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c6c9261d21e617c6dc5eacba35cb68ec36bb72adcff0dee63f8fbc899362588"}, - {file = "numpy-1.25.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:0def91f8af6ec4bb94c370e38c575855bf1d0be8a8fbfba42ef9c073faf2cf19"}, - {file = "numpy-1.25.1-cp310-cp310-win32.whl", hash = "sha256:fd67b306320dcadea700a8f79b9e671e607f8696e98ec255915c0c6d6b818503"}, - {file = "numpy-1.25.1-cp310-cp310-win_amd64.whl", hash = "sha256:c1516db588987450b85595586605742879e50dcce923e8973f79529651545b57"}, - {file = "numpy-1.25.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6b82655dd8efeea69dbf85d00fca40013d7f503212bc5259056244961268b66e"}, - {file = "numpy-1.25.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e8f6049c4878cb16960fbbfb22105e49d13d752d4d8371b55110941fb3b17800"}, - {file = "numpy-1.25.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:41a56b70e8139884eccb2f733c2f7378af06c82304959e174f8e7370af112e09"}, - {file = "numpy-1.25.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d5154b1a25ec796b1aee12ac1b22f414f94752c5f94832f14d8d6c9ac40bcca6"}, - {file = "numpy-1.25.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:38eb6548bb91c421261b4805dc44def9ca1a6eef6444ce35ad1669c0f1a3fc5d"}, - {file = "numpy-1.25.1-cp311-cp311-win32.whl", hash = "sha256:791f409064d0a69dd20579345d852c59822c6aa087f23b07b1b4e28ff5880fcb"}, - {file = "numpy-1.25.1-cp311-cp311-win_amd64.whl", hash = "sha256:c40571fe966393b212689aa17e32ed905924120737194b5d5c1b20b9ed0fb171"}, - {file = "numpy-1.25.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:3d7abcdd85aea3e6cdddb59af2350c7ab1ed764397f8eec97a038ad244d2d105"}, - {file = "numpy-1.25.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:1a180429394f81c7933634ae49b37b472d343cccb5bb0c4a575ac8bbc433722f"}, - {file = "numpy-1.25.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d412c1697c3853c6fc3cb9751b4915859c7afe6a277c2bf00acf287d56c4e625"}, - {file = "numpy-1.25.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:20e1266411120a4f16fad8efa8e0454d21d00b8c7cee5b5ccad7565d95eb42dd"}, - {file = "numpy-1.25.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:f76aebc3358ade9eacf9bc2bb8ae589863a4f911611694103af05346637df1b7"}, - {file = "numpy-1.25.1-cp39-cp39-win32.whl", hash = "sha256:247d3ffdd7775bdf191f848be8d49100495114c82c2bd134e8d5d075fb386a1c"}, - {file = "numpy-1.25.1-cp39-cp39-win_amd64.whl", hash = "sha256:1d5d3c68e443c90b38fdf8ef40e60e2538a27548b39b12b73132456847f4b631"}, - {file = "numpy-1.25.1-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:35a9527c977b924042170a0887de727cd84ff179e478481404c5dc66b4170009"}, - {file = "numpy-1.25.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d3fe3dd0506a28493d82dc3cf254be8cd0d26f4008a417385cbf1ae95b54004"}, - {file = "numpy-1.25.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:012097b5b0d00a11070e8f2e261128c44157a8689f7dedcf35576e525893f4fe"}, - {file = "numpy-1.25.1.tar.gz", hash = "sha256:9a3a9f3a61480cc086117b426a8bd86869c213fc4072e606f01c4e4b66eb92bf"}, + {file = "numpy-1.24.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c0bfb52d2169d58c1cdb8cc1f16989101639b34c7d3ce60ed70b19c63eba0b64"}, + {file = "numpy-1.24.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ed094d4f0c177b1b8e7aa9cba7d6ceed51c0e569a5318ac0ca9a090680a6a1b1"}, + {file = "numpy-1.24.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:79fc682a374c4a8ed08b331bef9c5f582585d1048fa6d80bc6c35bc384eee9b4"}, + {file = "numpy-1.24.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7ffe43c74893dbf38c2b0a1f5428760a1a9c98285553c89e12d70a96a7f3a4d6"}, + {file = "numpy-1.24.4-cp310-cp310-win32.whl", hash = "sha256:4c21decb6ea94057331e111a5bed9a79d335658c27ce2adb580fb4d54f2ad9bc"}, + {file = "numpy-1.24.4-cp310-cp310-win_amd64.whl", hash = "sha256:b4bea75e47d9586d31e892a7401f76e909712a0fd510f58f5337bea9572c571e"}, + {file = "numpy-1.24.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:f136bab9c2cfd8da131132c2cf6cc27331dd6fae65f95f69dcd4ae3c3639c810"}, + {file = "numpy-1.24.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e2926dac25b313635e4d6cf4dc4e51c8c0ebfed60b801c799ffc4c32bf3d1254"}, + {file = "numpy-1.24.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:222e40d0e2548690405b0b3c7b21d1169117391c2e82c378467ef9ab4c8f0da7"}, + {file = "numpy-1.24.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7215847ce88a85ce39baf9e89070cb860c98fdddacbaa6c0da3ffb31b3350bd5"}, + {file = "numpy-1.24.4-cp311-cp311-win32.whl", hash = "sha256:4979217d7de511a8d57f4b4b5b2b965f707768440c17cb70fbf254c4b225238d"}, + {file = "numpy-1.24.4-cp311-cp311-win_amd64.whl", hash = "sha256:b7b1fc9864d7d39e28f41d089bfd6353cb5f27ecd9905348c24187a768c79694"}, + {file = "numpy-1.24.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1452241c290f3e2a312c137a9999cdbf63f78864d63c79039bda65ee86943f61"}, + {file = "numpy-1.24.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:04640dab83f7c6c85abf9cd729c5b65f1ebd0ccf9de90b270cd61935eef0197f"}, + {file = "numpy-1.24.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a5425b114831d1e77e4b5d812b69d11d962e104095a5b9c3b641a218abcc050e"}, + {file = "numpy-1.24.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dd80e219fd4c71fc3699fc1dadac5dcf4fd882bfc6f7ec53d30fa197b8ee22dc"}, + {file = "numpy-1.24.4-cp38-cp38-win32.whl", hash = "sha256:4602244f345453db537be5314d3983dbf5834a9701b7723ec28923e2889e0bb2"}, + {file = "numpy-1.24.4-cp38-cp38-win_amd64.whl", hash = "sha256:692f2e0f55794943c5bfff12b3f56f99af76f902fc47487bdfe97856de51a706"}, + {file = "numpy-1.24.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:2541312fbf09977f3b3ad449c4e5f4bb55d0dbf79226d7724211acc905049400"}, + {file = "numpy-1.24.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9667575fb6d13c95f1b36aca12c5ee3356bf001b714fc354eb5465ce1609e62f"}, + {file = "numpy-1.24.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f3a86ed21e4f87050382c7bc96571755193c4c1392490744ac73d660e8f564a9"}, + {file = "numpy-1.24.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d11efb4dbecbdf22508d55e48d9c8384db795e1b7b51ea735289ff96613ff74d"}, + {file = "numpy-1.24.4-cp39-cp39-win32.whl", hash = "sha256:6620c0acd41dbcb368610bb2f4d83145674040025e5536954782467100aa8835"}, + {file = "numpy-1.24.4-cp39-cp39-win_amd64.whl", hash = "sha256:befe2bf740fd8373cf56149a5c23a0f601e82869598d41f8e188a0e9869926f8"}, + {file = "numpy-1.24.4-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:31f13e25b4e304632a4619d0e0777662c2ffea99fcae2029556b17d8ff958aef"}, + {file = "numpy-1.24.4-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95f7ac6540e95bc440ad77f56e520da5bf877f87dca58bd095288dce8940532a"}, + {file = "numpy-1.24.4-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:e98f220aa76ca2a977fe435f5b04d7b3470c0a2e6312907b37ba6068f26787f2"}, + {file = "numpy-1.24.4.tar.gz", hash = "sha256:80f5e3a4e498641401868df4208b74581206afbee7cf7b8329daae82676d9463"}, ] [[package]] @@ -1093,7 +1166,7 @@ xml = ["lxml (>=4.6.3)"] name = "pandocfilters" version = "1.5.0" description = "Utilities for writing pandoc filters in python" -optional = false +optional = true python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" files = [ {file = "pandocfilters-1.5.0-py2.py3-none-any.whl", hash = "sha256:33aae3f25fd1a026079f5d27bdd52496f0e0803b3469282162bafdcbdf6ef14f"}, @@ -1213,7 +1286,7 @@ files = [ name = "pycparser" version = "2.21" description = "C parser in Python" -optional = false +optional = true python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" files = [ {file = "pycparser-2.21-py2.py3-none-any.whl", hash = "sha256:8ee45429555515e1f6b185e78100aea234072576aa43ab53aefcae078162fca9"}, @@ -1222,13 +1295,13 @@ files = [ [[package]] name = "pydata-sphinx-theme" -version = "0.13.3" +version = "0.14.1" description = "Bootstrap-based Sphinx theme from the PyData community" -optional = false -python-versions = ">=3.7" +optional = true +python-versions = ">=3.8" files = [ - {file = "pydata_sphinx_theme-0.13.3-py3-none-any.whl", hash = "sha256:bf41ca6c1c6216e929e28834e404bfc90e080b51915bbe7563b5e6fda70354f0"}, - {file = "pydata_sphinx_theme-0.13.3.tar.gz", hash = "sha256:827f16b065c4fd97e847c11c108bf632b7f2ff53a3bca3272f63f3f3ff782ecc"}, + {file = "pydata_sphinx_theme-0.14.1-py3-none-any.whl", hash = "sha256:c436027bc76ae023df4e70517e3baf90cdda5a88ee46b818b5ef0cc3884aba04"}, + {file = "pydata_sphinx_theme-0.14.1.tar.gz", hash = "sha256:d8d4ac81252c16a002e835d21f0fea6d04cf3608e95045c816e8cc823e79b053"}, ] [package.dependencies] @@ -1238,13 +1311,14 @@ beautifulsoup4 = "*" docutils = "!=0.17.0" packaging = "*" pygments = ">=2.7" -sphinx = ">=4.2" +sphinx = ">=5.0" typing-extensions = "*" [package.extras] +a11y = ["pytest-playwright"] dev = ["nox", "pre-commit", "pydata-sphinx-theme[doc,test]", "pyyaml"] -doc = ["ablog (>=0.11.0rc2)", "colorama", "ipyleaflet", "jupyter_sphinx", "linkify-it-py", "matplotlib", "myst-nb", "nbsphinx", "numpy", "numpydoc", "pandas", "plotly", "rich", "sphinx-copybutton", "sphinx-design", "sphinx-favicon (>=1.0.1)", "sphinx-sitemap", "sphinx-togglebutton", "sphinxcontrib-youtube", "sphinxext-rediraffe", "xarray"] -test = ["codecov", "pytest", "pytest-cov", "pytest-regressions"] +doc = ["ablog (>=0.11.0rc2)", "colorama", "ipyleaflet", "jupyter_sphinx", "jupyterlite-sphinx", "linkify-it-py", "matplotlib", "myst-nb", "nbsphinx", "numpy", "numpydoc", "pandas", "plotly", "rich", "sphinx-autoapi", "sphinx-copybutton", "sphinx-design", "sphinx-favicon (>=1.0.1)", "sphinx-sitemap", "sphinx-togglebutton", "sphinxcontrib-youtube (<1.4)", "sphinxext-rediraffe", "xarray"] +test = ["pytest", "pytest-cov", "pytest-regressions"] [[package]] name = "pyflakes" @@ -1261,7 +1335,7 @@ files = [ name = "pygments" version = "2.15.1" description = "Pygments is a syntax highlighting package written in Python." -optional = false +optional = true python-versions = ">=3.7" files = [ {file = "Pygments-2.15.1-py3-none-any.whl", hash = "sha256:db2db3deb4b4179f399a09054b023b6a586b76499d36965813c71aa8ed7b5fd1"}, @@ -1322,7 +1396,7 @@ files = [ name = "pywin32" version = "306" description = "Python for Window Extensions" -optional = false +optional = true python-versions = "*" files = [ {file = "pywin32-306-cp310-cp310-win32.whl", hash = "sha256:06d3420a5155ba65f0b72f2699b5bacf3109f36acbe8923765c22938a69dfc8d"}, @@ -1394,7 +1468,7 @@ files = [ name = "pyzmq" version = "25.1.0" description = "Python bindings for 0MQ" -optional = false +optional = true python-versions = ">=3.6" files = [ {file = "pyzmq-25.1.0-cp310-cp310-macosx_10_15_universal2.whl", hash = "sha256:1a6169e69034eaa06823da6a93a7739ff38716142b3596c180363dee729d713d"}, @@ -1483,7 +1557,7 @@ cffi = {version = "*", markers = "implementation_name == \"pypy\""} name = "referencing" version = "0.30.0" description = "JSON Referencing + Python" -optional = false +optional = true python-versions = ">=3.8" files = [ {file = "referencing-0.30.0-py3-none-any.whl", hash = "sha256:c257b08a399b6c2f5a3510a50d28ab5dbc7bbde049bcaf954d43c446f83ab548"}, @@ -1519,7 +1593,7 @@ use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] name = "rpds-py" version = "0.9.2" description = "Python bindings to Rust's persistent data structures (rpds)" -optional = false +optional = true python-versions = ">=3.8" files = [ {file = "rpds_py-0.9.2-cp310-cp310-macosx_10_7_x86_64.whl", hash = "sha256:ab6919a09c055c9b092798ce18c6c4adf49d24d4d9e43a92b257e3f2548231e7"}, @@ -1621,6 +1695,33 @@ files = [ {file = "rpds_py-0.9.2.tar.gz", hash = "sha256:8d70e8f14900f2657c249ea4def963bed86a29b81f81f5b76b5a9215680de945"}, ] +[[package]] +name = "rpy2" +version = "3.5.14" +description = "Python interface to the R language (embedded R)" +optional = true +python-versions = ">=3.7" +files = [ + {file = "rpy2-3.5.14-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:5cb7398adfcb6ca4faefbe856fb7af95eb11722ad18fbfcb4a79dbea1cf71c7c"}, + {file = "rpy2-3.5.14-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:f35910208e5945b5108b7668bcb58127742f47fc0e8df8b2f4889c86be6f6519"}, + {file = "rpy2-3.5.14-cp38-cp38-macosx_11_0_x86_64.whl", hash = "sha256:adbd8e08f67f807fcca8e47473340e233a55c25fffd418081e6719316e03dbd7"}, + {file = "rpy2-3.5.14-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:ca95dee528d0a8032913de5fa85b8252b925f389aa2e2219d5314dcf43beeb1e"}, + {file = "rpy2-3.5.14.tar.gz", hash = "sha256:5f46ae31d36e117be366ad4ae02493c015ac6ba59ebe3b4cd7200075332fc481"}, +] + +[package.dependencies] +cffi = ">=1.10.0" +jinja2 = "*" +packaging = {version = "*", markers = "platform_system == \"Windows\""} +tzlocal = "*" + +[package.extras] +all = ["ipython", "numpy", "pandas (>=1.3.5)", "pytest"] +pandas = ["numpy", "pandas (>=1.3.5)"] +test = ["ipython", "numpy", "pandas (>=1.3.5)", "pytest"] +test-minimal = ["coverage", "pytest", "pytest-cov"] +types = ["mypy", "types-tzlocal"] + [[package]] name = "scikit-learn" version = "1.3.0" @@ -1716,7 +1817,7 @@ files = [ name = "snowballstemmer" version = "2.2.0" description = "This package provides 29 stemmers for 28 languages generated from Snowball algorithms." -optional = false +optional = true python-versions = "*" files = [ {file = "snowballstemmer-2.2.0-py2.py3-none-any.whl", hash = "sha256:c8e1716e83cc398ae16824e5572ae04e0d9fc2c6b985fb0f900f5f0c96ecba1a"}, @@ -1727,7 +1828,7 @@ files = [ name = "soupsieve" version = "2.4.1" description = "A modern CSS selector implementation for Beautiful Soup." -optional = false +optional = true python-versions = ">=3.7" files = [ {file = "soupsieve-2.4.1-py3-none-any.whl", hash = "sha256:1c1bfee6819544a3447586c889157365a27e10d88cde3ad3da0cf0ddf646feb8"}, @@ -1738,7 +1839,7 @@ files = [ name = "sphinx" version = "7.1.1" description = "Python documentation generator" -optional = false +optional = true python-versions = ">=3.8" files = [ {file = "sphinx-7.1.1-py3-none-any.whl", hash = "sha256:4e6c5ea477afa0fb90815210fd1312012e1d7542589ab251ac9b53b7c0751bce"}, @@ -1772,7 +1873,7 @@ test = ["cython", "filelock", "html5lib", "pytest (>=4.6)"] name = "sphinx-copybutton" version = "0.5.2" description = "Add a copy button to each of your code cells." -optional = false +optional = true python-versions = ">=3.7" files = [ {file = "sphinx-copybutton-0.5.2.tar.gz", hash = "sha256:4cf17c82fb9646d1bc9ca92ac280813a3b605d8c421225fd9913154103ee1fbd"}, @@ -1790,7 +1891,7 @@ rtd = ["ipython", "myst-nb", "sphinx", "sphinx-book-theme", "sphinx-examples"] name = "sphinx-design" version = "0.5.0" description = "A sphinx extension for designing beautiful, view size responsive web components." -optional = false +optional = true python-versions = ">=3.8" files = [ {file = "sphinx_design-0.5.0-py3-none-any.whl", hash = "sha256:1af1267b4cea2eedd6724614f19dcc88fe2e15aff65d06b2f6252cee9c4f4c1e"}, @@ -1811,13 +1912,13 @@ theme-sbt = ["sphinx-book-theme (>=1.0,<2.0)"] [[package]] name = "sphinx-gallery" -version = "0.13.0" +version = "0.14.0" description = "A `Sphinx `_ extension that builds an HTML gallery of examples from any set of Python scripts." -optional = false +optional = true python-versions = ">=3.7" files = [ - {file = "sphinx-gallery-0.13.0.tar.gz", hash = "sha256:4756f92e079128b08cbc7a57922cc904b3d442b1abfa73ec6471ad24f3c5b4b2"}, - {file = "sphinx_gallery-0.13.0-py3-none-any.whl", hash = "sha256:5bedfa4998b4158d5affc7d1df6796e4b1e834b16680001dac992af1304d8ed9"}, + {file = "sphinx-gallery-0.14.0.tar.gz", hash = "sha256:2a4a0aaf032955508e1d0f3495199a3c7819ce420e71096bff0bca551a4043c2"}, + {file = "sphinx_gallery-0.14.0-py3-none-any.whl", hash = "sha256:55b3ad1f378abd126232c166192270ac0a3ef615dec10b66c961ed2967be1df6"}, ] [package.dependencies] @@ -1827,7 +1928,7 @@ sphinx = ">=4" name = "sphinxcontrib-applehelp" version = "1.0.4" description = "sphinxcontrib-applehelp is a Sphinx extension which outputs Apple help books" -optional = false +optional = true python-versions = ">=3.8" files = [ {file = "sphinxcontrib-applehelp-1.0.4.tar.gz", hash = "sha256:828f867945bbe39817c210a1abfd1bc4895c8b73fcaade56d45357a348a07d7e"}, @@ -1842,7 +1943,7 @@ test = ["pytest"] name = "sphinxcontrib-devhelp" version = "1.0.2" description = "sphinxcontrib-devhelp is a sphinx extension which outputs Devhelp document." -optional = false +optional = true python-versions = ">=3.5" files = [ {file = "sphinxcontrib-devhelp-1.0.2.tar.gz", hash = "sha256:ff7f1afa7b9642e7060379360a67e9c41e8f3121f2ce9164266f61b9f4b338e4"}, @@ -1857,7 +1958,7 @@ test = ["pytest"] name = "sphinxcontrib-htmlhelp" version = "2.0.1" description = "sphinxcontrib-htmlhelp is a sphinx extension which renders HTML help files" -optional = false +optional = true python-versions = ">=3.8" files = [ {file = "sphinxcontrib-htmlhelp-2.0.1.tar.gz", hash = "sha256:0cbdd302815330058422b98a113195c9249825d681e18f11e8b1f78a2f11efff"}, @@ -1872,7 +1973,7 @@ test = ["html5lib", "pytest"] name = "sphinxcontrib-jsmath" version = "1.0.1" description = "A sphinx extension which renders display math in HTML via JavaScript" -optional = false +optional = true python-versions = ">=3.5" files = [ {file = "sphinxcontrib-jsmath-1.0.1.tar.gz", hash = "sha256:a9925e4a4587247ed2191a22df5f6970656cb8ca2bd6284309578f2153e0c4b8"}, @@ -1886,7 +1987,7 @@ test = ["flake8", "mypy", "pytest"] name = "sphinxcontrib-qthelp" version = "1.0.3" description = "sphinxcontrib-qthelp is a sphinx extension which outputs QtHelp document." -optional = false +optional = true python-versions = ">=3.5" files = [ {file = "sphinxcontrib-qthelp-1.0.3.tar.gz", hash = "sha256:4c33767ee058b70dba89a6fc5c1892c0d57a54be67ddd3e7875a18d14cba5a72"}, @@ -1901,7 +2002,7 @@ test = ["pytest"] name = "sphinxcontrib-serializinghtml" version = "1.1.5" description = "sphinxcontrib-serializinghtml is a sphinx extension which outputs \"serialized\" HTML files (json and pickle)." -optional = false +optional = true python-versions = ">=3.5" files = [ {file = "sphinxcontrib-serializinghtml-1.1.5.tar.gz", hash = "sha256:aa5f6de5dfdf809ef505c4895e51ef5c9eac17d0f287933eb49ec495280b6952"}, @@ -1972,7 +2073,7 @@ files = [ name = "tinycss2" version = "1.2.1" description = "A tiny CSS parser" -optional = false +optional = true python-versions = ">=3.7" files = [ {file = "tinycss2-1.2.1-py3-none-any.whl", hash = "sha256:2b80a96d41e7c3914b8cda8bc7f705a4d9c49275616e886103dd839dfc847847"}, @@ -2012,7 +2113,7 @@ files = [ name = "tornado" version = "6.3.2" description = "Tornado is a Python web framework and asynchronous networking library, originally developed at FriendFeed." -optional = false +optional = true python-versions = ">= 3.8" files = [ {file = "tornado-6.3.2-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:c367ab6c0393d71171123ca5515c61ff62fe09024fa6bf299cd1339dc9456829"}, @@ -2052,7 +2153,7 @@ telegram = ["requests"] name = "traitlets" version = "5.9.0" description = "Traitlets Python configuration system" -optional = false +optional = true python-versions = ">=3.7" files = [ {file = "traitlets-5.9.0-py3-none-any.whl", hash = "sha256:9e6ec080259b9a5940c797d58b613b5e31441c2257b87c2e795c5228ae80d2d8"}, @@ -2065,13 +2166,13 @@ test = ["argcomplete (>=2.0)", "pre-commit", "pytest", "pytest-mock"] [[package]] name = "typing-extensions" -version = "4.7.1" -description = "Backported and Experimental Type Hints for Python 3.7+" +version = "4.8.0" +description = "Backported and Experimental Type Hints for Python 3.8+" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "typing_extensions-4.7.1-py3-none-any.whl", hash = "sha256:440d5dd3af93b060174bf433bccd69b0babc3b15b1a8dca43789fd7f61514b36"}, - {file = "typing_extensions-4.7.1.tar.gz", hash = "sha256:b75ddc264f0ba5615db7ba217daeb99701ad295353c45f9e95963337ceeeffb2"}, + {file = "typing_extensions-4.8.0-py3-none-any.whl", hash = "sha256:8f92fc8806f9a6b641eaa5318da32b44d401efaac0f6678c9bc448ba3605faa0"}, + {file = "typing_extensions-4.8.0.tar.gz", hash = "sha256:df8e4339e9cb77357558cbdbceca33c303714cf861d1eef15e1070055ae8b7ef"}, ] [[package]] @@ -2085,6 +2186,23 @@ files = [ {file = "tzdata-2023.3.tar.gz", hash = "sha256:11ef1e08e54acb0d4f95bdb1be05da659673de4acbd21bf9c69e94cc5e907a3a"}, ] +[[package]] +name = "tzlocal" +version = "5.1" +description = "tzinfo object for the local timezone" +optional = true +python-versions = ">=3.7" +files = [ + {file = "tzlocal-5.1-py3-none-any.whl", hash = "sha256:2938498395d5f6a898ab8009555cb37a4d360913ad375d4747ef16826b03ef23"}, + {file = "tzlocal-5.1.tar.gz", hash = "sha256:a5ccb2365b295ed964e0a98ad076fe10c495591e75505d34f154d60a7f1ed722"}, +] + +[package.dependencies] +tzdata = {version = "*", markers = "platform_system == \"Windows\""} + +[package.extras] +devenv = ["black", "check-manifest", "flake8", "pyroma", "pytest (>=4.3)", "pytest-cov", "pytest-mock (>=3.3)", "zest.releaser"] + [[package]] name = "urllib3" version = "2.0.4" @@ -2106,7 +2224,7 @@ zstd = ["zstandard (>=0.18.0)"] name = "webencodings" version = "0.5.1" description = "Character encoding aliases for legacy web content" -optional = false +optional = true python-versions = "*" files = [ {file = "webencodings-0.5.1-py2.py3-none-any.whl", hash = "sha256:a0af1213f3c2226497a97e2b3aa01a7e4bee4f403f95be16fc9acd2947514a78"}, @@ -2155,4 +2273,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "a52384add5c7f9d836bb6992f977d83826388c587ec0f1fd4992bed704de743c" +content-hash = "7b7b136975a0e65a82eb4ad0de78ee240bd719a40f902a1d60c7ce52467491f2" diff --git a/pyproject.toml b/pyproject.toml index a5c0cd8..f34c14c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ documentation = "https://xeofs.readthedocs.io/en/latest/" [tool.poetry.dependencies] python = "^3.10" -numpy = ">=1.19.2" +numpy = "~1.24" pandas = ">=1.4.1" xarray = ">=0.21.1" scikit-learn = ">=1.0.2" @@ -19,21 +19,23 @@ pooch = "^1.6.0" tqdm = "^4.64.0" dask = ">=2023.0.1" statsmodels = ">=0.14.0" +netCDF4 = "^1.5.7" +numba = "^0.57" +typing-extensions = "^4.8.0" -[tool.poetry.dev-dependencies] +[tool.poetry.group.dev.dependencies] flake8 = "^4.0.1" pytest = "^7.0.1" coverage = "^6.3.1" -netCDF4 = "^1.5.7" -sphinx-gallery = "^0" -sphinx-design = "^0" -nbsphinx = "^0" -sphinx-copybutton = "^0" -pydata-sphinx-theme = "^0" +black = "~23.7.0" - -[tool.poetry.group.dev.dependencies] -black = "~23.7.0" +[tool.poetry.group.docs.dependencies] +rpy2 = {version = ">=3.5", optional = true} +sphinx-gallery = {version = "^0", optional = true} +sphinx-design = {version = "^0", optional = true} +sphinx-copybutton = {version = "^0", optional = true} +nbsphinx = {version = "^0", optional = true} +pydata-sphinx-theme = {version = "^0", optional = true} [build-system] requires = ["setuptools", "poetry-core>=1.0.0"] diff --git a/tests/conftest.py b/tests/conftest.py index 6bdac1b..7927083 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,11 +4,222 @@ import pytest import warnings import xarray as xr +import pandas as pd + +from xeofs.utils.data_types import DataArray, DataSet, DataList warnings.filterwarnings("ignore", message="numpy.dtype size changed") warnings.filterwarnings("ignore", message="numpy.ufunc size changed") +# ============================================================================= +# Synthetic data +# ============================================================================= +def generate_synthetic_dataarray( + n_sample=1, + n_feature=1, + index_policy="index", + nan_policy="no_nan", + dask_policy="no_dask", + seed=0, +) -> DataArray: + """Create synthetic DataArray. + + Parameters: + ------------ + n_sample: int + Number of sample dimensions. + n_dims_feature: int + Number of feature dimensions. + index_policy: ["index", "multiindex"], default="index" + If "multiindex", the data will have a multiindex. + nan_policy: ["no_nan", "isolated", "fulldim"], default="no_nan" + If specified, the data will contain NaNs. + dsak_policy: ["no_dask", "dask"], default="no_dask" + If "dask", the data will be a dask array. + seed: int, default=0 + Seed for the random number generator. + + Returns: + --------- + data: xr.DataArray + Synthetic data. + + """ + rng = np.random.default_rng(seed) + + # Create dimensions + sample_dims = [f"sample{i}" for i in range(n_sample)] + feature_dims = [f"feature{i}" for i in range(n_feature)] + all_dims = feature_dims + sample_dims + + # Create coordinates/indices + coords = {} + for i, dim in enumerate(all_dims): + if index_policy == "multiindex": + coords[dim] = pd.MultiIndex.from_arrays( + [np.arange(6 - i), np.arange(6 - i)], + names=[f"index{i}a", f"index{i}b"], + ) + elif index_policy == "index": + coords[dim] = np.arange(6 + i) + else: + raise ValueError(f"Invalid value for index_policy: {index_policy}") + + # Get data shape + shape = tuple([len(coords[dim]) for dim in all_dims]) + + # Create data + noise = rng.normal(5, 3, size=shape) + signal = 2 * np.sin(np.linspace(0, 2 * np.pi, shape[-1])) + signal = np.broadcast_to(signal, shape) + data = signal + noise + data = xr.DataArray(data, dims=all_dims, coords=coords) + + # Add NaNs + if nan_policy == "no_nan": + pass + elif nan_policy == "isolated": + isolated_point = {dim: 0 for dim in all_dims} + data.loc[isolated_point] = np.nan + elif nan_policy == "fulldim": + fulldim_point = {dim: 0 for dim in feature_dims} + data.loc[fulldim_point] = np.nan + else: + raise ValueError(f"Invalid value for nan_policy: {nan_policy}") + + # Convert to dask array + if dask_policy == "no_dask": + pass + elif dask_policy == "dask": + data = data.chunk({"sample0": 1}) + else: + raise ValueError(f"Invalid value for dask_policy: {dask_policy}") + + return data + + +def generate_synthetic_dataset( + n_variables=1, + n_sample=1, + n_feature=1, + index_policy="index", + nan_policy="no_nan", + dask_policy="no_dask", + seed=0, +) -> DataSet: + """Create synthetic Dataset. + + Parameters: + ------------ + n_variables: int + Number of variables. + n_sample: int + Number of sample dimensions. + n_dims_feature: int + Number of feature dimensions. + index_policy: ["index", "multiindex"], default="index" + If "multiindex", the data will have a multiindex. + nan_policy: ["no_nan", "isolated", "fulldim"], default="no_nan" + If specified, the data will contain NaNs. + dask_policy: ["no_dask", "dask"], default="no_dask" + If "dask", the data will be a dask array. + seed: int, default=0 + Seed for the random number generator. + + Returns: + --------- + data: xr.Dataset + Synthetic data. + + """ + data = generate_synthetic_dataarray( + n_sample, n_feature, index_policy, nan_policy, dask_policy, seed + ) + dataset = xr.Dataset({"var0": data}) + seed += 1 + + for n in range(1, n_variables): + data_n = generate_synthetic_dataarray( + n_sample=n_sample, + n_feature=n_feature, + index_policy=index_policy, + nan_policy=nan_policy, + dask_policy=dask_policy, + seed=seed, + ) + dataset[f"var{n}"] = data_n + seed += 1 + return dataset + + +def generate_list_of_synthetic_dataarrays( + n_arrays=1, + n_sample=1, + n_feature=1, + index_policy="index", + nan_policy="no_nan", + dask_policy="no_dask", + seed=0, +) -> DataList: + """Create synthetic Dataset. + + Parameters: + ------------ + n_arrays: int + Number of DataArrays. + n_sample: int + Number of sample dimensions. + n_dims_feature: int + Number of feature dimensions. + index_policy: ["index", "multiindex"], default="index" + If "multiindex", the data will have a multiindex. + nan_policy: ["no_nan", "isolated", "fulldim"], default="no_nan" + If specified, the data will contain NaNs. + dask_policy: ["no_dask", "dask"], default="no_dask" + If "dask", the data will be a dask array. + seed: int, default=0 + Seed for the random number generator. + + Returns: + --------- + data: xr.Dataset + Synthetic data. + + """ + data_arrays = [] + for n in range(n_arrays): + data_n = generate_synthetic_dataarray( + n_sample=n_sample, + n_feature=n_feature, + index_policy=index_policy, + nan_policy=nan_policy, + dask_policy=dask_policy, + seed=seed, + ) + data_arrays.append(data_n) + seed += 1 + return data_arrays + + +@pytest.fixture +def synthetic_dataarray(request) -> DataArray: + data = generate_synthetic_dataarray(*request.param) + return data + + +@pytest.fixture +def synthetic_dataset(request) -> DataSet: + data = generate_synthetic_dataset(*request.param) + return data + + +@pytest.fixture +def synthetic_datalist(request) -> DataList: + data = generate_list_of_synthetic_dataarrays(*request.param) + return data + + # ============================================================================= # Input data # ============================================================================= diff --git a/tests/data_container/test_base_cross_model_data_container.py b/tests/data_container/test_base_cross_model_data_container.py deleted file mode 100644 index 2b9068d..0000000 --- a/tests/data_container/test_base_cross_model_data_container.py +++ /dev/null @@ -1,76 +0,0 @@ -import pytest -import xarray as xr -import numpy as np - -from xeofs.data_container._base_cross_model_data_container import ( - _BaseCrossModelDataContainer, -) - - -def test_init(): - """Test the initialization of the BaseCrossModelDataContainer.""" - data_container = _BaseCrossModelDataContainer() - assert data_container._input_data1 is None - assert data_container._input_data2 is None - assert data_container._components1 is None - assert data_container._components2 is None - assert data_container._scores1 is None - assert data_container._scores2 is None - - -def test_set_data(sample_input_data, sample_components, sample_scores): - """Test the set_data() method.""" - data_container = _BaseCrossModelDataContainer() - data_container.set_data( - sample_input_data, - sample_input_data, - sample_components, - sample_components, - sample_scores, - sample_scores, - ) - assert data_container._input_data1 is sample_input_data - assert data_container._input_data2 is sample_input_data - assert data_container._components1 is sample_components - assert data_container._components2 is sample_components - assert data_container._scores1 is sample_scores - assert data_container._scores2 is sample_scores - - -def test_no_data(): - """Test the data accessors without data.""" - data_container = _BaseCrossModelDataContainer() - with pytest.raises(ValueError): - data_container.input_data1 - with pytest.raises(ValueError): - data_container.input_data2 - with pytest.raises(ValueError): - data_container.components1 - with pytest.raises(ValueError): - data_container.components2 - with pytest.raises(ValueError): - data_container.scores1 - with pytest.raises(ValueError): - data_container.scores2 - with pytest.raises(ValueError): - data_container.set_attrs({"test": 1}) - with pytest.raises(ValueError): - data_container.compute() - - -def test_set_attrs(sample_input_data, sample_components, sample_scores): - """Test the set_attrs() method.""" - data_container = _BaseCrossModelDataContainer() - data_container.set_data( - sample_input_data, - sample_input_data, - sample_components, - sample_components, - sample_scores, - sample_scores, - ) - data_container.set_attrs({"test": 1}) - assert data_container.components1.attrs["test"] == 1 - assert data_container.components2.attrs["test"] == 1 - assert data_container.scores1.attrs["test"] == 1 - assert data_container.scores2.attrs["test"] == 1 diff --git a/tests/data_container/test_base_model_data_container.py b/tests/data_container/test_base_model_data_container.py deleted file mode 100644 index 50ee347..0000000 --- a/tests/data_container/test_base_model_data_container.py +++ /dev/null @@ -1,46 +0,0 @@ -import pytest -import xarray as xr -import numpy as np - -from xeofs.data_container._base_model_data_container import _BaseModelDataContainer - - -def test_init(): - """Test the initialization of the BaseModelDataContainer.""" - data_container = _BaseModelDataContainer() - assert data_container._input_data is None - assert data_container._components is None - assert data_container._scores is None - - -def test_set_data(sample_input_data, sample_components, sample_scores): - """Test the set_data() method.""" - data_container = _BaseModelDataContainer() - data_container.set_data(sample_input_data, sample_components, sample_scores) - assert data_container._input_data is sample_input_data - assert data_container._components is sample_components - assert data_container._scores is sample_scores - - -def test_no_data(): - """Test the data accessors without data.""" - data_container = _BaseModelDataContainer() - with pytest.raises(ValueError): - data_container.input_data - with pytest.raises(ValueError): - data_container.components - with pytest.raises(ValueError): - data_container.scores - with pytest.raises(ValueError): - data_container.set_attrs({"test": 1}) - with pytest.raises(ValueError): - data_container.compute() - - -def test_set_attrs(sample_input_data, sample_components, sample_scores): - """Test the set_attrs() method.""" - data_container = _BaseModelDataContainer() - data_container.set_data(sample_input_data, sample_components, sample_scores) - data_container.set_attrs({"test": 1}) - assert data_container.components.attrs["test"] == 1 - assert data_container.scores.attrs["test"] == 1 diff --git a/tests/data_container/test_complex_eof_data_container.py b/tests/data_container/test_complex_eof_data_container.py deleted file mode 100644 index 612436d..0000000 --- a/tests/data_container/test_complex_eof_data_container.py +++ /dev/null @@ -1,131 +0,0 @@ -import pytest -import numpy as np -import xarray as xr -from dask.array import Array as DaskArray # type: ignore - - -from xeofs.data_container.eof_data_container import ComplexEOFDataContainer - - -def test_init(): - """Test the initialization of the ComplexEOFDataContainer.""" - container = ComplexEOFDataContainer() - assert container._input_data is None - assert container._components is None - assert container._scores is None - assert container._explained_variance is None - assert container._total_variance is None - assert container._idx_modes_sorted is None - - -def test_set_data( - sample_input_data, - sample_components, - sample_scores, - sample_exp_var, - sample_total_variance, - sample_idx_modes_sorted, -): - """Test the set_data() method.""" - - container = ComplexEOFDataContainer() - container.set_data( - sample_input_data, - sample_components, - sample_scores, - sample_exp_var, - sample_total_variance, - sample_idx_modes_sorted, - ) - total_variance = sample_exp_var.sum() - idx_modes_sorted = sample_exp_var.argsort()[::-1] - container.set_data( - input_data=sample_input_data, - components=sample_components, - scores=sample_scores, - explained_variance=sample_exp_var, - total_variance=total_variance, - idx_modes_sorted=idx_modes_sorted, - ) - assert container._input_data is sample_input_data - assert container._components is sample_components - assert container._scores is sample_scores - assert container._explained_variance is sample_exp_var - assert container._total_variance is total_variance - assert container._idx_modes_sorted is idx_modes_sorted - - -def test_no_data(): - """Test the data accessors without data.""" - container = ComplexEOFDataContainer() - with pytest.raises(ValueError): - container.input_data - with pytest.raises(ValueError): - container.components - with pytest.raises(ValueError): - container.scores - with pytest.raises(ValueError): - container.explained_variance - with pytest.raises(ValueError): - container.total_variance - with pytest.raises(ValueError): - container.idx_modes_sorted - with pytest.raises(ValueError): - container.set_attrs({"test": 1}) - with pytest.raises(ValueError): - container.compute() - - -def test_set_attrs(sample_input_data, sample_components, sample_scores, sample_exp_var): - """Test the set_attrs() method.""" - total_variance = sample_exp_var.chunk({"mode": 2}).sum() - idx_modes_sorted = sample_exp_var.argsort()[::-1] - container = ComplexEOFDataContainer() - container.set_data( - sample_input_data, - sample_components, - sample_scores, - sample_exp_var, - total_variance, - idx_modes_sorted, - ) - container.set_attrs({"test": 1}) - assert container.components.attrs["test"] == 1 - assert container.scores.attrs["test"] == 1 - assert container.explained_variance.attrs["test"] == 1 - assert container.explained_variance_ratio.attrs["test"] == 1 - assert container.singular_values.attrs["test"] == 1 - assert container.total_variance.attrs["test"] == 1 - assert container.idx_modes_sorted.attrs["test"] == 1 - - -def test_compute(sample_input_data, sample_components, sample_scores, sample_exp_var): - """Check that dask arrays are computed correctly.""" - total_variance = sample_exp_var.chunk({"mode": 2}).sum() - idx_modes_sorted = sample_exp_var.argsort()[::-1] - container = ComplexEOFDataContainer() - container.set_data( - sample_input_data.chunk({"sample": 2}), - sample_components.chunk({"feature": 2}), - sample_scores.chunk({"sample": 2}), - sample_exp_var.chunk({"mode": 2}), - total_variance, - idx_modes_sorted, - ) - # The components and scores are dask arrays - assert isinstance(container.input_data.data, DaskArray) - assert isinstance(container.components.data, DaskArray) - assert isinstance(container.scores.data, DaskArray) - assert isinstance(container.explained_variance.data, DaskArray) - assert isinstance(container.total_variance.data, DaskArray) - - container.compute() - - # The components and scores are computed correctly - assert isinstance( - container.input_data.data, DaskArray - ), "input_data should still be a dask array" - assert isinstance(container.components.data, np.ndarray) - assert isinstance(container.scores.data, np.ndarray) - assert isinstance(container.explained_variance.data, np.ndarray) - assert isinstance(container.total_variance.data, np.ndarray) diff --git a/tests/data_container/test_complex_eof_rotator_data_container.py b/tests/data_container/test_complex_eof_rotator_data_container.py deleted file mode 100644 index a32b91e..0000000 --- a/tests/data_container/test_complex_eof_rotator_data_container.py +++ /dev/null @@ -1,51 +0,0 @@ -import pytest -import numpy as np -import xarray as xr -from dask.array import Array as DaskArray # type: ignore - -from xeofs.data_container.eof_rotator_data_container import ( - ComplexEOFRotatorDataContainer, -) - - -def test_init(): - """Test the initialization of the ComplexEOFRotatorDataContainer.""" - container = ComplexEOFRotatorDataContainer() - assert container._rotation_matrix is None - assert container._phi_matrix is None - assert container._modes_sign is None - - -def test_set_data( - sample_input_data, - sample_components, - sample_scores, - sample_exp_var, - sample_rotation_matrix, - sample_phi_matrix, - sample_modes_sign, -): - """Test the set_data() method of ComplexEOFRotatorDataContainer.""" - total_variance = sample_exp_var.sum() - idx_modes_sorted = sample_exp_var.argsort()[::-1] - container = ComplexEOFRotatorDataContainer() - container.set_data( - sample_input_data, - sample_components, - sample_scores, - sample_exp_var, - total_variance, - idx_modes_sorted, - sample_rotation_matrix, - sample_phi_matrix, - sample_modes_sign, - ) - assert container._input_data is sample_input_data - assert container._components is sample_components - assert container._scores is sample_scores - assert container._explained_variance is sample_exp_var - assert container._total_variance is total_variance - assert container._idx_modes_sorted is idx_modes_sorted - assert container._modes_sign is sample_modes_sign - assert container._rotation_matrix is sample_rotation_matrix - assert container._phi_matrix is sample_phi_matrix diff --git a/tests/data_container/test_complex_mca_rotator_data_container.py b/tests/data_container/test_complex_mca_rotator_data_container.py deleted file mode 100644 index 93f25a0..0000000 --- a/tests/data_container/test_complex_mca_rotator_data_container.py +++ /dev/null @@ -1,116 +0,0 @@ -import pytest -import xarray as xr -import numpy as np - -from xeofs.data_container.mca_rotator_data_container import ( - ComplexMCARotatorDataContainer, -) -from .test_mca_rotator_data_container import test_init as test_mca_rotator_init -from .test_mca_rotator_data_container import test_set_data as test_mca_rotator_set_data -from .test_mca_rotator_data_container import test_no_data as test_mca_rotator_no_data -from .test_mca_rotator_data_container import ( - test_set_attrs as test_mca_rotator_set_attrs, -) - - -def test_init(): - """Test the initialization of the ComplexMCARotatorDataContainer.""" - data_container = ComplexMCARotatorDataContainer() - test_mca_rotator_init() # Re-use the test from MCARotatorDataContainer. - - -def test_set_data( - sample_input_data, - sample_components, - sample_scores, - sample_squared_covariance, - sample_total_squared_covariance, - sample_idx_modes_sorted, - sample_norm, - sample_rotation_matrix, - sample_phi_matrix, - sample_modes_sign, -): - """Test the set_data() method of ComplexMCARotatorDataContainer.""" - data_container = ComplexMCARotatorDataContainer() - data_container.set_data( - sample_input_data, - sample_input_data, - sample_components, - sample_components, - sample_scores, - sample_scores, - sample_squared_covariance, - sample_total_squared_covariance, - sample_idx_modes_sorted, - sample_modes_sign, - sample_norm, - sample_norm, - sample_rotation_matrix, - sample_phi_matrix, - ) - - test_mca_rotator_set_data( - sample_input_data, - sample_components, - sample_scores, - sample_squared_covariance, - sample_total_squared_covariance, - sample_idx_modes_sorted, - sample_norm, - sample_rotation_matrix, - sample_phi_matrix, - sample_modes_sign, - ) # Re-use the test from MCARotatorDataContainer. - - -def test_no_data(): - """Test the data accessors without data in ComplexMCARotatorDataContainer.""" - data_container = ComplexMCARotatorDataContainer() - test_mca_rotator_no_data() # Re-use the test from MCARotatorDataContainer. - - -def test_set_attrs( - sample_input_data, - sample_components, - sample_scores, - sample_squared_covariance, - sample_total_squared_covariance, - sample_idx_modes_sorted, - sample_norm, - sample_rotation_matrix, - sample_phi_matrix, - sample_modes_sign, -): - """Test the set_attrs() method of ComplexMCARotatorDataContainer.""" - data_container = ComplexMCARotatorDataContainer() - data_container.set_data( - sample_input_data, - sample_input_data, - sample_components, - sample_components, - sample_scores, - sample_scores, - sample_squared_covariance, - sample_total_squared_covariance, - sample_idx_modes_sorted, - sample_modes_sign, - sample_norm, - sample_norm, - sample_rotation_matrix, - sample_phi_matrix, - ) - data_container.set_attrs({"test": 1}) - - test_mca_rotator_set_attrs( - sample_input_data, - sample_components, - sample_scores, - sample_squared_covariance, - sample_total_squared_covariance, - sample_idx_modes_sorted, - sample_norm, - sample_rotation_matrix, - sample_phi_matrix, - sample_modes_sign, - ) # Re-use the test from MCARotatorDataContainer. diff --git a/tests/data_container/test_eof_data_container.py b/tests/data_container/test_eof_data_container.py deleted file mode 100644 index ca8fd46..0000000 --- a/tests/data_container/test_eof_data_container.py +++ /dev/null @@ -1,125 +0,0 @@ -import pytest -import numpy as np -import xarray as xr -from dask.array import Array as DaskArray # type: ignore - - -from xeofs.data_container.eof_data_container import EOFDataContainer - - -def test_init(): - """Test the initialization of the EOFDataContainer.""" - container = EOFDataContainer() - assert container._input_data is None - assert container._components is None - assert container._scores is None - assert container._explained_variance is None - assert container._total_variance is None - assert container._idx_modes_sorted is None - - -def test_set_data(sample_input_data, sample_components, sample_scores, sample_exp_var): - """Test the set_data() method.""" - total_variance = sample_exp_var.sum() - idx_modes_sorted = sample_exp_var.argsort()[::-1] - container = EOFDataContainer() - container.set_data( - sample_input_data, - sample_components, - sample_scores, - sample_exp_var, - total_variance, - idx_modes_sorted, - ) - total_variance = sample_exp_var.sum() - idx_modes_sorted = sample_exp_var.argsort()[::-1] - container.set_data( - input_data=sample_input_data, - components=sample_components, - scores=sample_scores, - explained_variance=sample_exp_var, - total_variance=total_variance, - idx_modes_sorted=idx_modes_sorted, - ) - assert container._input_data is sample_input_data - assert container._components is sample_components - assert container._scores is sample_scores - assert container._explained_variance is sample_exp_var - assert container._total_variance is total_variance - assert container._idx_modes_sorted is idx_modes_sorted - - -def test_no_data(): - """Test the data accessors without data.""" - container = EOFDataContainer() - with pytest.raises(ValueError): - container.input_data - with pytest.raises(ValueError): - container.components - with pytest.raises(ValueError): - container.scores - with pytest.raises(ValueError): - container.explained_variance - with pytest.raises(ValueError): - container.total_variance - with pytest.raises(ValueError): - container.idx_modes_sorted - with pytest.raises(ValueError): - container.set_attrs({"test": 1}) - with pytest.raises(ValueError): - container.compute() - - -def test_set_attrs(sample_input_data, sample_components, sample_scores, sample_exp_var): - """Test the set_attrs() method.""" - total_variance = sample_exp_var.chunk({"mode": 2}).sum() - idx_modes_sorted = sample_exp_var.argsort()[::-1] - container = EOFDataContainer() - container.set_data( - sample_input_data, - sample_components, - sample_scores, - sample_exp_var, - total_variance, - idx_modes_sorted, - ) - container.set_attrs({"test": 1}) - assert container.components.attrs["test"] == 1 - assert container.scores.attrs["test"] == 1 - assert container.explained_variance.attrs["test"] == 1 - assert container.explained_variance_ratio.attrs["test"] == 1 - assert container.singular_values.attrs["test"] == 1 - assert container.total_variance.attrs["test"] == 1 - assert container.idx_modes_sorted.attrs["test"] == 1 - - -def test_compute(sample_input_data, sample_components, sample_scores, sample_exp_var): - """Check that dask arrays are computed correctly.""" - total_variance = sample_exp_var.chunk({"mode": 2}).sum() - idx_modes_sorted = sample_exp_var.argsort()[::-1] - container = EOFDataContainer() - container.set_data( - sample_input_data.chunk({"sample": 2}), - sample_components.chunk({"feature": 2}), - sample_scores.chunk({"sample": 2}), - sample_exp_var.chunk({"mode": 2}), - total_variance, - idx_modes_sorted, - ) - # The components and scores are dask arrays - assert isinstance(container.input_data.data, DaskArray) - assert isinstance(container.components.data, DaskArray) - assert isinstance(container.scores.data, DaskArray) - assert isinstance(container.explained_variance.data, DaskArray) - assert isinstance(container.total_variance.data, DaskArray) - - container.compute() - - # The components and scores are computed correctly - assert isinstance( - container.input_data.data, DaskArray - ), "input_data should still be a dask array" - assert isinstance(container.components.data, np.ndarray) - assert isinstance(container.scores.data, np.ndarray) - assert isinstance(container.explained_variance.data, np.ndarray) - assert isinstance(container.total_variance.data, np.ndarray) diff --git a/tests/data_container/test_eof_rotator_data_container.py b/tests/data_container/test_eof_rotator_data_container.py deleted file mode 100644 index 5e55b3a..0000000 --- a/tests/data_container/test_eof_rotator_data_container.py +++ /dev/null @@ -1,76 +0,0 @@ -import pytest -import numpy as np -import xarray as xr -from dask.array import Array as DaskArray # type: ignore - -from xeofs.data_container.eof_rotator_data_container import EOFRotatorDataContainer - - -def test_init(): - """Test the initialization of the EOFRotatorDataContainer.""" - container = EOFRotatorDataContainer() - assert container._rotation_matrix is None - assert container._phi_matrix is None - assert container._modes_sign is None - - -def test_set_data( - sample_input_data, - sample_components, - sample_scores, - sample_exp_var, - sample_rotation_matrix, - sample_phi_matrix, - sample_modes_sign, -): - """Test the set_data() method of EOFRotatorDataContainer.""" - total_variance = sample_exp_var.sum() - idx_modes_sorted = sample_exp_var.argsort()[::-1] - container = EOFRotatorDataContainer() - container.set_data( - sample_input_data, - sample_components, - sample_scores, - sample_exp_var, - total_variance, - idx_modes_sorted, - sample_modes_sign, - sample_rotation_matrix, - sample_phi_matrix, - ) - assert container._input_data is sample_input_data - assert container._components is sample_components - assert container._scores is sample_scores - assert container._explained_variance is sample_exp_var - assert container._total_variance is total_variance - assert container._idx_modes_sorted is idx_modes_sorted - assert container._modes_sign is sample_modes_sign - assert container._rotation_matrix is sample_rotation_matrix - assert container._phi_matrix is sample_phi_matrix - - -def test_no_data(): - """Test the data accessors without data for EOFRotatorDataContainer.""" - container = EOFRotatorDataContainer() - with pytest.raises(ValueError): - container.input_data - with pytest.raises(ValueError): - container.components - with pytest.raises(ValueError): - container.scores - with pytest.raises(ValueError): - container.explained_variance - with pytest.raises(ValueError): - container.total_variance - with pytest.raises(ValueError): - container.idx_modes_sorted - with pytest.raises(ValueError): - container.modes_sign - with pytest.raises(ValueError): - container.rotation_matrix - with pytest.raises(ValueError): - container.phi_matrix - with pytest.raises(ValueError): - container.set_attrs({"test": 1}) - with pytest.raises(ValueError): - container.compute() diff --git a/tests/data_container/test_mca_data_container.py b/tests/data_container/test_mca_data_container.py deleted file mode 100644 index a0834b0..0000000 --- a/tests/data_container/test_mca_data_container.py +++ /dev/null @@ -1,136 +0,0 @@ -import pytest -import xarray as xr -import numpy as np - -from xeofs.data_container.mca_data_container import MCADataContainer - - -def test_init(): - """Test the initialization of the MCADataContainer.""" - data_container = MCADataContainer() - assert data_container._input_data1 is None - assert data_container._input_data2 is None - assert data_container._components1 is None - assert data_container._components2 is None - assert data_container._scores1 is None - assert data_container._scores2 is None - assert data_container._squared_covariance is None - assert data_container._total_squared_covariance is None - assert data_container._idx_modes_sorted is None - assert data_container._norm1 is None - assert data_container._norm2 is None - - -def test_set_data( - sample_input_data, - sample_components, - sample_scores, - sample_squared_covariance, - sample_total_squared_covariance, - sample_idx_modes_sorted, - sample_norm, -): - """Test the set_data() method of MCADataContainer.""" - data_container = MCADataContainer() - data_container.set_data( - sample_input_data, - sample_input_data, - sample_components, - sample_components, - sample_scores, - sample_scores, - sample_squared_covariance, - sample_total_squared_covariance, - sample_idx_modes_sorted, - sample_norm, - sample_norm, - ) - - assert data_container._input_data1 is sample_input_data - assert data_container._input_data2 is sample_input_data - assert data_container._components1 is sample_components - assert data_container._components2 is sample_components - assert data_container._scores1 is sample_scores - assert data_container._scores2 is sample_scores - assert data_container._squared_covariance is sample_squared_covariance - assert data_container._total_squared_covariance is sample_total_squared_covariance - assert data_container._idx_modes_sorted is sample_idx_modes_sorted - assert data_container._norm1 is sample_norm - assert data_container._norm2 is sample_norm - - -def test_no_data(): - """Test the data accessors without data in MCADataContainer.""" - data_container = MCADataContainer() - with pytest.raises(ValueError): - data_container.input_data1 - with pytest.raises(ValueError): - data_container.input_data2 - with pytest.raises(ValueError): - data_container.components1 - with pytest.raises(ValueError): - data_container.components2 - with pytest.raises(ValueError): - data_container.scores1 - with pytest.raises(ValueError): - data_container.scores2 - with pytest.raises(ValueError): - data_container.squared_covariance - with pytest.raises(ValueError): - data_container.total_squared_covariance - with pytest.raises(ValueError): - data_container.squared_covariance_fraction - with pytest.raises(ValueError): - data_container.singular_values - with pytest.raises(ValueError): - data_container.covariance_fraction - with pytest.raises(ValueError): - data_container.idx_modes_sorted - with pytest.raises(ValueError): - data_container.norm1 - with pytest.raises(ValueError): - data_container.norm2 - with pytest.raises(ValueError): - data_container.set_attrs({"test": 1}) - with pytest.raises(ValueError): - data_container.compute() - - -def test_set_attrs( - sample_input_data, - sample_components, - sample_scores, - sample_squared_covariance, - sample_total_squared_covariance, - sample_idx_modes_sorted, - sample_norm, -): - """Test the set_attrs() method of MCADataContainer.""" - data_container = MCADataContainer() - data_container.set_data( - sample_input_data, - sample_input_data, - sample_components, - sample_components, - sample_scores, - sample_scores, - sample_squared_covariance, - sample_total_squared_covariance, - sample_idx_modes_sorted, - sample_norm, - sample_norm, - ) - data_container.set_attrs({"test": 1}) - - assert data_container.components1.attrs["test"] == 1 - assert data_container.components2.attrs["test"] == 1 - assert data_container.scores1.attrs["test"] == 1 - assert data_container.scores2.attrs["test"] == 1 - assert data_container.squared_covariance.attrs["test"] == 1 - assert data_container.total_squared_covariance.attrs["test"] == 1 - assert data_container.squared_covariance_fraction.attrs["test"] == 1 - assert data_container.singular_values.attrs["test"] == 1 - assert data_container.total_covariance.attrs["test"] == 1 - assert data_container.covariance_fraction.attrs["test"] == 1 - assert data_container.norm1.attrs["test"] == 1 - assert data_container.norm2.attrs["test"] == 1 diff --git a/tests/data_container/test_mca_rotator_data_container.py b/tests/data_container/test_mca_rotator_data_container.py deleted file mode 100644 index f618e83..0000000 --- a/tests/data_container/test_mca_rotator_data_container.py +++ /dev/null @@ -1,129 +0,0 @@ -import pytest -import xarray as xr -import numpy as np - -from xeofs.data_container.mca_rotator_data_container import MCARotatorDataContainer -from .test_mca_data_container import test_init as test_mca_init -from .test_mca_data_container import test_set_data as test_mca_set_data -from .test_mca_data_container import test_no_data as test_mca_no_data -from .test_mca_data_container import test_set_attrs as test_mca_set_attrs - -""" -The idea here is to reuse tests from MCADataContainer in MCARotatorDataContainer -and then tests from MCARotatorDataContainer in ComplexMCARotatorDataContainer, -while also testing the new functionality of each class. This way, we ensure that -inherited behavior still works as expected in subclasses. If some new tests fail, -we'll know it's due to the new functionality and not something inherited. -""" - - -def test_init(): - """Test the initialization of the MCARotatorDataContainer.""" - data_container = MCARotatorDataContainer() - test_mca_init() # Re-use the test from MCADataContainer. - assert data_container._rotation_matrix is None - assert data_container._phi_matrix is None - assert data_container._modes_sign is None - - -def test_set_data( - sample_input_data, - sample_components, - sample_scores, - sample_squared_covariance, - sample_total_squared_covariance, - sample_idx_modes_sorted, - sample_norm, - sample_rotation_matrix, - sample_phi_matrix, - sample_modes_sign, -): - """Test the set_data() method of MCARotatorDataContainer.""" - data_container = MCARotatorDataContainer() - data_container.set_data( - sample_input_data, - sample_input_data, - sample_components, - sample_components, - sample_scores, - sample_scores, - sample_squared_covariance, - sample_total_squared_covariance, - sample_idx_modes_sorted, - sample_modes_sign, - sample_norm, - sample_norm, - sample_rotation_matrix, - sample_phi_matrix, - ) - - test_mca_set_data( - sample_input_data, - sample_components, - sample_scores, - sample_squared_covariance, - sample_total_squared_covariance, - sample_idx_modes_sorted, - sample_norm, - ) # Re-use the test from MCADataContainer. - assert data_container._rotation_matrix is sample_rotation_matrix - assert data_container._phi_matrix is sample_phi_matrix - assert data_container._modes_sign is sample_modes_sign - - -def test_no_data(): - """Test the data accessors without data in MCARotatorDataContainer.""" - data_container = MCARotatorDataContainer() - test_mca_no_data() # Re-use the test from MCADataContainer. - with pytest.raises(ValueError): - data_container.rotation_matrix - with pytest.raises(ValueError): - data_container.phi_matrix - with pytest.raises(ValueError): - data_container.modes_sign - - -def test_set_attrs( - sample_input_data, - sample_components, - sample_scores, - sample_squared_covariance, - sample_total_squared_covariance, - sample_idx_modes_sorted, - sample_norm, - sample_rotation_matrix, - sample_phi_matrix, - sample_modes_sign, -): - """Test the set_attrs() method of MCARotatorDataContainer.""" - data_container = MCARotatorDataContainer() - data_container.set_data( - sample_input_data, - sample_input_data, - sample_components, - sample_components, - sample_scores, - sample_scores, - sample_squared_covariance, - sample_total_squared_covariance, - sample_idx_modes_sorted, - sample_modes_sign, - sample_norm, - sample_norm, - sample_rotation_matrix, - sample_phi_matrix, - ) - data_container.set_attrs({"test": 1}) - - test_mca_set_attrs( - sample_input_data, - sample_components, - sample_scores, - sample_squared_covariance, - sample_total_squared_covariance, - sample_idx_modes_sorted, - sample_norm, - ) # Re-use the test from MCADataContainer. - assert data_container.rotation_matrix.attrs["test"] == 1 - assert data_container.phi_matrix.attrs["test"] == 1 - assert data_container.modes_sign.attrs["test"] == 1 diff --git a/tests/models/test_cca.py b/tests/models/test_cca.py new file mode 100644 index 0000000..620e0a0 --- /dev/null +++ b/tests/models/test_cca.py @@ -0,0 +1,63 @@ +import numpy as np +import xarray as xr +import pytest +import dask.array as da +from numpy.testing import assert_allclose +from ..conftest import generate_list_of_synthetic_dataarrays + +from xeofs.models.cca import CCA + + +@pytest.mark.parametrize( + "dim", + [ + (("time",)), + (("lat", "lon")), + (("lon", "lat")), + ], +) +def test_fit(dim, mock_data_array_list): + """Tests the fit method of the CCA class""" + + cca = CCA() + cca.fit(mock_data_array_list, dim) + + # Assert the required attributes have been set + assert hasattr(cca, "preprocessors") + assert hasattr(cca, "data") + + +@pytest.mark.parametrize( + "dim", + [ + (("time",)), + (("lat", "lon")), + (("lon", "lat")), + ], +) +def test_components(dim, mock_data_array_list): + """Tests the components method of the CCA class""" + + cca = CCA() + cca.fit(mock_data_array_list, dim) + + comps = cca.components() + assert isinstance(comps, list) + + +@pytest.mark.parametrize( + "dim", + [ + (("time",)), + (("lat", "lon")), + (("lon", "lat")), + ], +) +def test_scores(dim, mock_data_array_list): + """Tests the components method of the CCA class""" + + cca = CCA() + cca.fit(mock_data_array_list, dim) + + scores = cca.scores() + assert isinstance(scores, list) diff --git a/tests/models/test_complex_eof_rotator.py b/tests/models/test_complex_eof_rotator.py index b711788..19ac1c5 100644 --- a/tests/models/test_complex_eof_rotator.py +++ b/tests/models/test_complex_eof_rotator.py @@ -4,9 +4,7 @@ from dask.array import Array as DaskArray # type: ignore from xeofs.models import ComplexEOF, ComplexEOFRotator -from xeofs.data_container.eof_rotator_data_container import ( - ComplexEOFRotatorDataContainer, -) +from xeofs.data_container import DataContainer @pytest.fixture @@ -52,7 +50,7 @@ def test_fit(ceof_model): ceof_rotator, "data" ), 'The attribute "data" should be populated after fitting.' assert type(ceof_rotator.model) == ComplexEOF - assert type(ceof_rotator.data) == ComplexEOFRotatorDataContainer + assert type(ceof_rotator.data) == DataContainer @pytest.mark.parametrize( diff --git a/tests/models/test_complex_mca.py b/tests/models/test_complex_mca.py index 3d0046f..bfe348b 100644 --- a/tests/models/test_complex_mca.py +++ b/tests/models/test_complex_mca.py @@ -269,7 +269,7 @@ def test_fit_with_dataset(mca_model, mock_dataset, dim): (("lon", "lat")), ], ) -def test_fit_with_dataarraylist(mca_model, mock_data_array_list, dim): +def test_fit_with_datalist(mca_model, mock_data_array_list, dim): mca_model.fit(mock_data_array_list, mock_data_array_list, dim) assert hasattr(mca_model, "preprocessor1") assert hasattr(mca_model, "preprocessor2") diff --git a/tests/models/test_decomposer.py b/tests/models/test_decomposer.py index 35ad47b..2cdb86a 100644 --- a/tests/models/test_decomposer.py +++ b/tests/models/test_decomposer.py @@ -6,6 +6,7 @@ from scipy.sparse.linalg import svds as complex_svd # type: ignore from dask.array.linalg import svd_compressed as dask_svd from xeofs.models.decomposer import Decomposer +from ..utilities import data_is_dask @pytest.fixture @@ -37,7 +38,7 @@ def test_complex_dask_data_array(mock_complex_data_array): def test_init(decomposer): assert decomposer.n_modes == 2 - assert decomposer.solver_kwargs["random_state"] == 42 + assert decomposer.random_state == 42 def test_fit_full(mock_data_array): @@ -99,25 +100,36 @@ def test_fit_dask_full(mock_dask_data_array): assert decomposer.V_.shape[1] == 2 -def test_fit_dask_randomized(mock_dask_data_array): +@pytest.mark.parametrize("compute", [True, False]) +def test_fit_dask_randomized(mock_dask_data_array, compute): # The Dask SVD solver has no parameter 'random_state' but 'seed' instead, # so let's create a new decomposer for this case - decomposer = Decomposer(n_modes=2, solver="randomized", seed=42) + decomposer = Decomposer(n_modes=2, solver="randomized", compute=compute, seed=42) decomposer.fit(mock_dask_data_array) assert "U_" in decomposer.__dict__ assert "s_" in decomposer.__dict__ assert "V_" in decomposer.__dict__ - # Check if the Dask SVD solver has been used - assert isinstance(decomposer.U_.data, DaskArray) - assert isinstance(decomposer.s_.data, DaskArray) - assert isinstance(decomposer.V_.data, DaskArray) - # Check that indeed 2 modes are returned assert decomposer.U_.shape[1] == 2 assert decomposer.s_.shape[0] == 2 assert decomposer.V_.shape[1] == 2 + is_dask_before = data_is_dask(mock_dask_data_array) + U_is_dask_after = data_is_dask(decomposer.U_) + s_is_dask_after = data_is_dask(decomposer.s_) + V_is_dask_after = data_is_dask(decomposer.V_) + # Check if the Dask SVD solver has been used + assert is_dask_before + if compute: + assert not U_is_dask_after + assert not s_is_dask_after + assert not V_is_dask_after + else: + assert U_is_dask_after + assert s_is_dask_after + assert V_is_dask_after + def test_fit_complex(mock_complex_data_array): decomposer = Decomposer(n_modes=2, solver="randomized", random_state=42) @@ -134,3 +146,34 @@ def test_fit_complex(mock_complex_data_array): # Check that U and V are complex assert np.iscomplexobj(decomposer.U_.data) assert np.iscomplexobj(decomposer.V_.data) + + +@pytest.mark.parametrize( + "data", + ["real", "complex", "dask_real"], +) +def test_random_state( + data, mock_data_array, mock_complex_data_array, mock_dask_data_array +): + match data: + case "real": + X = mock_data_array + case "complex": + X = mock_complex_data_array + case "dask_real": + X = mock_dask_data_array + case _: + raise ValueError(f"Unrecognized data type '{data}'.") + + decomposer = Decomposer( + n_modes=2, solver="randomized", random_state=42, compute=True + ) + decomposer.fit(X) + U1 = decomposer.U_.data + + # Refit + decomposer.fit(X) + U2 = decomposer.U_.data + + # Check that the results are the same + assert np.alltrue(U1 == U2) diff --git a/tests/models/test_eeof.py b/tests/models/test_eeof.py new file mode 100644 index 0000000..b9b2b98 --- /dev/null +++ b/tests/models/test_eeof.py @@ -0,0 +1,439 @@ +import numpy as np +import xarray as xr +import pytest +import dask.array as da +from numpy.testing import assert_allclose + +from xeofs.models.eeof import ExtendedEOF + + +def test_init(): + """Tests the initialization of the ExtendedEOF class""" + eof = ExtendedEOF(n_modes=5, tau=2, embedding=2) + + # Assert preprocessor has been initialized + assert hasattr(eof, "_params") + assert hasattr(eof, "preprocessor") + + +@pytest.mark.parametrize( + "dim", + [ + (("time",)), + (("lat", "lon")), + (("lon", "lat")), + ], +) +def test_fit(dim, mock_data_array): + """Tests the fit method of the ExtendedEOF class""" + + eof = ExtendedEOF(n_modes=5, tau=2, embedding=2) + eof.fit(mock_data_array, dim) + + # Assert the required attributes have been set + assert hasattr(eof, "preprocessor") + assert hasattr(eof, "data") + + +@pytest.mark.parametrize( + "dim", + [ + (("time",)), + (("lat", "lon")), + (("lon", "lat")), + ], +) +def test_singular_values(dim, mock_data_array): + """Tests the singular_values method of the ExtendedEOF class""" + + eof = ExtendedEOF(n_modes=5, tau=2, embedding=2) + eof.fit(mock_data_array, dim) + + # Test singular_values method + singular_values = eof.singular_values() + assert isinstance(singular_values, xr.DataArray) + + +@pytest.mark.parametrize( + "dim", + [ + (("time",)), + (("lat", "lon")), + (("lon", "lat")), + ], +) +def test_explained_variance(dim, mock_data_array): + """Tests the explained_variance method of the ExtendedEOF class""" + eof = ExtendedEOF(n_modes=5, tau=2, embedding=2) + eof.fit(mock_data_array, dim) + + # Test explained_variance method + explained_variance = eof.explained_variance() + assert isinstance(explained_variance, xr.DataArray) + # Explained variance must be positive + assert (explained_variance > 0).all() + + +@pytest.mark.parametrize( + "dim", + [ + (("time",)), + (("lat", "lon")), + (("lon", "lat")), + ], +) +def test_explained_variance_ratio(dim, mock_data_array): + """Tests the explained_variance_ratio method of the ExtendedEOF class""" + eof = ExtendedEOF(n_modes=5, tau=2, embedding=2) + eof.fit(mock_data_array, dim) + + # Test explained_variance_ratio method + explained_variance_ratio = eof.explained_variance_ratio() + assert isinstance(explained_variance_ratio, xr.DataArray) + # Explained variance ratio must be positive + assert ( + explained_variance_ratio > 0 + ).all(), "The explained variance ratio must be positive" + # The sum of the explained variance ratio must be <= 1 + assert ( + explained_variance_ratio.sum() <= 1 + 1e-5 + ), "The sum of the explained variance ratio must be <= 1" + + +@pytest.mark.parametrize( + "dim", + [ + (("time",)), + (("lat", "lon")), + (("lon", "lat")), + ], +) +def test_isolated_nans(dim, mock_data_array_isolated_nans): + """Tests the components method of the ExtendedEOF class""" + eof = ExtendedEOF(n_modes=5, tau=2, embedding=2) + with pytest.raises(ValueError): + eof.fit(mock_data_array_isolated_nans, dim) + + +@pytest.mark.parametrize( + "dim", + [ + (("time",)), + (("lat", "lon")), + (("lon", "lat")), + ], +) +def test_components(dim, mock_data_array): + """Tests the components method of the ExtendedEOF class""" + eof = ExtendedEOF(n_modes=5, tau=2, embedding=2) + eof.fit(mock_data_array, dim) + + # Test components method + components = eof.components() + feature_dims = tuple(set(mock_data_array.dims) - set(dim)) + assert isinstance(components, xr.DataArray), "Components is not a DataArray" + given_dims = set(components.dims) + expected_dims = set(feature_dims + ("mode", "embedding")) + assert ( + given_dims == expected_dims + ), "Components does not have the right feature dimensions" + + +@pytest.mark.parametrize( + "dim", + [ + (("time",)), + (("lat", "lon")), + (("lon", "lat")), + ], +) +def test_components_fulldim_nans(dim, mock_data_array_full_dimensional_nans): + """Tests the components method of the ExtendedEOF class""" + eof = ExtendedEOF(n_modes=5, tau=2, embedding=2) + eof.fit(mock_data_array_full_dimensional_nans, dim) + + # Test components method + components = eof.components() + feature_dims = tuple(set(mock_data_array_full_dimensional_nans.dims) - set(dim)) + assert isinstance(components, xr.DataArray), "Components is not a DataArray" + given_dims = set(components.dims) + expected_dims = set(feature_dims + ("mode", "embedding")) + assert ( + given_dims == expected_dims + ), "Components does not have the right feature dimensions" + + +@pytest.mark.parametrize( + "dim", + [ + (("time",)), + (("lat", "lon")), + (("lon", "lat")), + ], +) +def test_components_boundary_nans(dim, mock_data_array_boundary_nans): + """Tests the components method of the ExtendedEOF class""" + eof = ExtendedEOF(n_modes=5, tau=2, embedding=2) + eof.fit(mock_data_array_boundary_nans, dim) + + # Test components method + components = eof.components() + feature_dims = tuple(set(mock_data_array_boundary_nans.dims) - set(dim)) + assert isinstance(components, xr.DataArray), "Components is not a DataArray" + given_dims = set(components.dims) + expected_dims = set(feature_dims + ("mode", "embedding")) + assert ( + given_dims == expected_dims + ), "Components does not have the right feature dimensions" + + +@pytest.mark.parametrize( + "dim", + [ + (("time",)), + (("lat", "lon")), + (("lon", "lat")), + ], +) +def test_components_dataset(dim, mock_dataset): + """Tests the components method of the ExtendedEOF class""" + eof = ExtendedEOF(n_modes=5, tau=2, embedding=2) + eof.fit(mock_dataset, dim) + + # Test components method + components = eof.components() + feature_dims = tuple(set(mock_dataset.dims) - set(dim)) + assert isinstance(components, xr.Dataset), "Components is not a Dataset" + assert set(components.data_vars) == set( + mock_dataset.data_vars + ), "Components does not have the same data variables as the input Dataset" + given_dims = set(components.dims) + expected_dims = set(feature_dims + ("mode", "embedding")) + assert ( + given_dims == expected_dims + ), "Components does not have the right feature dimensions" + + +@pytest.mark.parametrize( + "dim", + [ + (("time",)), + (("lat", "lon")), + (("lon", "lat")), + ], +) +def test_components_dataarray_list(dim, mock_data_array_list): + """Tests the components method of the ExtendedEOF class""" + eof = ExtendedEOF(n_modes=5, tau=2, embedding=2) + eof.fit(mock_data_array_list, dim) + + # Test components method + components = eof.components() + feature_dims = [tuple(set(data.dims) - set(dim)) for data in mock_data_array_list] + assert isinstance(components, list), "Components is not a list" + assert len(components) == len( + mock_data_array_list + ), "Components does not have the same length as the input list" + assert isinstance( + components[0], xr.DataArray + ), "Components is not a list of DataArrays" + for comp, feat_dims in zip(components, feature_dims): + given_dims = set(comp.dims) + expected_dims = set(feat_dims + ("mode", "embedding")) + assert ( + given_dims == expected_dims + ), "Components does not have the right feature dimensions" + + +@pytest.mark.parametrize( + "dim", + [ + (("time",)), + (("lat", "lon")), + (("lon", "lat")), + ], +) +def test_scores(dim, mock_data_array): + """Tests the scores method of the ExtendedEOF class""" + eof = ExtendedEOF(n_modes=5, tau=2, embedding=2) + eof.fit(mock_data_array, dim) + + # Test scores method + scores = eof.scores() + assert isinstance(scores, xr.DataArray), "Scores is not a DataArray" + assert set(scores.dims) == set( + (dim + ("mode",)) + ), "Scores does not have the right dimensions" + + +@pytest.mark.parametrize( + "dim", + [ + (("time",)), + (("lat", "lon")), + (("lon", "lat")), + ], +) +def test_scores_fulldim_nans(dim, mock_data_array_full_dimensional_nans): + """Tests the scores method of the ExtendedEOF class""" + eof = ExtendedEOF(n_modes=5, tau=2, embedding=2) + eof.fit(mock_data_array_full_dimensional_nans, dim) + + # Test scores method + scores = eof.scores() + assert isinstance(scores, xr.DataArray), "Scores is not a DataArray" + assert set(scores.dims) == set( + (dim + ("mode",)) + ), "Scores does not have the right dimensions" + + +@pytest.mark.parametrize( + "dim", + [ + (("time",)), + (("lat", "lon")), + (("lon", "lat")), + ], +) +def test_scores_boundary_nans(dim, mock_data_array_boundary_nans): + """Tests the scores method of the ExtendedEOF class""" + eof = ExtendedEOF(n_modes=5, tau=2, embedding=2) + eof.fit(mock_data_array_boundary_nans, dim) + + # Test scores method + scores = eof.scores() + assert isinstance(scores, xr.DataArray), "Scores is not a DataArray" + assert set(scores.dims) == set( + (dim + ("mode",)) + ), "Scores does not have the right dimensions" + + +@pytest.mark.parametrize( + "dim", + [ + (("time",)), + (("lat", "lon")), + (("lon", "lat")), + ], +) +def test_scores_dataset(dim, mock_dataset): + """Tests the scores method of the ExtendedEOF class""" + eof = ExtendedEOF(n_modes=5, tau=2, embedding=2) + eof.fit(mock_dataset, dim) + + # Test scores method + scores = eof.scores() + assert isinstance(scores, xr.DataArray) + assert set(scores.dims) == set( + (dim + ("mode",)) + ), "Scores does not have the right dimensions" + + +@pytest.mark.parametrize( + "dim", + [ + (("time",)), + (("lat", "lon")), + (("lon", "lat")), + ], +) +def test_scores_dataarray_list(dim, mock_data_array_list): + """Tests the scores method of the ExtendedEOF class""" + eof = ExtendedEOF(n_modes=5, tau=2, embedding=2) + eof.fit(mock_data_array_list, dim) + + # Test scores method + scores = eof.scores() + assert isinstance(scores, xr.DataArray) + assert set(scores.dims) == set( + (dim + ("mode",)) + ), "Scores does not have the right dimensions" + + +def test_get_params(): + """Tests the get_params method of the ExtendedEOF class""" + eof = ExtendedEOF(n_modes=5, tau=2, embedding=2) + + # Test get_params method + params = eof.get_params() + assert isinstance(params, dict) + assert params.get("n_modes") == 5 + assert params.get("tau") == 2 + assert params.get("embedding") == 2 + assert params.get("solver") == "auto" + + +@pytest.mark.parametrize( + "dim", + [ + (("time",)), + (("lat", "lon")), + (("lon", "lat")), + ], +) +def test_transform(dim, mock_data_array): + """Test projecting new unseen data onto the components (EOFs/eigenvectors)""" + + # Create a xarray DataArray with random data + model = ExtendedEOF(n_modes=5, tau=2, embedding=2, solver="full") + model.fit(mock_data_array, dim) + scores = model.scores() + + # Create a new xarray DataArray with random data + new_data = mock_data_array + + with pytest.raises(NotImplementedError): + projections = model.transform(new_data) + + # # Check that the projection has the right dimensions + # assert projections.dims == scores.dims, "Projection has wrong dimensions" # type: ignore + + # # Check that the projection has the right data type + # assert isinstance(projections, xr.DataArray), "Projection is not a DataArray" + + # # Check that the projection has the right name + # assert projections.name == "scores", "Projection has wrong name: {}".format( + # projections.name + # ) + + # # Check that the projection's data is the same as the scores + # np.testing.assert_allclose( + # scores.sel(mode=slice(1, 3)), projections.sel(mode=slice(1, 3)), rtol=1e-3 + # ) + + +@pytest.mark.parametrize( + "dim", + [ + (("time",)), + (("lat", "lon")), + (("lon", "lat")), + ], +) +def test_inverse_transform(dim, mock_data_array): + """Test inverse_transform method in ExtendedEOF class.""" + + # instantiate the ExtendedEOF class with necessary parameters + eof = ExtendedEOF(n_modes=5, tau=2, embedding=2) + + # fit the ExtendedEOF model + eof.fit(mock_data_array, dim=dim) + + # Test with scalar + mode = 1 + with pytest.raises(NotImplementedError): + reconstructed_data = eof.inverse_transform(mode) + # assert isinstance(reconstructed_data, xr.DataArray) + + # # Test with slice + # mode = slice(1, 2) + # reconstructed_data = eof.inverse_transform(mode) + # assert isinstance(reconstructed_data, xr.DataArray) + + # # Test with array of tick labels + # mode = np.array([1, 3]) + # reconstructed_data = eof.inverse_transform(mode) + # assert isinstance(reconstructed_data, xr.DataArray) + + # # Check that the reconstructed data has the same dimensions as the original data + # assert set(reconstructed_data.dims) == set(mock_data_array.dims) diff --git a/tests/models/test_eof.py b/tests/models/test_eof.py index c822e4e..16d4ef2 100644 --- a/tests/models/test_eof.py +++ b/tests/models/test_eof.py @@ -11,16 +11,8 @@ def test_init(): """Tests the initialization of the EOF class""" eof = EOF(n_modes=5, standardize=True, use_coslat=True) - # Assert parameters are correctly stored in the _params attribute - assert eof._params == { - "n_modes": 5, - "standardize": True, - "use_coslat": True, - "use_weights": False, - "solver": "auto", - } - # Assert preprocessor has been initialized + assert hasattr(eof, "_params") assert hasattr(eof, "preprocessor") @@ -108,6 +100,21 @@ def test_explained_variance_ratio(dim, mock_data_array): ), "The sum of the explained variance ratio must be <= 1" +@pytest.mark.parametrize( + "dim", + [ + (("time",)), + (("lat", "lon")), + (("lon", "lat")), + ], +) +def test_isolated_nans(dim, mock_data_array_isolated_nans): + """Tests the components method of the EOF class""" + eof = EOF() + with pytest.raises(ValueError): + eof.fit(mock_data_array_isolated_nans, dim) + + @pytest.mark.parametrize( "dim", [ @@ -130,6 +137,50 @@ def test_components(dim, mock_data_array): ), "Components does not have the right feature dimensions" +@pytest.mark.parametrize( + "dim", + [ + (("time",)), + (("lat", "lon")), + (("lon", "lat")), + ], +) +def test_components_fulldim_nans(dim, mock_data_array_full_dimensional_nans): + """Tests the components method of the EOF class""" + eof = EOF() + eof.fit(mock_data_array_full_dimensional_nans, dim) + + # Test components method + components = eof.components() + feature_dims = tuple(set(mock_data_array_full_dimensional_nans.dims) - set(dim)) + assert isinstance(components, xr.DataArray), "Components is not a DataArray" + assert set(components.dims) == set( + ("mode",) + feature_dims + ), "Components does not have the right feature dimensions" + + +@pytest.mark.parametrize( + "dim", + [ + (("time",)), + (("lat", "lon")), + (("lon", "lat")), + ], +) +def test_components_boundary_nans(dim, mock_data_array_boundary_nans): + """Tests the components method of the EOF class""" + eof = EOF() + eof.fit(mock_data_array_boundary_nans, dim) + + # Test components method + components = eof.components() + feature_dims = tuple(set(mock_data_array_boundary_nans.dims) - set(dim)) + assert isinstance(components, xr.DataArray), "Components is not a DataArray" + assert set(components.dims) == set( + ("mode",) + feature_dims + ), "Components does not have the right feature dimensions" + + @pytest.mark.parametrize( "dim", [ @@ -205,6 +256,48 @@ def test_scores(dim, mock_data_array): ), "Scores does not have the right dimensions" +@pytest.mark.parametrize( + "dim", + [ + (("time",)), + (("lat", "lon")), + (("lon", "lat")), + ], +) +def test_scores_fulldim_nans(dim, mock_data_array_full_dimensional_nans): + """Tests the scores method of the EOF class""" + eof = EOF() + eof.fit(mock_data_array_full_dimensional_nans, dim) + + # Test scores method + scores = eof.scores() + assert isinstance(scores, xr.DataArray), "Scores is not a DataArray" + assert set(scores.dims) == set( + (dim + ("mode",)) + ), "Scores does not have the right dimensions" + + +@pytest.mark.parametrize( + "dim", + [ + (("time",)), + (("lat", "lon")), + (("lon", "lat")), + ], +) +def test_scores_boundary_nans(dim, mock_data_array_boundary_nans): + """Tests the scores method of the EOF class""" + eof = EOF() + eof.fit(mock_data_array_boundary_nans, dim) + + # Test scores method + scores = eof.scores() + assert isinstance(scores, xr.DataArray), "Scores is not a DataArray" + assert set(scores.dims) == set( + (dim + ("mode",)) + ), "Scores does not have the right dimensions" + + @pytest.mark.parametrize( "dim", [ @@ -254,13 +347,10 @@ def test_get_params(): # Test get_params method params = eof.get_params() assert isinstance(params, dict) - assert params == { - "n_modes": 5, - "standardize": True, - "use_coslat": True, - "use_weights": False, - "solver": "auto", - } + assert params.get("n_modes") == 5 + assert params.get("standardize") is True + assert params.get("use_coslat") is True + assert params.get("solver") == "auto" @pytest.mark.parametrize( @@ -291,7 +381,9 @@ def test_transform(dim, mock_data_array): assert isinstance(projections, xr.DataArray), "Projection is not a DataArray" # Check that the projection has the right name - assert projections.name == "scores", "Projection has wrong name" + assert projections.name == "scores", "Projection has wrong name: {}".format( + projections.name + ) # Check that the projection's data is the same as the scores np.testing.assert_allclose( diff --git a/tests/models/test_eof_rotator.py b/tests/models/test_eof_rotator.py index 38374c2..3fb26eb 100644 --- a/tests/models/test_eof_rotator.py +++ b/tests/models/test_eof_rotator.py @@ -4,9 +4,8 @@ from dask.array import Array as DaskArray # type: ignore from xeofs.models import EOF, EOFRotator -from xeofs.data_container.eof_rotator_data_container import ( - EOFRotatorDataContainer, -) +from xeofs.data_container import DataContainer +from ..utilities import data_is_dask @pytest.fixture @@ -18,7 +17,7 @@ def eof_model(mock_data_array, dim): @pytest.fixture def eof_model_delayed(mock_dask_data_array, dim): - eof = EOF(n_modes=5) + eof = EOF(n_modes=5, compute=False) eof.fit(mock_dask_data_array, dim) return eof @@ -52,7 +51,7 @@ def test_fit(eof_model): eof_rotator, "data" ), 'The attribute "data" should be populated after fitting.' assert type(eof_rotator.model) == EOF - assert type(eof_rotator.data) == EOFRotatorDataContainer + assert type(eof_rotator.data) == DataContainer @pytest.mark.parametrize( @@ -170,49 +169,26 @@ def test_scores(eof_model): @pytest.mark.parametrize( - "dim", + "dim, compute", [ - (("time",)), - (("lat", "lon")), - (("lon", "lat")), + (("time",), True), + (("lat", "lon"), True), + (("lon", "lat"), True), + (("time",), False), + (("lat", "lon"), False), + (("lon", "lat"), False), ], ) -def test_compute(eof_model_delayed): - eof_rotator = EOFRotator(n_modes=5) +def test_compute(eof_model_delayed, compute): + eof_rotator = EOFRotator(n_modes=5, compute=compute) eof_rotator.fit(eof_model_delayed) - # before computation, the attributes should be dask arrays - assert isinstance( - eof_rotator.data.explained_variance.data, DaskArray - ), "The attribute _explained_variance should be a dask array." - assert isinstance( - eof_rotator.data.explained_variance_ratio.data, DaskArray - ), "The attribute _explained_variance_ratio should be a dask array." - assert isinstance( - eof_rotator.data.components.data, DaskArray - ), "The attribute _components should be a dask array." - assert isinstance( - eof_rotator.data.rotation_matrix.data, DaskArray - ), "The attribute _rotation_matrix should be a dask array." - assert isinstance( - eof_rotator.data.scores.data, DaskArray - ), "The attribute _scores should be a dask array." - - eof_rotator.compute() - - # after computation, the attributes should be numpy ndarrays - assert isinstance( - eof_rotator.data.explained_variance.data, np.ndarray - ), "The attribute _explained_variance should be a numpy ndarray." - assert isinstance( - eof_rotator.data.explained_variance_ratio.data, np.ndarray - ), "The attribute _explained_variance_ratio should be a numpy ndarray." - assert isinstance( - eof_rotator.data.components.data, np.ndarray - ), "The attribute _components should be a numpy ndarray." - assert isinstance( - eof_rotator.data.rotation_matrix.data, np.ndarray - ), "The attribute _rotation_matrix should be a numpy ndarray." - assert isinstance( - eof_rotator.data.scores.data, np.ndarray - ), "The attribute _scores should be a numpy ndarray." + if compute: + assert not data_is_dask(eof_rotator.data["explained_variance"]) + assert not data_is_dask(eof_rotator.data["components"]) + assert not data_is_dask(eof_rotator.data["rotation_matrix"]) + + else: + assert data_is_dask(eof_rotator.data["explained_variance"]) + assert data_is_dask(eof_rotator.data["components"]) + assert data_is_dask(eof_rotator.data["rotation_matrix"]) diff --git a/tests/models/test_gwpca.py b/tests/models/test_gwpca.py new file mode 100644 index 0000000..c3e87d5 --- /dev/null +++ b/tests/models/test_gwpca.py @@ -0,0 +1,55 @@ +import pytest +import xeofs as xe + +from ..utilities import assert_expected_dims, data_is_dask, data_has_multiindex + +# ============================================================================= +# GENERALLY VALID TEST CASES +# ============================================================================= +N_ARRAYS = [1, 2] +N_SAMPLE_DIMS = [1, 2] +N_FEATURE_DIMS = [1, 2] +INDEX_POLICY = ["index"] +NAN_POLICY = ["no_nan", "fulldim"] +DASK_POLICY = ["no_dask"] +SEED = [0] + +VALID_TEST_DATA = [ + (na, ns, nf, index, nan, dask) + for na in N_ARRAYS + for ns in N_SAMPLE_DIMS + for nf in N_FEATURE_DIMS + for index in INDEX_POLICY + for nan in NAN_POLICY + for dask in DASK_POLICY +] + + +# TESTS +# ============================================================================= +@pytest.mark.parametrize( + "kernel", + [("bisquare"), ("gaussian"), ("exponential")], +) +def test_fit(mock_data_array, kernel): + gwpca = xe.models.GWPCA( + n_modes=2, metric="haversine", kernel=kernel, bandwidth=5000 + ) + gwpca.fit(mock_data_array, dim=("lat", "lon")) + comps = gwpca.components() + llwc = gwpca.largest_locally_weighted_components() + + +@pytest.mark.parametrize( + "metric, kernel, bandwidth", + [ + ("haversine", "invalid_kernel", 5000), + ("invalid_metric", "gaussian", 5000), + ("haversine", "exponential", 0), + ], +) +def test_fit_invalid(mock_data_array, metric, kernel, bandwidth): + with pytest.raises(ValueError): + gwpca = xe.models.GWPCA( + n_modes=2, metric=metric, kernel=kernel, bandwidth=bandwidth + ) diff --git a/tests/models/test_mca.py b/tests/models/test_mca.py index 4bf3ca8..afa4c51 100644 --- a/tests/models/test_mca.py +++ b/tests/models/test_mca.py @@ -5,6 +5,7 @@ from numpy.testing import assert_allclose from xeofs.models.mca import MCA +from ..utilities import data_is_dask @pytest.fixture @@ -357,26 +358,26 @@ def test_heterogeneous_patterns(mca_model, mock_data_array, dim): @pytest.mark.parametrize( - "dim", + "dim, compute", [ - (("time",)), - (("lat", "lon")), - (("lon", "lat")), + (("time",), True), + (("lat", "lon"), True), + (("lon", "lat"), True), + (("time",), False), + (("lat", "lon"), False), + (("lon", "lat"), False), ], ) -def test_compute(mca_model, mock_dask_data_array, dim): +def test_compute(mock_dask_data_array, dim, compute): + mca_model = MCA(n_modes=10, compute=compute) mca_model.fit(mock_dask_data_array, mock_dask_data_array, (dim)) - assert isinstance(mca_model.data.squared_covariance.data, DaskArray) - assert isinstance(mca_model.data.components1.data, DaskArray) - assert isinstance(mca_model.data.components2.data, DaskArray) - assert isinstance(mca_model.data.scores1.data, DaskArray) - assert isinstance(mca_model.data.scores2.data, DaskArray) - - mca_model.compute() + if compute: + assert not data_is_dask(mca_model.data["squared_covariance"]) + assert not data_is_dask(mca_model.data["components1"]) + assert not data_is_dask(mca_model.data["components2"]) - assert isinstance(mca_model.data.squared_covariance.data, np.ndarray) - assert isinstance(mca_model.data.components1.data, np.ndarray) - assert isinstance(mca_model.data.components2.data, np.ndarray) - assert isinstance(mca_model.data.scores1.data, np.ndarray) - assert isinstance(mca_model.data.scores2.data, np.ndarray) + else: + assert data_is_dask(mca_model.data["squared_covariance"]) + assert data_is_dask(mca_model.data["components1"]) + assert data_is_dask(mca_model.data["components2"]) diff --git a/tests/models/test_mca_rotator.py b/tests/models/test_mca_rotator.py index 1e1c4df..0cd9c59 100644 --- a/tests/models/test_mca_rotator.py +++ b/tests/models/test_mca_rotator.py @@ -5,6 +5,7 @@ # Import the classes from your modules from xeofs.models import MCA, MCARotator +from ..utilities import data_is_dask @pytest.fixture @@ -16,7 +17,7 @@ def mca_model(mock_data_array, dim): @pytest.fixture def mca_model_delayed(mock_dask_data_array, dim): - mca = MCA(n_modes=5) + mca = MCA(n_modes=5, compute=False) mca.fit(mock_dask_data_array, mock_dask_data_array, dim) return mca @@ -213,78 +214,38 @@ def test_heterogeneous_patterns(mca_model, mock_data_array, dim): @pytest.mark.parametrize( - "dim", + "dim, compute", [ - (("time",)), - (("lat", "lon")), - (("lon", "lat")), + (("time",), True), + (("lat", "lon"), True), + (("lon", "lat"), True), + (("time",), False), + (("lat", "lon"), False), + (("lon", "lat"), False), ], ) -def test_compute(mca_model_delayed): +def test_compute(mca_model_delayed, compute): """Test the compute method of the MCARotator class.""" - mca_rotator = MCARotator(n_modes=4, rtol=1e-5) + mca_rotator = MCARotator(n_modes=4, compute=compute, rtol=1e-5) mca_rotator.fit(mca_model_delayed) - assert isinstance( - mca_rotator.data.squared_covariance.data, DaskArray - ), "squared_covariance is not a delayed object" - assert isinstance( - mca_rotator.data.components1.data, DaskArray - ), "components1 is not a delayed object" - assert isinstance( - mca_rotator.data.components2.data, DaskArray - ), "components2 is not a delayed object" - assert isinstance( - mca_rotator.data.scores1.data, DaskArray - ), "scores1 is not a delayed object" - assert isinstance( - mca_rotator.data.scores2.data, DaskArray - ), "scores2 is not a delayed object" - assert isinstance( - mca_rotator.data.rotation_matrix.data, DaskArray - ), "rotation_matrix is not a delayed object" - assert isinstance( - mca_rotator.data.phi_matrix.data, DaskArray - ), "phi_matrix is not a delayed object" - assert isinstance( - mca_rotator.data.norm1.data, DaskArray - ), "norm1 is not a delayed object" - assert isinstance( - mca_rotator.data.norm2.data, DaskArray - ), "norm2 is not a delayed object" - assert isinstance( - mca_rotator.data.modes_sign.data, DaskArray - ), "modes_sign is not a delayed object" - - mca_rotator.compute() - - assert isinstance( - mca_rotator.data.squared_covariance.data, np.ndarray - ), "squared_covariance is not computed" - assert isinstance( - mca_rotator.data.total_squared_covariance.data, np.ndarray - ), "total_squared_covariance is not computed" - assert isinstance( - mca_rotator.data.components1.data, np.ndarray - ), "components1 is not computed" - assert isinstance( - mca_rotator.data.components2.data, np.ndarray - ), "components2 is not computed" - assert isinstance( - mca_rotator.data.scores1.data, np.ndarray - ), "scores1 is not computed" - assert isinstance( - mca_rotator.data.scores2.data, np.ndarray - ), "scores2 is not computed" - assert isinstance( - mca_rotator.data.rotation_matrix.data, np.ndarray - ), "rotation_matrix is not computed" - assert isinstance( - mca_rotator.data.phi_matrix.data, np.ndarray - ), "phi_matrix is not computed" - assert isinstance(mca_rotator.data.norm1.data, np.ndarray), "norm1 is not computed" - assert isinstance(mca_rotator.data.norm2.data, np.ndarray), "norm2 is not computed" - assert isinstance( - mca_rotator.data.modes_sign.data, np.ndarray - ), "modes_sign is not computed" + if compute: + assert not data_is_dask(mca_rotator.data["squared_covariance"]) + assert not data_is_dask(mca_rotator.data["components1"]) + assert not data_is_dask(mca_rotator.data["components2"]) + assert not data_is_dask(mca_rotator.data["rotation_matrix"]) + assert not data_is_dask(mca_rotator.data["phi_matrix"]) + assert not data_is_dask(mca_rotator.data["norm1"]) + assert not data_is_dask(mca_rotator.data["norm2"]) + assert not data_is_dask(mca_rotator.data["modes_sign"]) + + else: + assert data_is_dask(mca_rotator.data["squared_covariance"]) + assert data_is_dask(mca_rotator.data["components1"]) + assert data_is_dask(mca_rotator.data["components2"]) + assert data_is_dask(mca_rotator.data["rotation_matrix"]) + assert data_is_dask(mca_rotator.data["phi_matrix"]) + assert data_is_dask(mca_rotator.data["norm1"]) + assert data_is_dask(mca_rotator.data["norm2"]) + assert data_is_dask(mca_rotator.data["modes_sign"]) diff --git a/tests/models/test_opa.py b/tests/models/test_opa.py index 8f3d38c..d587c78 100644 --- a/tests/models/test_opa.py +++ b/tests/models/test_opa.py @@ -16,18 +16,8 @@ def test_init(): """Tests the initialization of the OPA class""" opa = OPA(n_modes=3, tau_max=3, n_pca_modes=19, use_coslat=True) - # Assert parameters are correctly stored in the _params attribute - assert opa._params == { - "n_modes": 3, - "tau_max": 3, - "n_pca_modes": 19, - "standardize": False, - "use_coslat": True, - "use_weights": False, - "solver": "auto", - } - # Assert preprocessor has been initialized + assert hasattr(opa, "_params") assert hasattr(opa, "preprocessor") @@ -229,15 +219,12 @@ def test_get_params(opa_model): # Test get_params method params = opa_model.get_params() assert isinstance(params, dict) - assert params == { - "n_modes": 3, - "tau_max": 3, - "n_pca_modes": 19, - "standardize": False, - "use_coslat": False, - "use_weights": False, - "solver": "auto", - } + assert params.get("n_modes") == 3 + assert params.get("tau_max") == 3 + assert params.get("n_pca_modes") == 19 + assert params.get("standardize") is False + assert params.get("use_coslat") is False + assert params.get("solver") == "auto" @pytest.mark.parametrize( @@ -385,7 +372,7 @@ def test_scores_uncorrelated(dim, use_coslat, mock_data_array): use_coslat=use_coslat, ) model.fit(mock_data_array, dim=dim) - scores = model.data.scores.values + scores = model.data["scores"].values check = scores.T @ scores / (scores.shape[0] - 1) assert np.allclose( check, np.eye(check.shape[1]), atol=1e-5 diff --git a/tests/models/test_orthogonality.py b/tests/models/test_orthogonality.py index 98425e2..0fe3296 100644 --- a/tests/models/test_orthogonality.py +++ b/tests/models/test_orthogonality.py @@ -23,7 +23,7 @@ def test_eof_components(dim, use_coslat, mock_data_array): """Components are orthogonal""" model = EOF(n_modes=5, standardize=True, use_coslat=use_coslat) model.fit(mock_data_array, dim=dim) - V = model.data.components.values + V = model.data["components"].values assert np.allclose( V.T @ V, np.eye(V.shape[1]), atol=1e-5 ), "Components are not orthogonal" @@ -41,7 +41,7 @@ def test_eof_scores(dim, use_coslat, mock_data_array): """Scores are orthogonal""" model = EOF(n_modes=5, standardize=True, use_coslat=use_coslat) model.fit(mock_data_array, dim=dim) - U = model.data.scores.values + U = model.data["scores"].values / model.data["norms"].values assert np.allclose( U.T @ U, np.eye(U.shape[1]), atol=1e-5 ), "Scores are not orthogonal" @@ -60,7 +60,7 @@ def test_ceof_components(dim, use_coslat, mock_data_array): """Components are unitary""" model = ComplexEOF(n_modes=5, standardize=True, use_coslat=use_coslat) model.fit(mock_data_array, dim=dim) - V = model.data.components.values + V = model.data["components"].values assert np.allclose( V.conj().T @ V, np.eye(V.shape[1]), atol=1e-5 ), "Components are not unitary" @@ -78,7 +78,7 @@ def test_ceof_scores(dim, use_coslat, mock_data_array): """Scores are unitary""" model = ComplexEOF(n_modes=5, standardize=True, use_coslat=use_coslat) model.fit(mock_data_array, dim=dim) - U = model.data.scores.values + U = model.data["scores"].values / model.data["norms"].values assert np.allclose( U.conj().T @ U, np.eye(U.shape[1]), atol=1e-5 ), "Scores are not unitary" @@ -102,7 +102,7 @@ def test_reof_components(dim, use_coslat, power, mock_data_array): model.fit(mock_data_array, dim=dim) rot = EOFRotator(n_modes=5, power=power) rot.fit(model) - V = rot.data.components.values + V = rot.data["components"].values K = V.conj().T @ V assert np.allclose( np.diag(K), np.ones(V.shape[1]), atol=1e-5 @@ -128,7 +128,7 @@ def test_reof_scores(dim, use_coslat, power, mock_data_array): model.fit(mock_data_array, dim=dim) rot = EOFRotator(n_modes=5, power=power) rot.fit(model) - U = rot.data.scores.values + U = rot.data["scores"].values / rot.data["norms"].values K = U.conj().T @ U if power == 1: # Varimax rotation does guarantee orthogonality @@ -157,7 +157,7 @@ def test_creof_components(dim, use_coslat, power, mock_data_array): model.fit(mock_data_array, dim=dim) rot = ComplexEOFRotator(n_modes=5, power=power) rot.fit(model) - V = rot.data.components.values + V = rot.data["components"].values K = V.conj().T @ V assert np.allclose( np.diag(K), np.ones(V.shape[1]), atol=1e-5 @@ -183,7 +183,7 @@ def test_creof_scores(dim, use_coslat, power, mock_data_array): model.fit(mock_data_array, dim=dim) rot = ComplexEOFRotator(n_modes=5, power=power) rot.fit(model) - U = rot.data.scores.values + U = rot.data["scores"].values / rot.data["norms"].values K = U.conj().T @ U if power == 1: # Varimax rotation does guarantee unitarity @@ -209,8 +209,8 @@ def test_mca_components(dim, use_coslat, mock_data_array): data2 = data1.copy() ** 2 model = MCA(n_modes=5, standardize=True, use_coslat=use_coslat) model.fit(data1, data2, dim=dim) - V1 = model.data.components1.values - V2 = model.data.components2.values + V1 = model.data["components1"].values + V2 = model.data["components2"].values K1 = V1.T @ V1 K2 = V2.T @ V2 assert np.allclose( @@ -235,10 +235,10 @@ def test_mca_scores(dim, use_coslat, mock_data_array): data2 = data1.copy() ** 2 model = MCA(n_modes=5, standardize=True, use_coslat=use_coslat) model.fit(data1, data2, dim=dim) - U1 = model.data.scores1.values - U2 = model.data.scores2.values + U1 = model.data["scores1"].values + U2 = model.data["scores2"].values K = U1.T @ U2 - target = np.eye(K.shape[0]) * (model.data.input_data1.sample.size - 1) + target = np.eye(K.shape[0]) * (model.data["input_data1"].sample.size - 1) assert np.allclose(K, target, atol=1e-5), "Scores are not orthogonal" @@ -257,8 +257,8 @@ def test_cmca_components(dim, use_coslat, mock_data_array): data2 = data1.copy() ** 2 model = ComplexMCA(n_modes=5, standardize=True, use_coslat=use_coslat) model.fit(data1, data2, dim=dim) - V1 = model.data.components1.values - V2 = model.data.components2.values + V1 = model.data["components1"].values + V2 = model.data["components2"].values K1 = V1.conj().T @ V1 K2 = V2.conj().T @ V2 assert np.allclose( @@ -283,10 +283,10 @@ def test_cmca_scores(dim, use_coslat, mock_data_array): data2 = data1.copy() ** 2 model = ComplexMCA(n_modes=10, standardize=True, use_coslat=use_coslat) model.fit(data1, data2, dim=dim) - U1 = model.data.scores1.values - U2 = model.data.scores2.values + U1 = model.data["scores1"].values + U2 = model.data["scores2"].values K = U1.conj().T @ U2 - target = np.eye(K.shape[0]) * (model.data.input_data1.sample.size - 1) + target = np.eye(K.shape[0]) * (model.data["input_data1"].sample.size - 1) assert np.allclose(K, target, atol=1e-5), "Scores are not unitary" @@ -316,8 +316,8 @@ def test_rmca_components(dim, use_coslat, power, squared_loadings, mock_data_arr model.fit(data1, data2, dim=dim) rot = MCARotator(n_modes=5, power=power, squared_loadings=squared_loadings) rot.fit(model) - V1 = rot.data.components1.values - V2 = rot.data.components2.values + V1 = rot.data["components1"].values + V2 = rot.data["components2"].values K1 = V1.conj().T @ V1 K2 = V2.conj().T @ V2 assert np.allclose( @@ -356,10 +356,10 @@ def test_rmca_scores(dim, use_coslat, power, squared_loadings, mock_data_array): model.fit(data1, data2, dim=dim) rot = MCARotator(n_modes=5, power=power, squared_loadings=squared_loadings) rot.fit(model) - U1 = rot.data.scores1.values - U2 = rot.data.scores2.values + U1 = rot.data["scores1"].values + U2 = rot.data["scores2"].values K = U1.conj().T @ U2 - target = np.eye(K.shape[0]) * (model.data.input_data1.sample.size - 1) + target = np.eye(K.shape[0]) * (model.data["input_data1"].sample.size - 1) if power == 1: # Varimax rotation does guarantee orthogonality assert np.allclose(K, target, atol=1e-5), "Components are not orthogonal" @@ -393,8 +393,8 @@ def test_crmca_components(dim, use_coslat, power, squared_loadings, mock_data_ar model.fit(data1, data2, dim=dim) rot = ComplexMCARotator(n_modes=5, power=power, squared_loadings=squared_loadings) rot.fit(model) - V1 = rot.data.components1.values - V2 = rot.data.components2.values + V1 = rot.data["components1"].values + V2 = rot.data["components2"].values K1 = V1.conj().T @ V1 K2 = V2.conj().T @ V2 assert np.allclose( @@ -433,10 +433,10 @@ def test_crmca_scores(dim, use_coslat, power, squared_loadings, mock_data_array) model.fit(data1, data2, dim=dim) rot = ComplexMCARotator(n_modes=5, power=power, squared_loadings=squared_loadings) rot.fit(model) - U1 = rot.data.scores1.values - U2 = rot.data.scores2.values + U1 = rot.data["scores1"].values + U2 = rot.data["scores2"].values K = U1.conj().T @ U2 - target = np.eye(K.shape[0]) * (model.data.input_data1.sample.size - 1) + target = np.eye(K.shape[0]) * (model.data["input_data1"].sample.size - 1) if power == 1: # Varimax rotation does guarantee orthogonality assert np.allclose(K, target, atol=1e-5), "Components are not orthogonal" @@ -457,7 +457,12 @@ def test_crmca_scores(dim, use_coslat, power, squared_loadings, mock_data_array) ) def test_eof_transform(dim, use_coslat, mock_data_array): """Transforming the original data results in the model scores""" - model = EOF(n_modes=5, standardize=True, use_coslat=use_coslat) + model = EOF( + n_modes=5, + standardize=True, + use_coslat=use_coslat, + solver_kwargs={"random_state": 5}, + ) model.fit(mock_data_array, dim=dim) scores = model.scores() pseudo_scores = model.transform(mock_data_array) diff --git a/tests/preprocessing/test_dataarray_list_stacker.py b/tests/preprocessing/test_dataarray_list_stacker.py deleted file mode 100644 index a9c4eaa..0000000 --- a/tests/preprocessing/test_dataarray_list_stacker.py +++ /dev/null @@ -1,124 +0,0 @@ -import pytest -import xarray as xr -import numpy as np - -from xeofs.preprocessing.stacker import ListDataArrayStacker - - -@pytest.mark.parametrize( - "dim_sample, dim_feature", - [ - (("time",), ("lat", "lon")), - (("time",), ("lon", "lat")), - (("lat", "lon"), ("time",)), - (("lon", "lat"), ("time",)), - ], -) -def test_fit_transform(dim_sample, dim_feature, mock_data_array_list): - """ - Test that ListDataArrayStacker correctly stacks a list of DataArrays and - fit_transform returns DataArray with 'sample' and 'feature' dimensions. - """ - stacker = ListDataArrayStacker() - feature_dims_list = [dim_feature] * len( - mock_data_array_list - ) # Assume that all DataArrays have the same feature dimensions - stacked_data = stacker.fit_transform( - mock_data_array_list, dim_sample, feature_dims_list - ) - - # Check if the output is a DataArray - assert isinstance(stacked_data, xr.DataArray) - # Check if the dimensions are correct - assert set(stacked_data.dims) == set(("sample", "feature")) - # Check if the data is preserved - assert stacked_data.size == sum([da.size for da in mock_data_array_list]) - - # Check if the transform function returns the same result - transformed_data = stacker.transform(mock_data_array_list) - [ - xr.testing.assert_equal(stacked, transformed) - for stacked, transformed in zip(stacked_data, transformed_data) - ] - - # Check if the stacker dimensions are correct - for stckr, da in zip(stacker.stackers, mock_data_array_list): - assert set(stckr.dims_in_) == set(da.dims) - assert set(stckr.dims_out_) == set(("sample", "feature")) - # test that coordinates are preserved - for dim, coords in da.coords.items(): - assert ( - stckr.coords_in_[dim].size == coords.size - ), "Dimension {} has different size.".format(dim) - assert stckr.coords_out_["sample"].size == np.prod( - [coords.size for dim, coords in da.coords.items() if dim in dim_sample] - ), "Sample dimension has different size." - assert stckr.coords_out_["feature"].size == np.prod( - [coords.size for dim, coords in da.coords.items() if dim in dim_feature] - ), "Feature dimension has different size." - - # Check that invalid input raises an error in transform - with pytest.raises(ValueError): - stacker.transform( - [ - xr.DataArray(np.random.rand(2, 4, 5), dims=("a", "y", "x")) - for _ in range(3) - ] - ) - - -@pytest.mark.parametrize( - "dim_sample, dim_feature", - [ - (("time",), ("lat", "lon")), - (("time",), ("lon", "lat")), - (("lat", "lon"), ("time",)), - (("lon", "lat"), ("time",)), - ], -) -def test_unstack_data(dim_sample, dim_feature, mock_data_array_list): - """Test if the inverse transformed DataArrays are identical to the original DataArrays.""" - stacker_list = ListDataArrayStacker() - feature_dims_list = [dim_feature] * len( - mock_data_array_list - ) # Assume that all DataArrays have the same feature dimensions - stacked = stacker_list.fit_transform(mock_data_array_list, dim_sample, feature_dims_list) # type: ignore - unstacked = stacker_list.inverse_transform_data(stacked) - - for da_test, da_ref in zip(unstacked, mock_data_array_list): - xr.testing.assert_equal(da_test, da_ref) - - -@pytest.mark.parametrize( - "dim_sample, dim_feature", - [ - (("time",), ("lat", "lon")), - (("time",), ("lon", "lat")), - (("lat", "lon"), ("time",)), - (("lon", "lat"), ("time",)), - ], -) -def test_unstack_components(dim_sample, dim_feature, mock_data_array_list): - """Test if the inverse transformed components are identical to the original components.""" - stacker_list = ListDataArrayStacker() - feature_dims_list = [dim_feature] * len(mock_data_array_list) - stacked = stacker_list.fit_transform( - mock_data_array_list, dim_sample, feature_dims_list - ) - - components = xr.DataArray( - np.random.normal(size=(stacker_list.coords_out_["feature"].size, 10)), - dims=("feature", "mode"), - coords={"feature": stacker_list.coords_out_["feature"]}, - ) - unstacked = stacker_list.inverse_transform_components(components) - - for da_test, da_ref in zip(unstacked, mock_data_array_list): - # Check if the dimensions are correct - assert set(da_test.dims) == set(dim_feature + ("mode",)) - # Check if the coordinates are preserved - for dim, coords in da_ref.coords.items(): - if dim in dim_feature: - assert ( - da_test.coords[dim].size == coords.size - ), "Dimension {} has different size.".format(dim) diff --git a/tests/preprocessing/test_dataarray_multiindex_converter.py b/tests/preprocessing/test_dataarray_multiindex_converter.py new file mode 100644 index 0000000..876ec5f --- /dev/null +++ b/tests/preprocessing/test_dataarray_multiindex_converter.py @@ -0,0 +1,81 @@ +import pytest +import pandas as pd + +from xeofs.preprocessing.multi_index_converter import ( + MultiIndexConverter, +) +from ..conftest import generate_synthetic_dataarray +from xeofs.utils.data_types import DataArray +from ..utilities import assert_expected_dims, data_is_dask, data_has_multiindex + +# ============================================================================= +# GENERALLY VALID TEST CASES +# ============================================================================= +N_SAMPLE_DIMS = [1, 2] +N_FEATURE_DIMS = [1, 2] +INDEX_POLICY = ["index", "multiindex"] +NAN_POLICY = ["no_nan"] +DASK_POLICY = ["no_dask", "dask"] +SEED = [0] + +VALID_TEST_DATA = [ + (ns, nf, index, nan, dask) + for ns in N_SAMPLE_DIMS + for nf in N_FEATURE_DIMS + for index in INDEX_POLICY + for nan in NAN_POLICY + for dask in DASK_POLICY +] + + +@pytest.mark.parametrize( + "synthetic_dataarray", + VALID_TEST_DATA, + indirect=["synthetic_dataarray"], +) +def test_transform(synthetic_dataarray): + converter = MultiIndexConverter() + converter.fit(synthetic_dataarray) + transformed_data = converter.transform(synthetic_dataarray) + + is_dask_before = data_is_dask(synthetic_dataarray) + is_dask_after = data_is_dask(transformed_data) + + # Transforming doesn't change the dask-ness of the data + assert is_dask_before == is_dask_after + + # Transforming removes MultiIndex + assert data_has_multiindex(transformed_data) is False + + # Result is robust to calling the method multiple times + transformed_data = converter.transform(synthetic_dataarray) + assert data_has_multiindex(transformed_data) is False + + # Transforming data twice won't change the data + transformed_data2 = converter.transform(transformed_data) + assert data_has_multiindex(transformed_data2) is False + assert transformed_data.identical(transformed_data2) + + +@pytest.mark.parametrize( + "synthetic_dataarray", + VALID_TEST_DATA, + indirect=["synthetic_dataarray"], +) +def test_inverse_transform_data(synthetic_dataarray): + converter = MultiIndexConverter() + converter.fit(synthetic_dataarray) + transformed_data = converter.transform(synthetic_dataarray) + inverse_transformed_data = converter.inverse_transform_data(transformed_data) + + is_dask_before = data_is_dask(synthetic_dataarray) + is_dask_after = data_is_dask(transformed_data) + + # Transforming doesn't change the dask-ness of the data + assert is_dask_before == is_dask_after + + has_multiindex_before = data_has_multiindex(synthetic_dataarray) + has_multiindex_after = data_has_multiindex(inverse_transformed_data) + + assert inverse_transformed_data.identical(synthetic_dataarray) + assert has_multiindex_before == has_multiindex_after diff --git a/tests/preprocessing/test_dataarray_renamer.py b/tests/preprocessing/test_dataarray_renamer.py new file mode 100644 index 0000000..efb5d25 --- /dev/null +++ b/tests/preprocessing/test_dataarray_renamer.py @@ -0,0 +1,86 @@ +import pytest + +from xeofs.preprocessing.dimension_renamer import DimensionRenamer +from ..utilities import ( + data_is_dask, + get_dims_from_data, +) + +# ============================================================================= +# GENERALLY VALID TEST CASES +# ============================================================================= +N_SAMPLE_DIMS = [1, 2] +N_FEATURE_DIMS = [1, 2] +INDEX_POLICY = ["index", "multiindex"] +NAN_POLICY = ["no_nan"] +DASK_POLICY = ["no_dask", "dask"] +SEED = [0] + +VALID_TEST_DATA = [ + (ns, nf, index, nan, dask) + for ns in N_SAMPLE_DIMS + for nf in N_FEATURE_DIMS + for index in INDEX_POLICY + for nan in NAN_POLICY + for dask in DASK_POLICY +] + + +@pytest.mark.parametrize( + "synthetic_dataarray", + VALID_TEST_DATA, + indirect=["synthetic_dataarray"], +) +def test_transform(synthetic_dataarray): + all_dims, sample_dims, feature_dims = get_dims_from_data(synthetic_dataarray) + + n_dims = len(all_dims) + + base = "new" + start = 10 + expected_dims = set(base + str(i) for i in range(start, start + n_dims)) + + renamer = DimensionRenamer(base=base, start=start) + renamer.fit(synthetic_dataarray, sample_dims, feature_dims) + transformed_data = renamer.transform(synthetic_dataarray) + + is_dask_before = data_is_dask(synthetic_dataarray) + is_dask_after = data_is_dask(transformed_data) + + # Transforming doesn't change the dask-ness of the data + assert is_dask_before == is_dask_after + + # Transforming converts dimension names + given_dims = set(transformed_data.dims) + assert given_dims == expected_dims + + # Result is robust to calling the method multiple times + transformed_data = renamer.transform(synthetic_dataarray) + given_dims = set(transformed_data.dims) + assert given_dims == expected_dims + + +@pytest.mark.parametrize( + "synthetic_dataarray", + VALID_TEST_DATA, + indirect=["synthetic_dataarray"], +) +def test_inverse_transform_data(synthetic_dataarray): + all_dims, sample_dims, feature_dims = get_dims_from_data(synthetic_dataarray) + + base = "new" + start = 10 + + renamer = DimensionRenamer(base=base, start=start) + renamer.fit(synthetic_dataarray, sample_dims, feature_dims) + transformed_data = renamer.transform(synthetic_dataarray) + inverse_transformed_data = renamer.inverse_transform_data(transformed_data) + + is_dask_before = data_is_dask(synthetic_dataarray) + is_dask_after = data_is_dask(transformed_data) + + # Transforming doesn't change the dask-ness of the data + assert is_dask_before == is_dask_after + + assert inverse_transformed_data.identical(synthetic_dataarray) + assert set(inverse_transformed_data.dims) == set(synthetic_dataarray.dims) diff --git a/tests/preprocessing/test_dataarray_sanitizer.py b/tests/preprocessing/test_dataarray_sanitizer.py new file mode 100644 index 0000000..ac46af2 --- /dev/null +++ b/tests/preprocessing/test_dataarray_sanitizer.py @@ -0,0 +1,213 @@ +import pytest +import numpy as np +import xarray as xr + +from xeofs.preprocessing.sanitizer import Sanitizer +from xeofs.utils.data_types import DataArray +from ..conftest import generate_synthetic_dataarray +from ..utilities import ( + data_is_dask, + assert_expected_dims, + assert_expected_coords, +) + +# ============================================================================= +# VALID TEST CASES +# ============================================================================= +N_SAMPLE_DIMS = [1] +N_FEATURE_DIMS = [1] +INDEX_POLICY = ["index"] +NAN_POLICY = ["no_nan", "fulldim"] +DASK_POLICY = ["no_dask", "dask"] +SEED = [0] + +VALID_TEST_DATA = [ + (ns, nf, index, nan, dask) + for ns in N_SAMPLE_DIMS + for nf in N_FEATURE_DIMS + for index in INDEX_POLICY + for nan in NAN_POLICY + for dask in DASK_POLICY +] + + +# TESTS +# ============================================================================= +@pytest.mark.parametrize( + "sample_name, feature_name, data_params", + [ + ("sample", "feature", (1, 1)), + ("another_sample", "another_feature", (1, 1)), + ], +) +def test_fit_valid_dimension_names(sample_name, feature_name, data_params): + data = generate_synthetic_dataarray(*data_params) + data = data.rename({"sample0": sample_name, "feature0": feature_name}) + + sanitizer = Sanitizer(sample_name=sample_name, feature_name=feature_name) + sanitizer.fit(data) + data_clean = sanitizer.transform(data) + reconstructed_data = sanitizer.inverse_transform_data(data_clean) + + assert data_clean.ndim == 2 + assert set(data_clean.dims) == set((sample_name, feature_name)) + assert set(reconstructed_data.dims) == set(data.dims) + + +@pytest.mark.parametrize( + "sample_name, feature_name, data_params", + [ + ("sample1", "feature", (1, 1)), + ("sample", "feature1", (1, 1)), + ("sample1", "feature1", (1, 1)), + ], +) +def test_fit_invalid_dimension_names(sample_name, feature_name, data_params): + data = generate_synthetic_dataarray(*data_params) + + sanitizer = Sanitizer(sample_name=sample_name, feature_name=feature_name) + + with pytest.raises(ValueError): + sanitizer.fit(data) + + +@pytest.mark.parametrize( + "synthetic_dataarray", + VALID_TEST_DATA, + indirect=["synthetic_dataarray"], +) +def test_transform(synthetic_dataarray): + data = synthetic_dataarray + data = data.rename({"sample0": "sample", "feature0": "feature"}) + + sanitizer = Sanitizer() + sanitizer.fit(data) + transformed_data = sanitizer.transform(data) + transformed_data2 = sanitizer.transform(data) + + is_dask_before = data_is_dask(data) + is_dask_after = data_is_dask(transformed_data) + + assert transformed_data.notnull().all() + assert isinstance(transformed_data, DataArray) + assert transformed_data.ndim == 2 + assert transformed_data.dims == data.dims + assert is_dask_before == is_dask_after + assert transformed_data.identical(transformed_data2) + + +@pytest.mark.parametrize( + "synthetic_dataarray", + VALID_TEST_DATA, + indirect=["synthetic_dataarray"], +) +def test_transform_invalid(synthetic_dataarray): + data = synthetic_dataarray + data = data.rename({"sample0": "sample", "feature0": "feature"}) + + sanitizer = Sanitizer() + sanitizer.fit(data) + with pytest.raises(ValueError): + sanitizer.transform(data.isel(feature0=slice(0, 2))) + + +@pytest.mark.parametrize( + "synthetic_dataarray", + VALID_TEST_DATA, + indirect=["synthetic_dataarray"], +) +def test_fit_transform(synthetic_dataarray): + data = synthetic_dataarray + data = data.rename({"sample0": "sample", "feature0": "feature"}) + + sanitizer = Sanitizer() + transformed_data = sanitizer.fit_transform(data) + + is_dask_before = data_is_dask(data) + is_dask_after = data_is_dask(transformed_data) + + assert isinstance(transformed_data, DataArray) + assert transformed_data.notnull().all() + assert transformed_data.ndim == 2 + assert transformed_data.dims == data.dims + assert is_dask_before == is_dask_after + + +@pytest.mark.parametrize( + "synthetic_dataarray", + VALID_TEST_DATA, + indirect=["synthetic_dataarray"], +) +def test_invserse_transform_data(synthetic_dataarray): + data = synthetic_dataarray + data = data.rename({"sample0": "sample", "feature0": "feature"}) + + sanitizer = Sanitizer() + sanitizer.fit(data) + cleaned_data = sanitizer.transform(data) + uncleaned_data = sanitizer.inverse_transform_data(cleaned_data) + + is_dask_before = data_is_dask(data) + is_dask_after = data_is_dask(uncleaned_data) + + # inverse transform is only identical if nan_policy={"no_nan", "fulldim"} + # in case of "isolated" the inverse transform will set the entire feature column + # to NaNs, which is not identical to the original data + # assert data.identical(uncleaned_data) + + # inverse transform should not change dask-ness + assert is_dask_before == is_dask_after + + +@pytest.mark.parametrize( + "synthetic_dataarray", + VALID_TEST_DATA, + indirect=["synthetic_dataarray"], +) +def test_invserse_transform_components(synthetic_dataarray): + data: DataArray = synthetic_dataarray + data = data.rename({"sample0": "sample", "feature0": "feature"}) + + sanitizer = Sanitizer() + sanitizer.fit(data) + + stacked_data = sanitizer.transform(data) + components = stacked_data.rename({"sample": "mode"}) + unstacked_data = sanitizer.inverse_transform_components(components) + + is_dask_before = data_is_dask(data) + is_dask_after = data_is_dask(unstacked_data) + + # Unstacked components has correct feature dimensions + assert_expected_dims(data, unstacked_data, policy="feature") + # Unstacked data has coordinates of original data + assert_expected_coords(data, unstacked_data, policy="feature") + # inverse transform should not change dask-ness + assert is_dask_before == is_dask_after + + +@pytest.mark.parametrize( + "synthetic_dataarray", + VALID_TEST_DATA, + indirect=["synthetic_dataarray"], +) +def test_invserse_transform_scores(synthetic_dataarray): + data: DataArray = synthetic_dataarray + data = data.rename({"sample0": "sample", "feature0": "feature"}) + + sanitizer = Sanitizer() + sanitizer.fit(data) + + stacked_data = sanitizer.transform(data) + components = stacked_data.rename({"feature": "mode"}) + unstacked_data = sanitizer.inverse_transform_scores(components) + + is_dask_before = data_is_dask(data) + is_dask_after = data_is_dask(unstacked_data) + + # Unstacked components has correct feature dimensions + assert_expected_dims(data, unstacked_data, policy="sample") + # Unstacked data has coordinates of original data + assert_expected_coords(data, unstacked_data, policy="sample") + # inverse transform should not change dask-ness + assert is_dask_before == is_dask_after diff --git a/tests/preprocessing/test_dataarray_scaler.py b/tests/preprocessing/test_dataarray_scaler.py new file mode 100644 index 0000000..a0c7b6a --- /dev/null +++ b/tests/preprocessing/test_dataarray_scaler.py @@ -0,0 +1,193 @@ +import pytest +import xarray as xr +import numpy as np + +from xeofs.preprocessing.scaler import Scaler + + +@pytest.mark.parametrize( + "with_std, with_coslat", + [ + (True, True), + (True, True), + (True, False), + (True, False), + (False, True), + (False, True), + (False, False), + (False, False), + (True, True), + (True, True), + (True, False), + (True, False), + (False, True), + (False, True), + (False, False), + (False, False), + ], +) +def test_init_params(with_std, with_coslat): + s = Scaler(with_std=with_std, with_coslat=with_coslat) + assert s.get_params()["with_std"] == with_std + assert s.get_params()["with_coslat"] == with_coslat + + +@pytest.mark.parametrize( + "with_std, with_coslat", + [ + (True, True), + (True, False), + (False, True), + (False, False), + ], +) +def test_fit_params(with_std, with_coslat, mock_data_array): + s = Scaler(with_std=with_std, with_coslat=with_coslat) + sample_dims = ["time"] + feature_dims = ["lat", "lon"] + size_lats = mock_data_array.lat.size + weights = xr.DataArray(np.random.rand(size_lats), dims=["lat"]) + s.fit(mock_data_array, sample_dims, feature_dims, weights) + assert hasattr(s, "mean_"), "Scaler has no mean attribute." + if with_std: + assert hasattr(s, "std_"), "Scaler has no std attribute." + if with_coslat: + assert hasattr(s, "coslat_weights_"), "Scaler has no coslat_weights attribute." + assert s.mean_ is not None, "Scaler mean is None." + if with_std: + assert s.std_ is not None, "Scaler std is None." + if with_coslat: + assert s.coslat_weights_ is not None, "Scaler coslat_weights is None." + + +@pytest.mark.parametrize( + "with_std, with_coslat, with_weights", + [ + (True, True, True), + (True, False, True), + (False, True, True), + (False, False, True), + (True, True, False), + (True, False, False), + (False, True, False), + (False, False, False), + ], +) +def test_transform_params(with_std, with_coslat, with_weights, mock_data_array): + s = Scaler(with_std=with_std, with_coslat=with_coslat) + sample_dims = ["time"] + feature_dims = ["lat", "lon"] + size_lats = mock_data_array.lat.size + if with_weights: + weights = xr.DataArray( + np.random.rand(size_lats), dims=["lat"], coords={"lat": mock_data_array.lat} + ) + else: + weights = None + s.fit(mock_data_array, sample_dims, feature_dims, weights) + transformed = s.transform(mock_data_array) + assert transformed is not None, "Transformed data is None." + + transformed_mean = transformed.mean(sample_dims, skipna=False) + assert np.allclose(transformed_mean, 0), "Mean of the transformed data is not zero." + + if with_std and not (with_coslat or with_weights): + transformed_std = transformed.std(sample_dims, skipna=False) + + assert np.allclose( + transformed_std, 1 + ), "Standard deviation of the transformed data is not one." + + if with_coslat: + assert s.coslat_weights_ is not None, "Scaler coslat_weights is None." + assert not np.array_equal( + transformed, mock_data_array + ), "Data has not been transformed." + + transformed2 = s.fit_transform(mock_data_array, sample_dims, feature_dims, weights) + xr.testing.assert_allclose(transformed, transformed2) + + +@pytest.mark.parametrize( + "with_std, with_coslat", + [ + (True, True), + (True, False), + (False, True), + (False, False), + ], +) +def test_inverse_transform_params(with_std, with_coslat, mock_data_array): + s = Scaler( + with_std=with_std, + with_coslat=with_coslat, + ) + sample_dims = ["time"] + feature_dims = ["lat", "lon"] + size_lats = mock_data_array.lat.size + weights = xr.DataArray( + np.random.rand(size_lats), dims=["lat"], coords={"lat": mock_data_array.lat} + ) + s.fit(mock_data_array, sample_dims, feature_dims, weights) + transformed = s.transform(mock_data_array) + inverted = s.inverse_transform_data(transformed) + xr.testing.assert_allclose(inverted, mock_data_array) + + +@pytest.mark.parametrize( + "dim_sample, dim_feature", + [ + (("time",), ("lat", "lon")), + (("time",), ("lon", "lat")), + (("lat", "lon"), ("time",)), + (("lon", "lat"), ("time",)), + ], +) +def test_fit_dims(dim_sample, dim_feature, mock_data_array): + s = Scaler(with_std=True) + s.fit(mock_data_array, dim_sample, dim_feature) + assert hasattr(s, "mean_"), "Scaler has no mean attribute." + assert s.mean_ is not None, "Scaler mean is None." + assert hasattr(s, "std_"), "Scaler has no std attribute." + assert s.std_ is not None, "Scaler std is None." + # check that all dimensions are present except the sample dimensions + assert set(s.mean_.dims) == set(mock_data_array.dims) - set( + dim_sample + ), "Mean has wrong dimensions." + assert set(s.std_.dims) == set(mock_data_array.dims) - set( + dim_sample + ), "Standard deviation has wrong dimensions." + + +@pytest.mark.parametrize( + "dim_sample, dim_feature", + [ + (("time",), ("lat", "lon")), + (("time",), ("lon", "lat")), + (("lat", "lon"), ("time",)), + (("lon", "lat"), ("time",)), + ], +) +def test_fit_transform_dims(dim_sample, dim_feature, mock_data_array): + s = Scaler() + transformed = s.fit_transform(mock_data_array, dim_sample, dim_feature) + # check that all dimensions are present + assert set(transformed.dims) == set( + mock_data_array.dims + ), "Transformed data has wrong dimensions." + # check that the coordinates are the same + for dim in mock_data_array.dims: + xr.testing.assert_allclose(transformed[dim], mock_data_array[dim]) + + +# Test input types +def test_fit_input_type(mock_data_array, mock_dataset, mock_data_array_list): + s = Scaler() + + with pytest.raises(TypeError): + s.fit(mock_data_array_list, ["time"], ["lon", "lat"]) + + s.fit(mock_data_array, ["time"], ["lon", "lat"]) + + with pytest.raises(TypeError): + s.transform(mock_data_array_list) diff --git a/tests/preprocessing/test_dataarray_stacker.py b/tests/preprocessing/test_dataarray_stacker.py index 14f697a..5882f13 100644 --- a/tests/preprocessing/test_dataarray_stacker.py +++ b/tests/preprocessing/test_dataarray_stacker.py @@ -1,204 +1,230 @@ import pytest -import xarray as xr import numpy as np +import xarray as xr -from xeofs.preprocessing.stacker import SingleDataArrayStacker - +from xeofs.preprocessing.stacker import DataArrayStacker +from xeofs.utils.data_types import DataArray +from ..conftest import generate_synthetic_dataarray +from ..utilities import ( + get_dims_from_data, + data_is_dask, + assert_expected_dims, + assert_expected_coords, +) +# ============================================================================= +# GENERALLY VALID TEST CASES +# ============================================================================= +N_SAMPLE_DIMS = [1, 2] +N_FEATURE_DIMS = [1, 2] +INDEX_POLICY = ["index"] +NAN_POLICY = ["no_nan"] +DASK_POLICY = ["no_dask", "dask"] +SEED = [0] + +VALID_TEST_DATA = [ + (ns, nf, index, nan, dask) + for ns in N_SAMPLE_DIMS + for nf in N_FEATURE_DIMS + for index in INDEX_POLICY + for nan in NAN_POLICY + for dask in DASK_POLICY +] + + +# TESTS +# ============================================================================= @pytest.mark.parametrize( - "dim_sample, dim_feature", + "sample_name, feature_name, data_params", [ - (("time",), ("lat", "lon")), - (("time",), ("lon", "lat")), - (("lat", "lon"), ("time",)), - (("lon", "lat"), ("time",)), + ("sample", "feature", (1, 1)), + ("sample0", "feature0", (1, 1)), + ("sample0", "feature", (1, 2)), + ("sample", "feature0", (2, 1)), + ("sample", "feature", (2, 2)), + ("another_sample", "another_feature", (1, 1)), + ("another_sample", "another_feature", (2, 2)), ], ) -def test_fit_transform( - dim_sample, - dim_feature, - mock_data_array, - mock_data_array_isolated_nans, - mock_data_array_full_dimensional_nans, - mock_data_array_boundary_nans, -): - # Test basic functionality - stacker = SingleDataArrayStacker() - stacked = stacker.fit_transform(mock_data_array, dim_sample, dim_feature) - assert stacked.ndim == 2 - assert set(stacked.dims) == {"sample", "feature"} - assert not stacked.isnull().any() - - # Test that the operation is reversible - unstacked = stacker.inverse_transform_data(stacked) - xr.testing.assert_equal(unstacked, mock_data_array) - - # Test that isolated NaNs raise an error - with pytest.raises(ValueError): - stacker.fit_transform(mock_data_array_isolated_nans, dim_sample, dim_feature) - - # Test that NaNs across a full dimension are handled correctly - stacked = stacker.fit_transform( - mock_data_array_full_dimensional_nans, dim_sample, dim_feature - ) - unstacked = stacker.inverse_transform_data(stacked) - xr.testing.assert_equal(unstacked, mock_data_array_full_dimensional_nans) - - # Test that NaNs on the boundary are handled correctly - stacked = stacker.fit_transform( - mock_data_array_boundary_nans, dim_sample, dim_feature - ) - unstacked = stacker.inverse_transform_data(stacked) - xr.testing.assert_equal(unstacked, mock_data_array_boundary_nans) - - # Test that the same stacker cannot be used with data of different shapes - with pytest.raises(ValueError): - other_data = mock_data_array.isel(time=slice(None, -1), lon=slice(None, -1)) - stacker.transform(other_data) +def test_fit_valid_dimension_names(sample_name, feature_name, data_params): + data = generate_synthetic_dataarray(*data_params) + all_dims, sample_dims, feature_dims = get_dims_from_data(data) + + stacker = DataArrayStacker(sample_name=sample_name, feature_name=feature_name) + stacker.fit(data, sample_dims, feature_dims) + stacked_data = stacker.transform(data) + reconstructed_data = stacker.inverse_transform_data(stacked_data) + + assert stacked_data.ndim == 2 + assert set(stacked_data.dims) == set((sample_name, feature_name)) + assert set(reconstructed_data.dims) == set(data.dims) @pytest.mark.parametrize( - "dim_sample, dim_feature", + "sample_name, feature_name, data_params", [ - (("time",), ("lat", "lon")), - (("time",), ("lon", "lat")), - (("lat", "lon"), ("time",)), - (("lon", "lat"), ("time",)), + ("sample1", "feature", (2, 1)), + ("sample", "feature1", (1, 2)), + ("sample1", "feature1", (3, 3)), ], ) -def test_transform(mock_data_array, dim_sample, dim_feature): - # Test basic functionality - stacker = SingleDataArrayStacker() - stacker.fit_transform(mock_data_array, dim_sample, dim_feature) - other_data = mock_data_array.copy(deep=True) - transformed = stacker.transform(other_data) - - # Test that transformed data has the correct dimensions - assert transformed.ndim == 2 - assert set(transformed.dims) == {"sample", "feature"} - assert not transformed.isnull().any() - - # Invalid data raises an error +def test_fit_invalid_dimension_names(sample_name, feature_name, data_params): + data = generate_synthetic_dataarray(*data_params) + all_dims, sample_dims, feature_dims = get_dims_from_data(data) + + stacker = DataArrayStacker(sample_name=sample_name, feature_name=feature_name) + with pytest.raises(ValueError): - stacker.transform(mock_data_array.isel(lon=slice(None, 2), time=slice(None, 2))) + stacker.fit(data, sample_dims, feature_dims) @pytest.mark.parametrize( - "dim_sample, dim_feature", - [ - (("time",), ("lat", "lon")), - (("time",), ("lon", "lat")), - (("lat", "lon"), ("time",)), - (("lon", "lat"), ("time",)), - ], + "synthetic_dataarray", + VALID_TEST_DATA, + indirect=["synthetic_dataarray"], ) -def test_inverse_transform_data(mock_data_array, dim_sample, dim_feature): - # Test inverse transform - stacker = SingleDataArrayStacker() - stacker.fit_transform(mock_data_array, dim_sample, dim_feature) - stacked = stacker.transform(mock_data_array) - unstacked = stacker.inverse_transform_data(stacked) - xr.testing.assert_equal(unstacked, mock_data_array) +def test_fit(synthetic_dataarray): + data = synthetic_dataarray + all_dims, sample_dims, feature_dims = get_dims_from_data(data) - # Test that the operation is reversible - restacked = stacker.transform(unstacked) - xr.testing.assert_equal(restacked, stacked) + stacker = DataArrayStacker() + stacker.fit(data, sample_dims, feature_dims) @pytest.mark.parametrize( - "dim_sample, dim_feature", - [ - (("time",), ("lat", "lon")), - (("time",), ("lon", "lat")), - (("lat", "lon"), ("time",)), - (("lon", "lat"), ("time",)), - ], + "synthetic_dataarray", + VALID_TEST_DATA, + indirect=["synthetic_dataarray"], ) -def test_inverse_transform_components(mock_data_array, dim_sample, dim_feature): - # Test basic functionality - stacker = SingleDataArrayStacker() - stacker.fit_transform(mock_data_array, dim_sample, dim_feature) - components = xr.DataArray( - np.random.normal(size=(len(stacker.coords_out_["feature"]), 10)), - dims=("feature", "mode"), - coords={"feature": stacker.coords_out_["feature"]}, - ) - unstacked = stacker.inverse_transform_components(components) - - # Test that feature dimensions are preserved - assert set(unstacked.dims) == set(dim_feature + ("mode",)) - - # Test that feature coordinates are preserved - for dim, coords in mock_data_array.coords.items(): - if dim in dim_feature: - assert ( - unstacked.coords[dim].size == coords.size - ), "Dimension {} has different size.".format(dim) +def test_transform(synthetic_dataarray): + data = synthetic_dataarray + all_dims, sample_dims, feature_dims = get_dims_from_data(data) + + stacker = DataArrayStacker() + stacker.fit(data, sample_dims, feature_dims) + transformed_data = stacker.transform(data) + transformed_data2 = stacker.transform(data) + + is_dask_before = data_is_dask(data) + is_dask_after = data_is_dask(transformed_data) + + assert isinstance(transformed_data, DataArray) + assert transformed_data.ndim == 2 + assert transformed_data.dims == ("sample", "feature") + assert is_dask_before == is_dask_after + assert transformed_data.identical(transformed_data2) @pytest.mark.parametrize( - "dim_sample, dim_feature", - [ - (("time",), ("lat", "lon")), - (("time",), ("lon", "lat")), - (("lat", "lon"), ("time",)), - (("lon", "lat"), ("time",)), - ], + "synthetic_dataarray", + VALID_TEST_DATA, + indirect=["synthetic_dataarray"], +) +def test_transform_invalid(synthetic_dataarray): + data = synthetic_dataarray + all_dims, sample_dims, feature_dims = get_dims_from_data(data) + + stacker = DataArrayStacker() + stacker.fit(data, sample_dims, feature_dims) + with pytest.raises(ValueError): + stacker.transform(data.isel(feature0=slice(0, 2))) + + +@pytest.mark.parametrize( + "synthetic_dataarray", + VALID_TEST_DATA, + indirect=["synthetic_dataarray"], +) +def test_fit_transform(synthetic_dataarray): + data = synthetic_dataarray + all_dims, sample_dims, feature_dims = get_dims_from_data(data) + + stacker = DataArrayStacker() + transformed_data = stacker.fit_transform(data, sample_dims, feature_dims) + + is_dask_before = data_is_dask(data) + is_dask_after = data_is_dask(transformed_data) + + assert isinstance(transformed_data, DataArray) + assert transformed_data.ndim == 2 + assert transformed_data.dims == ("sample", "feature") + assert is_dask_before == is_dask_after + + +@pytest.mark.parametrize( + "synthetic_dataarray", + VALID_TEST_DATA, + indirect=["synthetic_dataarray"], +) +def test_invserse_transform_data(synthetic_dataarray): + data = synthetic_dataarray + all_dims, sample_dims, feature_dims = get_dims_from_data(data) + + stacker = DataArrayStacker() + stacker.fit(data, sample_dims, feature_dims) + stacked_data = stacker.transform(data) + unstacked_data = stacker.inverse_transform_data(stacked_data) + + is_dask_before = data_is_dask(data) + is_dask_after = data_is_dask(unstacked_data) + + # Unstacked data has dimensions of original data + assert_expected_dims(data, unstacked_data, policy="all") + # Unstacked data has coordinates of original data + assert_expected_coords(data, unstacked_data, policy="all") + # inverse transform should not change dask-ness + assert is_dask_before == is_dask_after + + +@pytest.mark.parametrize( + "synthetic_dataarray", + VALID_TEST_DATA, + indirect=["synthetic_dataarray"], +) +def test_invserse_transform_components(synthetic_dataarray): + data: DataArray = synthetic_dataarray + all_dims, sample_dims, feature_dims = get_dims_from_data(data) + + stacker = DataArrayStacker() + stacker.fit(data, sample_dims, feature_dims) + + stacked_data = stacker.transform(data) + components = stacked_data.rename({"sample": "mode"}) + unstacked_data = stacker.inverse_transform_components(components) + + is_dask_before = data_is_dask(data) + is_dask_after = data_is_dask(unstacked_data) + + # Unstacked components has correct feature dimensions + assert_expected_dims(data, unstacked_data, policy="feature") + # Unstacked data has coordinates of original data + assert_expected_coords(data, unstacked_data, policy="feature") + # inverse transform should not change dask-ness + assert is_dask_before == is_dask_after + + +@pytest.mark.parametrize( + "synthetic_dataarray", + VALID_TEST_DATA, + indirect=["synthetic_dataarray"], ) -def test_inverse_transform_scores(mock_data_array, dim_sample, dim_feature): - # Test basic functionality - stacker = SingleDataArrayStacker() - stacker.fit_transform(mock_data_array, dim_sample, dim_feature) - scores = xr.DataArray( - np.random.rand(len(stacker.coords_out_["sample"]), 10), - dims=("sample", "mode"), - coords={"sample": stacker.coords_out_["sample"]}, - ) - unstacked = stacker.inverse_transform_scores(scores) - - # Test that sample dimensions are preserved - assert set(unstacked.dims) == set(dim_sample + ("mode",)) - - # Test that sample coordinates are preserved - for dim, coords in mock_data_array.coords.items(): - if dim in dim_sample: - assert ( - unstacked.coords[dim].size == coords.size - ), "Dimension {} has different size.".format(dim) - - -def test_fit_transform_sample_feature_data(): - """Test fit_transform with sample and feature data.""" - # Create sample and feature data - np.random.seed(5) - simple_data = xr.DataArray( - np.random.rand(10, 5), - dims=("sample", "feature"), - coords={"sample": np.arange(10), "feature": np.arange(5)}, - ) - np.random.seed(5) - more_simple_data = xr.DataArray( - np.random.rand(10, 5), - dims=("sample", "feature"), - coords={"sample": np.arange(10), "feature": np.arange(5)}, - ) - - # Create stacker and fit_transform - stacker = SingleDataArrayStacker() - stacked = stacker.fit_transform(simple_data, ("sample",), ("feature")) - - # Test that the dimensions are correct - assert stacked.ndim == 2 - assert set(stacked.dims) == {"sample", "feature"} - assert not stacked.isnull().any() - - # Test that fitting new data yields the same results - more_stacked = stacker.transform(more_simple_data) - xr.testing.assert_equal(more_stacked, stacked) - - # Test that the operation is reversible - unstacked = stacker.inverse_transform_data(stacked) - xr.testing.assert_equal(unstacked, simple_data) - - more_unstacked = stacker.inverse_transform_data(more_stacked) - xr.testing.assert_equal(more_unstacked, more_simple_data) +def test_invserse_transform_scores(synthetic_dataarray): + data: DataArray = synthetic_dataarray + all_dims, sample_dims, feature_dims = get_dims_from_data(data) + + stacker = DataArrayStacker() + stacker.fit(data, sample_dims, feature_dims) + + stacked_data = stacker.transform(data) + components = stacked_data.rename({"feature": "mode"}) + unstacked_data = stacker.inverse_transform_scores(components) + + is_dask_before = data_is_dask(data) + is_dask_after = data_is_dask(unstacked_data) + + # Unstacked components has correct feature dimensions + assert_expected_dims(data, unstacked_data, policy="sample") + # Unstacked data has coordinates of original data + assert_expected_coords(data, unstacked_data, policy="sample") + # inverse transform should not change dask-ness + assert is_dask_before == is_dask_after diff --git a/tests/preprocessing/test_dataarray_stacker_stack.py b/tests/preprocessing/test_dataarray_stacker_stack.py deleted file mode 100644 index 7389559..0000000 --- a/tests/preprocessing/test_dataarray_stacker_stack.py +++ /dev/null @@ -1,132 +0,0 @@ -import pytest -import xarray as xr -import numpy as np - -from xeofs.preprocessing.stacker import SingleDataArrayStacker - - -def create_da(dim_sample, dim_feature, seed=None): - n_dims = len(dim_sample) + len(dim_feature) - size = n_dims * [3] - rng = np.random.default_rng(seed) - dims = dim_sample + dim_feature - coords = {d: np.arange(i, i + 3) for i, d in enumerate(dims)} - return xr.DataArray(rng.normal(0, 1, size=size), dims=dims, coords=coords) - - -# Valid input -# ============================================================================= -valid_input_dims = [ - (("year", "month"), ("lon", "lat")), - (("year",), ("lat", "lon")), - (("year", "month"), ("lon",)), - (("year",), ("lon",)), - (("sample",), ("feature",)), -] - -valid_input = [] -for dim_sample, dim_feature in valid_input_dims: - da = create_da(dim_sample, dim_feature) - valid_input.append((da, dim_sample, dim_feature)) - - -# Invalid input -# ============================================================================= -invalid_input_dims = [ - (("sample",), ("feature", "lat")), - (("sample",), ("month", "feature")), - (("sample", "month"), ("lon", "lat")), - (("sample",), ("lon", "lat")), - (("year",), ("month", "sample")), - (("year",), ("sample",)), - (("sample",), ("lon",)), - (("year", "month"), ("lon", "feature")), - (("year", "month"), ("feature",)), - (("year",), ("feature",)), - (("feature",), ("lon", "lat")), - (("feature",), ("lon",)), - (("feature",), ("sample",)), -] -invalid_input = [] -for dim_sample, dim_feature in invalid_input_dims: - da = create_da(dim_sample, dim_feature) - invalid_input.append((da, dim_sample, dim_feature)) - - -# Test stacking -# ============================================================================= -@pytest.mark.parametrize("da, dim_sample, dim_feature", valid_input) -def test_fit_transform(da, dim_sample, dim_feature): - """Test fit_transform with valid input.""" - stacker = SingleDataArrayStacker() - da_stacked = stacker.fit_transform(da, dim_sample, dim_feature) - - # Stacked data has dimensions (sample, feature) - err_msg = f"In: {da.dims}; Out: {da_stacked.dims}" - assert set(da_stacked.dims) == { - "sample", - "feature", - }, err_msg - - -@pytest.mark.parametrize("da, dim_sample, dim_feature", invalid_input) -def test_fit_transform_invalid_input(da, dim_sample, dim_feature): - """Test fit_transform with invalid input.""" - stacker = SingleDataArrayStacker() - with pytest.raises(ValueError): - da_stacked = stacker.fit_transform(da, dim_sample, dim_feature) - - -@pytest.mark.parametrize("da, dim_sample, dim_feature", valid_input) -def test_inverse_transform_data(da, dim_sample, dim_feature): - """Test inverse transform with valid input.""" - stacker = SingleDataArrayStacker() - da_stacked = stacker.fit_transform(da, dim_sample, dim_feature) - da_unstacked = stacker.inverse_transform_data(da_stacked) - - # Unstacked data has dimensions of original data - err_msg = f"Original: {da.dims}; Recovered: {da_unstacked.dims}" - assert set(da_unstacked.dims) == set(da.dims), err_msg - # Unstacked data has coordinates of original data - for d in da.dims: - assert np.all(da_unstacked.coords[d].values == da.coords[d].values) - - -@pytest.mark.parametrize("da, dim_sample, dim_feature", valid_input) -def test_inverse_transform_components(da, dim_sample, dim_feature): - """Test inverse transform components with valid input.""" - stacker = SingleDataArrayStacker() - da_stacked = stacker.fit_transform(da, dim_sample, dim_feature) - # Mock components by dropping sampling dim from data - comps_stacked = da_stacked.drop_vars("sample").rename({"sample": "mode"}) - comps_stacked.coords.update({"mode": range(comps_stacked.mode.size)}) - - comps_unstacked = stacker.inverse_transform_components(comps_stacked) - - # Unstacked components has correct feature dimensions - expected_dims = dim_feature + ("mode",) - err_msg = f"Expected: {expected_dims}; Recovered: {comps_unstacked.dims}" - assert set(comps_unstacked.dims) == set(expected_dims), err_msg - # Unstacked data has coordinates of original data - for d in dim_feature: - assert np.all(comps_unstacked.coords[d].values == da.coords[d].values) - - -@pytest.mark.parametrize("da, dim_sample, dim_feature", valid_input) -def test_inverse_transform_scores(da, dim_sample, dim_feature): - """Test inverse transform scores with valid input.""" - stacker = SingleDataArrayStacker() - da_stacked = stacker.fit_transform(da, dim_sample, dim_feature) - # Mock scores by dropping feature dim from data - scores_stacked = da_stacked.drop_vars("feature").rename({"feature": "mode"}) - scores_stacked.coords.update({"mode": range(scores_stacked.mode.size)}) - - scores_unstacked = stacker.inverse_transform_scores(scores_stacked) - - # Unstacked components has correct feature dimensions - expected_dims = dim_sample + ("mode",) - err_msg = f"Expected: {expected_dims}; Recovered: {scores_unstacked.dims}" - assert set(scores_unstacked.dims) == set(expected_dims), err_msg - # Unstacked data has coordinates of original data - for d in dim_sample: - assert np.all(scores_unstacked.coords[d].values == da.coords[d].values) diff --git a/tests/preprocessing/test_datalist_multiindex_converter.py b/tests/preprocessing/test_datalist_multiindex_converter.py new file mode 100644 index 0000000..1e8003c --- /dev/null +++ b/tests/preprocessing/test_datalist_multiindex_converter.py @@ -0,0 +1,93 @@ +# import pytest +# import pandas as pd + +# from xeofs.preprocessing.multi_index_converter import ( +# DataListMultiIndexConverter, +# ) +# from xeofs.utils.data_types import DataArray +# from ..utilities import assert_expected_dims, data_is_dask, data_has_multiindex + +# # ============================================================================= +# # GENERALLY VALID TEST CASES +# # ============================================================================= +# N_ARRAYS = [1, 2] +# N_SAMPLE_DIMS = [1, 2] +# N_FEATURE_DIMS = [1, 2] +# INDEX_POLICY = ["index"] +# NAN_POLICY = ["no_nan"] +# DASK_POLICY = ["no_dask", "dask"] +# SEED = [0] + +# VALID_TEST_DATA = [ +# (na, ns, nf, index, nan, dask) +# for na in N_ARRAYS +# for ns in N_SAMPLE_DIMS +# for nf in N_FEATURE_DIMS +# for index in INDEX_POLICY +# for nan in NAN_POLICY +# for dask in DASK_POLICY +# ] + + +# # TESTS +# # ============================================================================= +# @pytest.mark.parametrize( +# "synthetic_datalist", +# VALID_TEST_DATA, +# indirect=["synthetic_datalist"], +# ) +# def test_transform(synthetic_datalist): +# converter = DataListMultiIndexConverter() +# converter.fit(synthetic_datalist) +# transformed_data = converter.transform(synthetic_datalist) + +# is_dask_before = data_is_dask(synthetic_datalist) +# is_dask_after = data_is_dask(transformed_data) + +# # Transforming does not affect dimensions +# assert_expected_dims(transformed_data, synthetic_datalist, policy="all") + +# # Transforming doesn't change the dask-ness of the data +# assert is_dask_before == is_dask_after + +# # Transforming removes MultiIndex +# assert data_has_multiindex(transformed_data) is False + +# # Result is robust to calling the method multiple times +# transformed_data = converter.transform(synthetic_datalist) +# assert data_has_multiindex(transformed_data) is False + +# # Transforming data twice won't change the data +# transformed_data2 = converter.transform(transformed_data) +# assert data_has_multiindex(transformed_data2) is False +# assert all( +# trans.identical(data) +# for trans, data in zip(transformed_data, transformed_data2) +# ) + + +# @pytest.mark.parametrize( +# "synthetic_datalist", +# VALID_TEST_DATA, +# indirect=["synthetic_datalist"], +# ) +# def test_inverse_transform(synthetic_datalist): +# converter = DataListMultiIndexConverter() +# converter.fit(synthetic_datalist) +# transformed_data = converter.transform(synthetic_datalist) +# inverse_transformed_data = converter.inverse_transform_data(transformed_data) + +# is_dask_before = data_is_dask(synthetic_datalist) +# is_dask_after = data_is_dask(transformed_data) + +# # Transforming doesn't change the dask-ness of the data +# assert is_dask_before == is_dask_after + +# has_multiindex_before = data_has_multiindex(synthetic_datalist) +# has_multiindex_after = data_has_multiindex(inverse_transformed_data) + +# assert all( +# trans.identical(data) +# for trans, data in zip(inverse_transformed_data, synthetic_datalist) +# ) +# assert has_multiindex_before == has_multiindex_after diff --git a/tests/preprocessing/test_datalist_scaler.py b/tests/preprocessing/test_datalist_scaler.py new file mode 100644 index 0000000..3aa64f4 --- /dev/null +++ b/tests/preprocessing/test_datalist_scaler.py @@ -0,0 +1,215 @@ +# import pytest +# import xarray as xr +# import numpy as np + +# from xeofs.preprocessing.scaler import DataListScaler +# from xeofs.utils.data_types import DimsList + + +# @pytest.mark.parametrize( +# "with_std, with_coslat", +# [ +# (True, True), +# (True, False), +# (False, True), +# (False, False), +# ], +# ) +# def test_fit_params(with_std, with_coslat, mock_data_array_list): +# listscalers = DataListScaler(with_std=with_std, with_coslat=with_coslat) +# data = mock_data_array_list.copy() +# sample_dims = ["time"] +# feature_dims: DimsList = [["lat", "lon"]] * 3 +# size_lats_list = [da.lat.size for da in data] +# weights = [ +# xr.DataArray(np.random.rand(size), dims=["lat"]) for size in size_lats_list +# ] +# listscalers.fit(mock_data_array_list, sample_dims, feature_dims, weights) + +# for s in listscalers.scalers: +# assert hasattr(s, "mean_"), "Scaler has no mean attribute." +# if with_std: +# assert hasattr(s, "std_"), "Scaler has no std attribute." +# if with_coslat: +# assert hasattr( +# s, "coslat_weights_" +# ), "Scaler has no coslat_weights attribute." + +# assert s.mean_ is not None, "Scaler mean is None." +# if with_std: +# assert s.std_ is not None, "Scaler std is None." +# if with_coslat: +# assert s.coslat_weights_ is not None, "Scaler coslat_weights is None." + + +# @pytest.mark.parametrize( +# "with_std, with_coslat, with_weights", +# [ +# (True, True, True), +# (True, False, True), +# (False, True, True), +# (False, False, True), +# (True, True, False), +# (True, False, False), +# (False, True, False), +# (False, False, False), +# ], +# ) +# def test_transform_params(with_std, with_coslat, with_weights, mock_data_array_list): +# listscalers = DataListScaler(with_std=with_std, with_coslat=with_coslat) +# data = mock_data_array_list.copy() +# sample_dims = ["time"] +# feature_dims: DimsList = [("lat", "lon")] * 3 +# size_lats_list = [da.lat.size for da in data] +# if with_weights: +# weights = [ +# xr.DataArray(np.random.rand(size), dims=["lat"]) for size in size_lats_list +# ] +# else: +# weights = None +# listscalers.fit( +# mock_data_array_list, +# sample_dims, +# feature_dims, +# weights, +# ) + +# transformed = listscalers.transform(mock_data_array_list) +# transformed2 = listscalers.fit_transform( +# mock_data_array_list, sample_dims, feature_dims, weights +# ) + +# for t, t2, s, ref in zip(transformed, transformed2, listscalers.scalers, data): +# assert t is not None, "Transformed data is None." + +# t_mean = t.mean(sample_dims, skipna=False) +# assert np.allclose(t_mean, 0), "Mean of the transformed data is not zero." + +# if with_std: +# t_std = t.std(sample_dims, skipna=False) +# if with_coslat or with_weights: +# assert ( +# t_std <= 1 +# ).all(), "Standard deviation of the transformed data is larger one." +# else: +# assert np.allclose( +# t_std, 1 +# ), "Standard deviation of the transformed data is not one." + +# if with_coslat: +# assert s.coslat_weights_ is not None, "Scaler coslat_weights is None." +# assert not np.array_equal( +# t, mock_data_array_list +# ), "Data has not been transformed." + +# xr.testing.assert_allclose(t, t2) + + +# @pytest.mark.parametrize( +# "with_std, with_coslat", +# [ +# (True, True), +# (True, False), +# (False, True), +# (False, False), +# ], +# ) +# def test_inverse_transform_params(with_std, with_coslat, mock_data_array_list): +# listscalers = DataListScaler( +# with_std=with_std, +# with_coslat=with_coslat, +# ) +# data = mock_data_array_list.copy() +# sample_dims = ["time"] +# feature_dims: DimsList = [["lat", "lon"]] * 3 +# size_lats_list = [da.lat.size for da in data] +# weights = [ +# xr.DataArray(np.random.rand(size), dims=["lat"]) for size in size_lats_list +# ] +# listscalers.fit(mock_data_array_list, sample_dims, feature_dims, weights) +# transformed = listscalers.transform(mock_data_array_list) +# inverted = listscalers.inverse_transform_data(transformed) + +# # check that inverse transform is the same as the original data +# for inv, ref in zip(inverted, mock_data_array_list): +# xr.testing.assert_allclose(inv, ref) + + +# @pytest.mark.parametrize( +# "dim_sample, dim_feature", +# [ +# (("time",), ("lat", "lon")), +# (("time",), ("lon", "lat")), +# (("lat", "lon"), ("time",)), +# (("lon", "lat"), ("time",)), +# ], +# ) +# def test_fit_dims(dim_sample, dim_feature, mock_data_array_list): +# listscalers = DataListScaler(with_std=True) +# data = mock_data_array_list.copy() +# dim_feature = [dim_feature] * 3 + +# for s in listscalers.scalers: +# assert hasattr(s, "mean"), "Scaler has no mean attribute." +# assert s.mean is not None, "Scaler mean is None." +# assert hasattr(s, "std"), "Scaler has no std attribute." +# assert s.std is not None, "Scaler std is None." +# # check that all dimensions are present except the sample dimensions +# assert set(s.mean.dims) == set(mock_data_array_list.dims) - set( +# dim_sample +# ), "Mean has wrong dimensions." +# assert set(s.std.dims) == set(mock_data_array_list.dims) - set( +# dim_sample +# ), "Standard deviation has wrong dimensions." + + +# @pytest.mark.parametrize( +# "dim_sample, dim_feature", +# [ +# (("time",), ("lat", "lon")), +# (("time",), ("lon", "lat")), +# (("lat", "lon"), ("time",)), +# (("lon", "lat"), ("time",)), +# ], +# ) +# def test_fit_transform_dims(dim_sample, dim_feature, mock_data_array_list): +# listscalers = DataListScaler(with_std=True) +# data = mock_data_array_list.copy() +# dim_feature = [dim_feature] * 3 +# transformed = listscalers.fit_transform( +# mock_data_array_list, dim_sample, dim_feature +# ) + +# for trns, ref in zip(transformed, mock_data_array_list): +# # check that all dimensions are present +# assert set(trns.dims) == set(ref.dims), "Transformed data has wrong dimensions." +# # check that the coordinates are the same +# for dim in ref.dims: +# xr.testing.assert_allclose(trns[dim], ref[dim]) + + +# # Test input types +# @pytest.mark.parametrize( +# "dim_sample, dim_feature", +# [ +# (("time",), ("lat", "lon")), +# (("time",), ("lon", "lat")), +# (("lat", "lon"), ("time",)), +# (("lon", "lat"), ("time",)), +# ], +# ) +# def test_fit_input_type( +# dim_sample, dim_feature, mock_data_array, mock_dataset, mock_data_array_list +# ): +# s = DataListScaler() +# dim_feature = [dim_feature] * 3 +# with pytest.raises(TypeError): +# s.fit(mock_dataset, dim_sample, dim_feature) +# with pytest.raises(TypeError): +# s.fit(mock_data_array, dim_sample, dim_feature) + +# s.fit(mock_data_array_list, dim_sample, dim_feature) +# with pytest.raises(TypeError): +# s.transform(mock_dataset) +# with pytest.raises(TypeError): +# s.transform(mock_data_array) diff --git a/tests/preprocessing/test_datalist_stacker.py b/tests/preprocessing/test_datalist_stacker.py new file mode 100644 index 0000000..8b0b5be --- /dev/null +++ b/tests/preprocessing/test_datalist_stacker.py @@ -0,0 +1,236 @@ +# import pytest +# import numpy as np +# import xarray as xr + +# from xeofs.preprocessing.stacker import DataListStacker +# from xeofs.utils.data_types import DataArray, DataList +# from ..conftest import generate_list_of_synthetic_dataarrays +# from ..utilities import ( +# get_dims_from_data_list, +# data_is_dask, +# assert_expected_dims, +# assert_expected_coords, +# ) + +# # ============================================================================= +# # GENERALLY VALID TEST CASES +# # ============================================================================= +# N_ARRAYS = [1, 2] +# N_SAMPLE_DIMS = [1, 2] +# N_FEATURE_DIMS = [1, 2] +# INDEX_POLICY = ["index"] +# NAN_POLICY = ["no_nan"] +# DASK_POLICY = ["no_dask", "dask"] +# SEED = [0] + +# VALID_TEST_DATA = [ +# (na, ns, nf, index, nan, dask) +# for na in N_ARRAYS +# for ns in N_SAMPLE_DIMS +# for nf in N_FEATURE_DIMS +# for index in INDEX_POLICY +# for nan in NAN_POLICY +# for dask in DASK_POLICY +# ] + + +# # TESTS +# # ============================================================================= +# @pytest.mark.parametrize( +# "sample_name, feature_name, data_params", +# [ +# ("sample", "feature", (2, 1, 1)), +# ("sample0", "feature0", (2, 1, 1)), +# ("sample0", "feature", (2, 1, 2)), +# ("sample", "feature0", (2, 2, 1)), +# ("sample", "feature", (2, 2, 2)), +# ("another_sample", "another_feature", (2, 1, 1)), +# ("another_sample", "another_feature", (2, 2, 2)), +# ], +# ) +# def test_fit_valid_dimension_names(sample_name, feature_name, data_params): +# data_list = generate_list_of_synthetic_dataarrays(*data_params) +# all_dims, sample_dims, feature_dims = get_dims_from_data_list(data_list) + +# stacker = DataListStacker(sample_name=sample_name, feature_name=feature_name) +# stacker.fit(data_list, sample_dims[0], feature_dims) +# stacked_data = stacker.transform(data_list) +# reconstructed_data_list = stacker.inverse_transform_data(stacked_data) + +# assert stacked_data.ndim == 2 +# assert set(stacked_data.dims) == set((sample_name, feature_name)) +# for reconstructed_data, data in zip(reconstructed_data_list, data_list): +# assert set(reconstructed_data.dims) == set(data.dims) + + +# @pytest.mark.parametrize( +# "sample_name, feature_name, data_params", +# [ +# ("sample1", "feature", (2, 2, 1)), +# ("sample", "feature1", (2, 1, 2)), +# ("sample1", "feature1", (2, 3, 3)), +# ], +# ) +# def test_fit_invalid_dimension_names(sample_name, feature_name, data_params): +# data_list = generate_list_of_synthetic_dataarrays(*data_params) +# all_dims, sample_dims, feature_dims = get_dims_from_data_list(data_list) + +# stacker = DataListStacker(sample_name=sample_name, feature_name=feature_name) + +# with pytest.raises(ValueError): +# stacker.fit(data_list, sample_dims[0], feature_dims) + + +# @pytest.mark.parametrize( +# "synthetic_datalist", +# VALID_TEST_DATA, +# indirect=["synthetic_datalist"], +# ) +# def test_fit(synthetic_datalist): +# data_list = synthetic_datalist +# all_dims, sample_dims, feature_dims = get_dims_from_data_list(data_list) + +# stacker = DataListStacker() +# stacker.fit(data_list, sample_dims[0], feature_dims) + + +# @pytest.mark.parametrize( +# "synthetic_datalist", +# VALID_TEST_DATA, +# indirect=["synthetic_datalist"], +# ) +# def test_transform(synthetic_datalist): +# data_list = synthetic_datalist +# all_dims, sample_dims, feature_dims = get_dims_from_data_list(data_list) + +# stacker = DataListStacker() +# stacker.fit(data_list, sample_dims[0], feature_dims) +# transformed_data = stacker.transform(data_list) +# transformed_data2 = stacker.transform(data_list) + +# is_dask_before = data_is_dask(data_list) +# is_dask_after = data_is_dask(transformed_data) + +# assert isinstance(transformed_data, DataArray) +# assert transformed_data.ndim == 2 +# assert transformed_data.dims == ("sample", "feature") +# assert is_dask_before == is_dask_after +# assert transformed_data.identical(transformed_data2) + + +# @pytest.mark.parametrize( +# "synthetic_datalist", +# VALID_TEST_DATA, +# indirect=["synthetic_datalist"], +# ) +# def test_transform_invalid(synthetic_datalist): +# data_list = synthetic_datalist +# all_dims, sample_dims, feature_dims = get_dims_from_data_list(data_list) + +# stacker = DataListStacker() +# stacker.fit(data_list, sample_dims[0], feature_dims) + +# data_list = [da.isel(feature0=slice(0, 2)) for da in data_list] +# with pytest.raises(ValueError): +# stacker.transform(data_list) + + +# @pytest.mark.parametrize( +# "synthetic_datalist", +# VALID_TEST_DATA, +# indirect=["synthetic_datalist"], +# ) +# def test_fit_transform(synthetic_datalist): +# data_list = synthetic_datalist +# all_dims, sample_dims, feature_dims = get_dims_from_data_list(data_list) + +# stacker = DataListStacker() +# transformed_data = stacker.fit_transform(data_list, sample_dims[0], feature_dims) + +# is_dask_before = data_is_dask(data_list) +# is_dask_after = data_is_dask(transformed_data) + +# assert isinstance(transformed_data, DataArray) +# assert transformed_data.ndim == 2 +# assert transformed_data.dims == ("sample", "feature") +# assert is_dask_before == is_dask_after + + +# @pytest.mark.parametrize( +# "synthetic_datalist", +# VALID_TEST_DATA, +# indirect=["synthetic_datalist"], +# ) +# def test_invserse_transform_data(synthetic_datalist): +# data_list = synthetic_datalist +# all_dims, sample_dims, feature_dims = get_dims_from_data_list(data_list) + +# stacker = DataListStacker() +# stacker.fit(data_list, sample_dims[0], feature_dims) +# stacked_data = stacker.transform(data_list) +# unstacked_data = stacker.inverse_transform_data(stacked_data) + +# is_dask_before = data_is_dask(data_list) +# is_dask_after = data_is_dask(unstacked_data) + +# # Unstacked data has dimensions of original data +# assert_expected_dims(data_list, unstacked_data, policy="all") +# # Unstacked data has coordinates of original data +# assert_expected_coords(data_list, unstacked_data, policy="all") +# # inverse transform should not change dask-ness +# assert is_dask_before == is_dask_after + + +# @pytest.mark.parametrize( +# "synthetic_datalist", +# VALID_TEST_DATA, +# indirect=["synthetic_datalist"], +# ) +# def test_invserse_transform_components(synthetic_datalist): +# data_list = synthetic_datalist +# all_dims, sample_dims, feature_dims = get_dims_from_data_list(data_list) + +# stacker = DataListStacker() +# stacker.fit(data_list, sample_dims[0], feature_dims) + +# stacked_data = stacker.transform(data_list) +# components = stacked_data.rename({"sample": "mode"}) +# components.coords.update({"mode": range(components.mode.size)}) +# unstacked_data = stacker.inverse_transform_components(components) + +# is_dask_before = data_is_dask(data_list) +# is_dask_after = data_is_dask(unstacked_data) + +# # Unstacked components has correct feature dimensions +# assert_expected_dims(data_list, unstacked_data, policy="feature") +# # Unstacked data has feature coordinates of original data +# assert_expected_coords(data_list, unstacked_data, policy="feature") +# # inverse transform should not change dask-ness +# assert is_dask_before == is_dask_after + + +# @pytest.mark.parametrize( +# "synthetic_datalist", +# VALID_TEST_DATA, +# indirect=["synthetic_datalist"], +# ) +# def test_invserse_transform_scores(synthetic_datalist): +# data_list = synthetic_datalist +# all_dims, sample_dims, feature_dims = get_dims_from_data_list(data_list) + +# stacker = DataListStacker() +# stacker.fit(data_list, sample_dims[0], feature_dims) + +# stacked_data = stacker.transform(data_list) +# scores = stacked_data.rename({"feature": "mode"}) +# unstacked_data = stacker.inverse_transform_scores(scores) + +# is_dask_before = data_is_dask(data_list) +# is_dask_after = data_is_dask(unstacked_data) + +# # Unstacked scores has correct feature dimensions +# assert_expected_dims(data_list[0], unstacked_data, policy="sample") +# # Unstacked data has coordinates of original data +# assert_expected_coords(data_list[0], unstacked_data, policy="sample") +# # inverse transform should not change dask-ness +# assert is_dask_before == is_dask_after diff --git a/tests/preprocessing/test_dataset_multiindex_converter.py b/tests/preprocessing/test_dataset_multiindex_converter.py new file mode 100644 index 0000000..b93e75b --- /dev/null +++ b/tests/preprocessing/test_dataset_multiindex_converter.py @@ -0,0 +1,88 @@ +import pytest +import pandas as pd + +from xeofs.preprocessing.multi_index_converter import ( + MultiIndexConverter, +) +from ..conftest import generate_synthetic_dataset +from xeofs.utils.data_types import DataArray +from ..utilities import assert_expected_dims, data_is_dask, data_has_multiindex + +# ============================================================================= +# GENERALLY VALID TEST CASES +# ============================================================================= +N_VARIABLES = [1, 2] +N_SAMPLE_DIMS = [1, 2] +N_FEATURE_DIMS = [1, 2] +INDEX_POLICY = ["index"] +NAN_POLICY = ["no_nan"] +DASK_POLICY = ["no_dask", "dask"] +SEED = [0] + +VALID_TEST_DATA = [ + (nv, ns, nf, index, nan, dask) + for nv in N_VARIABLES + for ns in N_SAMPLE_DIMS + for nf in N_FEATURE_DIMS + for index in INDEX_POLICY + for nan in NAN_POLICY + for dask in DASK_POLICY +] + + +# TESTS +# ============================================================================= +@pytest.mark.parametrize( + "synthetic_dataset", + VALID_TEST_DATA, + indirect=["synthetic_dataset"], +) +def test_transform(synthetic_dataset): + converter = MultiIndexConverter() + converter.fit(synthetic_dataset) + transformed_data = converter.transform(synthetic_dataset) + + is_dask_before = data_is_dask(synthetic_dataset) + is_dask_after = data_is_dask(transformed_data) + + # Transforming does not affect dimensions + assert_expected_dims(transformed_data, synthetic_dataset, policy="all") + + # Transforming doesn't change the dask-ness of the data + assert is_dask_before == is_dask_after + + # Transforming removes MultiIndex + assert data_has_multiindex(transformed_data) is False + + # Result is robust to calling the method multiple times + transformed_data = converter.transform(synthetic_dataset) + assert data_has_multiindex(transformed_data) is False + + # Transforming data twice won't change the data + transformed_data2 = converter.transform(transformed_data) + assert data_has_multiindex(transformed_data2) is False + assert transformed_data.identical(transformed_data2) + + +@pytest.mark.parametrize( + "synthetic_dataset", + VALID_TEST_DATA, + indirect=["synthetic_dataset"], +) +def test_inverse_transform(synthetic_dataset): + converter = MultiIndexConverter() + converter.fit(synthetic_dataset) + transformed_data = converter.transform(synthetic_dataset) + inverse_transformed_data = converter.inverse_transform_data(transformed_data) + + is_dask_before = data_is_dask(synthetic_dataset) + is_dask_after = data_is_dask(transformed_data) + + # Transforming doesn't change the dask-ness of the data + assert is_dask_before == is_dask_after + + has_multiindex_before = data_has_multiindex(synthetic_dataset) + has_multiindex_after = data_has_multiindex(inverse_transformed_data) + + assert inverse_transformed_data.identical(synthetic_dataset) + assert has_multiindex_before == has_multiindex_after diff --git a/tests/preprocessing/test_dataset_renamer.py b/tests/preprocessing/test_dataset_renamer.py new file mode 100644 index 0000000..7ba110b --- /dev/null +++ b/tests/preprocessing/test_dataset_renamer.py @@ -0,0 +1,90 @@ +import pytest + +from xeofs.preprocessing.dimension_renamer import DimensionRenamer +from ..utilities import ( + data_is_dask, + get_dims_from_data, +) + +# ============================================================================= +# GENERALLY VALID TEST CASES +# ============================================================================= +N_VARIABLES = [1, 2] +N_SAMPLE_DIMS = [1, 2] +N_FEATURE_DIMS = [1, 2] +INDEX_POLICY = ["index", "multiindex"] +NAN_POLICY = ["no_nan", "fulldim"] +DASK_POLICY = ["no_dask", "dask"] +SEED = [0] + +VALID_TEST_DATA = [ + (nv, ns, nf, index, nan, dask) + for nv in N_VARIABLES + for ns in N_SAMPLE_DIMS + for nf in N_FEATURE_DIMS + for index in INDEX_POLICY + for nan in NAN_POLICY + for dask in DASK_POLICY +] + + +# TESTS +# ============================================================================= +@pytest.mark.parametrize( + "synthetic_dataset", + VALID_TEST_DATA, + indirect=["synthetic_dataset"], +) +def test_transform(synthetic_dataset): + all_dims, sample_dims, feature_dims = get_dims_from_data(synthetic_dataset) + + n_dims = len(all_dims) + + base = "new" + start = 10 + expected_dims = set(base + str(i) for i in range(start, start + n_dims)) + + renamer = DimensionRenamer(base=base, start=start) + renamer.fit(synthetic_dataset, sample_dims, feature_dims) + transformed_data = renamer.transform(synthetic_dataset) + + is_dask_before = data_is_dask(synthetic_dataset) + is_dask_after = data_is_dask(transformed_data) + + # Transforming doesn't change the dask-ness of the data + assert is_dask_before == is_dask_after + + # Transforming converts dimension names + given_dims = set(transformed_data.dims) + assert given_dims == expected_dims + + # Result is robust to calling the method multiple times + transformed_data = renamer.transform(synthetic_dataset) + given_dims = set(transformed_data.dims) + assert given_dims == expected_dims + + +@pytest.mark.parametrize( + "synthetic_dataset", + VALID_TEST_DATA, + indirect=["synthetic_dataset"], +) +def test_inverse_transform_data(synthetic_dataset): + all_dims, sample_dims, feature_dims = get_dims_from_data(synthetic_dataset) + + base = "new" + start = 10 + + renamer = DimensionRenamer(base=base, start=start) + renamer.fit(synthetic_dataset, sample_dims, feature_dims) + transformed_data = renamer.transform(synthetic_dataset) + inverse_transformed_data = renamer.inverse_transform_data(transformed_data) + + is_dask_before = data_is_dask(synthetic_dataset) + is_dask_after = data_is_dask(transformed_data) + + # Transforming doesn't change the dask-ness of the data + assert is_dask_before == is_dask_after + + assert inverse_transformed_data.identical(synthetic_dataset) + assert set(inverse_transformed_data.dims) == set(synthetic_dataset.dims) diff --git a/tests/preprocessing/test_single_dataset_scaler.py b/tests/preprocessing/test_dataset_scaler.py similarity index 52% rename from tests/preprocessing/test_single_dataset_scaler.py rename to tests/preprocessing/test_dataset_scaler.py index ce5ff62..c992d68 100644 --- a/tests/preprocessing/test_single_dataset_scaler.py +++ b/tests/preprocessing/test_dataset_scaler.py @@ -2,65 +2,35 @@ import xarray as xr import numpy as np -from xeofs.preprocessing.scaler import SingleDatasetScaler +from xeofs.preprocessing.scaler import Scaler @pytest.mark.parametrize( - "with_std, with_coslat, with_weights", + "with_std, with_coslat", [ - (True, True, True), - (True, True, False), - (True, False, True), - (True, False, False), - (False, True, True), - (False, True, False), - (False, False, True), - (False, False, False), - (True, True, True), - (True, True, False), - (True, False, True), - (True, False, False), - (False, True, True), - (False, True, False), - (False, False, True), - (False, False, False), + (True, True), + (True, False), + (False, True), + (False, False), ], ) -def test_init_params(with_std, with_coslat, with_weights): - s = SingleDatasetScaler( - with_std=with_std, with_coslat=with_coslat, with_weights=with_weights - ) - assert hasattr(s, "_params") - assert s._params["with_std"] == with_std - assert s._params["with_coslat"] == with_coslat - assert s._params["with_weights"] == with_weights +def test_init_params(with_std, with_coslat): + s = Scaler(with_std=with_std, with_coslat=with_coslat) + assert s.get_params()["with_std"] == with_std + assert s.get_params()["with_coslat"] == with_coslat @pytest.mark.parametrize( - "with_std, with_coslat, with_weights", + "with_std, with_coslat", [ - (True, True, True), - (True, True, False), - (True, False, True), - (True, False, False), - (False, True, True), - (False, True, False), - (False, False, True), - (False, False, False), - (True, True, True), - (True, True, False), - (True, False, True), - (True, False, False), - (False, True, True), - (False, True, False), - (False, False, True), - (False, False, False), + (True, True), + (True, False), + (False, True), + (False, False), ], ) -def test_fit_params(with_std, with_coslat, with_weights, mock_dataset): - s = SingleDatasetScaler( - with_std=with_std, with_coslat=with_coslat, with_weights=with_weights - ) +def test_fit_params(with_std, with_coslat, mock_dataset): + s = Scaler(with_std=with_std, with_coslat=with_coslat) sample_dims = ["time"] feature_dims = ["lat", "lon"] size_lats = mock_dataset.lat.size @@ -68,53 +38,42 @@ def test_fit_params(with_std, with_coslat, with_weights, mock_dataset): np.random.rand(size_lats), dims=["lat"], name="weights" ).to_dataset() s.fit(mock_dataset, sample_dims, feature_dims, weights) - assert hasattr(s, "mean"), "Scaler has no mean attribute." + assert hasattr(s, "mean_"), "Scaler has no mean attribute." if with_std: - assert hasattr(s, "std"), "Scaler has no std attribute." + assert hasattr(s, "std_"), "Scaler has no std attribute." if with_coslat: - assert hasattr(s, "coslat_weights"), "Scaler has no coslat_weights attribute." - if with_weights: - assert hasattr(s, "weights"), "Scaler has no weights attribute." - assert s.mean is not None, "Scaler mean is None." + assert hasattr(s, "coslat_weights_"), "Scaler has no coslat_weights attribute." + assert s.mean_ is not None, "Scaler mean is None." if with_std: - assert s.std is not None, "Scaler std is None." + assert s.std_ is not None, "Scaler std is None." if with_coslat: - assert s.coslat_weights is not None, "Scaler coslat_weights is None." - if with_weights: - assert s.weights is not None, "Scaler weights is None." + assert s.coslat_weights_ is not None, "Scaler coslat_weights is None." @pytest.mark.parametrize( "with_std, with_coslat, with_weights", [ (True, True, True), - (True, True, False), (True, False, True), - (True, False, False), (False, True, True), - (False, True, False), (False, False, True), - (False, False, False), - (True, True, True), (True, True, False), - (True, False, True), (True, False, False), - (False, True, True), (False, True, False), - (False, False, True), (False, False, False), ], ) def test_transform_params(with_std, with_coslat, with_weights, mock_dataset): - s = SingleDatasetScaler( - with_std=with_std, with_coslat=with_coslat, with_weights=with_weights - ) + s = Scaler(with_std=with_std, with_coslat=with_coslat) sample_dims = ["time"] feature_dims = ["lat", "lon"] size_lats = mock_dataset.lat.size - weights1 = xr.DataArray(np.random.rand(size_lats), dims=["lat"], name="t2m") - weights2 = xr.DataArray(np.random.rand(size_lats), dims=["lat"], name="prcp") - weights = xr.merge([weights1, weights2]) + if with_weights: + weights1 = xr.DataArray(np.random.rand(size_lats), dims=["lat"], name="t2m") + weights2 = xr.DataArray(np.random.rand(size_lats), dims=["lat"], name="prcp") + weights = xr.merge([weights1, weights2]) + else: + weights = None s.fit(mock_dataset, sample_dims, feature_dims, weights) transformed = s.transform(mock_dataset) assert transformed is not None, "Transformed data is None." @@ -136,13 +95,7 @@ def test_transform_params(with_std, with_coslat, with_weights, mock_dataset): ), "Standard deviation of the transformed data is not one." if with_coslat: - assert s.coslat_weights is not None, "Scaler coslat_weights is None." - assert not np.array_equal( - transformed, mock_dataset - ), "Data has not been transformed." - - if with_weights: - assert s.weights is not None, "Scaler weights is None." + assert s.coslat_weights_ is not None, "Scaler coslat_weights is None." assert not np.array_equal( transformed, mock_dataset ), "Data has not been transformed." @@ -152,30 +105,16 @@ def test_transform_params(with_std, with_coslat, with_weights, mock_dataset): @pytest.mark.parametrize( - "with_std, with_coslat, with_weights", + "with_std, with_coslat", [ - (True, True, True), - (True, True, False), - (True, False, True), - (True, False, False), - (False, True, True), - (False, True, False), - (False, False, True), - (False, False, False), - (True, True, True), - (True, True, False), - (True, False, True), - (True, False, False), - (False, True, True), - (False, True, False), - (False, False, True), - (False, False, False), + (True, True), + (True, False), + (False, True), + (False, False), ], ) -def test_inverse_transform_params(with_std, with_coslat, with_weights, mock_dataset): - s = SingleDatasetScaler( - with_std=with_std, with_coslat=with_coslat, with_weights=with_weights - ) +def test_inverse_transform_params(with_std, with_coslat, mock_dataset): + s = Scaler(with_std=with_std, with_coslat=with_coslat) sample_dims = ["time"] feature_dims = ["lat", "lon"] size_lats = mock_dataset.lat.size @@ -184,7 +123,7 @@ def test_inverse_transform_params(with_std, with_coslat, with_weights, mock_data weights = xr.merge([weights1, weights2]) s.fit(mock_dataset, sample_dims, feature_dims, weights) transformed = s.transform(mock_dataset) - inverted = s.inverse_transform(transformed) + inverted = s.inverse_transform_data(transformed) xr.testing.assert_allclose(inverted, mock_dataset) @@ -198,17 +137,17 @@ def test_inverse_transform_params(with_std, with_coslat, with_weights, mock_data ], ) def test_fit_dims(dim_sample, dim_feature, mock_dataset): - s = SingleDatasetScaler() + s = Scaler(with_std=True) s.fit(mock_dataset, dim_sample, dim_feature) - assert hasattr(s, "mean"), "Scaler has no mean attribute." - assert s.mean is not None, "Scaler mean is None." - assert hasattr(s, "std"), "Scaler has no std attribute." - assert s.std is not None, "Scaler std is None." + assert hasattr(s, "mean_"), "Scaler has no mean attribute." + assert s.mean_ is not None, "Scaler mean is None." + assert hasattr(s, "std_"), "Scaler has no std attribute." + assert s.std_ is not None, "Scaler std is None." # check that all dimensions are present except the sample dimensions - assert set(s.mean.dims) == set(mock_dataset.dims) - set( + assert set(s.mean_.dims) == set(mock_dataset.dims) - set( dim_sample ), "Mean has wrong dimensions." - assert set(s.std.dims) == set(mock_dataset.dims) - set( + assert set(s.std_.dims) == set(mock_dataset.dims) - set( dim_sample ), "Standard deviation has wrong dimensions." @@ -223,7 +162,7 @@ def test_fit_dims(dim_sample, dim_feature, mock_dataset): ], ) def test_fit_transform_dims(dim_sample, dim_feature, mock_dataset): - s = SingleDatasetScaler() + s = Scaler() transformed = s.fit_transform(mock_dataset, dim_sample, dim_feature) # check that all dimensions are present assert set(transformed.dims) == set( @@ -236,25 +175,20 @@ def test_fit_transform_dims(dim_sample, dim_feature, mock_dataset): # Test input types def test_fit_input_type(mock_dataset, mock_data_array, mock_data_array_list): - s = SingleDatasetScaler() - # Cannot fit DataArray - with pytest.raises(TypeError): - s.fit(mock_data_array, ["time"], ["lon", "lat"]) + s = Scaler() # Cannot fit list of DataArrays with pytest.raises(TypeError): s.fit(mock_data_array_list, ["time"], ["lon", "lat"]) s.fit(mock_dataset, ["time"], ["lon", "lat"]) - # Cannot transform DataArray - with pytest.raises(TypeError): - s.transform(mock_data_array) + # Cannot transform list of DataArrays with pytest.raises(TypeError): s.transform(mock_data_array_list) # def test_fit_weights_input_type(mock_dataset): -# s = SingleDatasetScaler() +# s = Scaler() # # Fitting with weights requires that the weights have the same variables as the dataset # # used for fitting; otherwise raise an error # size_lats = mock_dataset.lat.size diff --git a/tests/preprocessing/test_dataset_stacker.py b/tests/preprocessing/test_dataset_stacker.py index cfc53f9..e35d0e1 100644 --- a/tests/preprocessing/test_dataset_stacker.py +++ b/tests/preprocessing/test_dataset_stacker.py @@ -2,177 +2,240 @@ import xarray as xr import numpy as np -from xeofs.preprocessing.stacker import SingleDatasetStacker - +from xeofs.preprocessing.stacker import DataSetStacker +from xeofs.utils.data_types import DataSet, DataArray +from ..conftest import generate_synthetic_dataset +from ..utilities import ( + get_dims_from_data, + data_is_dask, + assert_expected_dims, + assert_expected_coords, +) +# ============================================================================= +# GENERALLY VALID TEST CASES +# ============================================================================= +N_VARIABLES = [1, 2] +N_SAMPLE_DIMS = [1, 2] +N_FEATURE_DIMS = [1, 2] +INDEX_POLICY = ["index"] +NAN_POLICY = ["no_nan"] +DASK_POLICY = ["no_dask", "dask"] +SEED = [0] + +VALID_TEST_DATA = [ + (nv, ns, nf, index, nan, dask) + for nv in N_VARIABLES + for ns in N_SAMPLE_DIMS + for nf in N_FEATURE_DIMS + for index in INDEX_POLICY + for nan in NAN_POLICY + for dask in DASK_POLICY +] + + +# TESTS +# ============================================================================= @pytest.mark.parametrize( - "dim_sample, dim_feature", + "sample_name, feature_name, data_params", [ - (("time",), ("lat", "lon")), - (("time",), ("lon", "lat")), - (("lat", "lon"), ("time",)), - (("lon", "lat"), ("time",)), + ("sample", "feature", (1, 1, 1)), + ("sample", "feature", (2, 1, 1)), + ("sample", "feature", (1, 2, 2)), + ("sample", "feature", (2, 2, 2)), + ("sample0", "feature", (1, 1, 1)), + ("sample0", "feature", (1, 1, 2)), + ("sample0", "feature", (2, 1, 2)), + ("another_sample", "another_feature", (1, 1, 1)), + ("another_sample", "another_feature", (1, 2, 2)), + ("another_sample", "another_feature", (2, 1, 1)), + ("another_sample", "another_feature", (2, 2, 2)), ], ) -def test_fit_transform(mock_dataset, dim_sample, dim_feature): - stacker = SingleDatasetStacker() - stacked = stacker.fit_transform(mock_dataset, dim_sample, dim_feature) +def test_fit_valid_dimension_names(sample_name, feature_name, data_params): + data = generate_synthetic_dataset(*data_params) + all_dims, sample_dims, feature_dims = get_dims_from_data(data) - # check output type and dimensions - assert isinstance(stacked, xr.DataArray) - assert set(stacked.dims) == {"sample", "feature"} + stacker = DataSetStacker(sample_name=sample_name, feature_name=feature_name) + stacker.fit(data, sample_dims, feature_dims) + stacked_data = stacker.transform(data) + reconstructed_data = stacker.inverse_transform_data(stacked_data) - # check if all NaN rows or columns have been dropped - assert not stacked.isnull().any() - - # check the size of the output data - assert stacked.size > 0 + assert stacked_data.ndim == 2 + assert set(stacked_data.dims) == set((sample_name, feature_name)) + assert set(reconstructed_data.dims) == set(data.dims) @pytest.mark.parametrize( - "dim_sample, dim_feature", + "sample_name, feature_name, data_params", [ - (("time",), ("lat", "lon")), - (("time",), ("lon", "lat")), - (("lat", "lon"), ("time",)), - (("lon", "lat"), ("time",)), + ("sample", "feature0", (1, 1, 1)), + ("sample0", "feature", (1, 2, 1)), + ("sample1", "feature1", (1, 3, 3)), + ("sample", "feature0", (2, 1, 1)), + ("sample0", "feature", (2, 2, 1)), + ("sample1", "feature1", (2, 3, 3)), ], ) -def test_transform(mock_dataset, dim_sample, dim_feature): - stacker = SingleDatasetStacker() - stacker.fit_transform(mock_dataset, dim_sample, dim_feature) +def test_fit_invalid_dimension_names(sample_name, feature_name, data_params): + data = generate_synthetic_dataset(*data_params) + all_dims, sample_dims, feature_dims = get_dims_from_data(data) - # create a new dataset for testing the transform function - new_data = mock_dataset.copy() - transformed = stacker.transform(new_data) + stacker = DataSetStacker(sample_name=sample_name, feature_name=feature_name) - assert isinstance(transformed, xr.DataArray) - assert set(transformed.dims) == {"sample", "feature"} - assert not transformed.isnull().any() - assert transformed.size > 0 + with pytest.raises(ValueError): + stacker.fit(data, sample_dims, feature_dims) @pytest.mark.parametrize( - "dim_sample, dim_feature", - [ - (("time",), ("lat", "lon")), - (("time",), ("lon", "lat")), - (("lat", "lon"), ("time",)), - (("lon", "lat"), ("time",)), - ], + "synthetic_dataset", + VALID_TEST_DATA, + indirect=["synthetic_dataset"], ) -def test_inverse_transform_data(mock_dataset, dim_sample, dim_feature): - stacker = SingleDatasetStacker() - stacked = stacker.fit_transform(mock_dataset, dim_sample, dim_feature) +def test_fit(synthetic_dataset): + data = synthetic_dataset + all_dims, sample_dims, feature_dims = get_dims_from_data(data) - inverse_transformed = stacker.inverse_transform_data(stacked) - assert isinstance(inverse_transformed, xr.Dataset) + stacker = DataSetStacker() + stacker.fit(data, sample_dims, feature_dims) - for var in inverse_transformed.data_vars: - xr.testing.assert_equal(inverse_transformed[var], mock_dataset[var]) + +@pytest.mark.parametrize( + "synthetic_dataset", + VALID_TEST_DATA, + indirect=["synthetic_dataset"], +) +def test_transform(synthetic_dataset): + data = synthetic_dataset + all_dims, sample_dims, feature_dims = get_dims_from_data(data) + + stacker = DataSetStacker() + stacker.fit(data, sample_dims, feature_dims) + transformed_data = stacker.transform(data) + transformed_data2 = stacker.transform(data) + + is_dask_before = data_is_dask(data) + is_dask_after = data_is_dask(transformed_data) + + assert isinstance(transformed_data, DataArray) + assert transformed_data.ndim == 2 + assert transformed_data.dims == ("sample", "feature") + assert is_dask_before == is_dask_after + assert transformed_data.identical(transformed_data2) @pytest.mark.parametrize( - "dim_sample, dim_feature", - [ - (("time",), ("lat", "lon")), - (("time",), ("lon", "lat")), - (("lat", "lon"), ("time",)), - (("lon", "lat"), ("time",)), - ], + "synthetic_dataset", + VALID_TEST_DATA, + indirect=["synthetic_dataset"], +) +def test_transform_invalid(synthetic_dataset): + data = synthetic_dataset + all_dims, sample_dims, feature_dims = get_dims_from_data(data) + + stacker = DataSetStacker() + stacker.fit(data, sample_dims, feature_dims) + with pytest.raises(ValueError): + stacker.transform(data.isel(feature0=slice(0, 2))) + + +@pytest.mark.parametrize( + "synthetic_dataset", + VALID_TEST_DATA, + indirect=["synthetic_dataset"], ) -def test_inverse_transform_components(mock_dataset, dim_sample, dim_feature): - stacker = SingleDatasetStacker() - stacked = stacker.fit_transform(mock_dataset, dim_sample, dim_feature) +def test_fit_transform(synthetic_dataset): + data = synthetic_dataset + all_dims, sample_dims, feature_dims = get_dims_from_data(data) - # dummy components - components = xr.DataArray( - np.random.normal(size=(len(stacker.coords_out_["feature"]), 10)), - dims=("feature", "mode"), - coords={"feature": stacker.coords_out_["feature"]}, - ) - inverse_transformed = stacker.inverse_transform_components(components) + stacker = DataSetStacker() + transformed_data = stacker.fit_transform(data, sample_dims, feature_dims) - # check output type and dimensions - assert isinstance(inverse_transformed, xr.Dataset) - assert set(inverse_transformed.dims) == set(dim_feature + ("mode",)) + is_dask_before = data_is_dask(data) + is_dask_after = data_is_dask(transformed_data) - assert set(mock_dataset.data_vars) == set( - inverse_transformed.data_vars - ), "Dataset variables are not the same." + assert isinstance(transformed_data, DataArray) + assert transformed_data.ndim == 2 + assert transformed_data.dims == ("sample", "feature") + assert is_dask_before == is_dask_after @pytest.mark.parametrize( - "dim_sample, dim_feature", - [ - (("time",), ("lat", "lon")), - (("time",), ("lon", "lat")), - (("lat", "lon"), ("time",)), - (("lon", "lat"), ("time",)), - ], + "synthetic_dataset", + VALID_TEST_DATA, + indirect=["synthetic_dataset"], ) -def test_inverse_transform_scores(mock_dataset, dim_sample, dim_feature): - stacker = SingleDatasetStacker() - stacked = stacker.fit_transform(mock_dataset, dim_sample, dim_feature) +def test_invserse_transform_data(synthetic_dataset): + data = synthetic_dataset + all_dims, sample_dims, feature_dims = get_dims_from_data(data) - # dummy scores - scores = xr.DataArray( - np.random.rand(len(stacker.coords_out_["sample"]), 10), - dims=("sample", "mode"), - coords={"sample": stacker.coords_out_["sample"]}, - ) - inverse_transformed = stacker.inverse_transform_scores(scores) + stacker = DataSetStacker() + stacker.fit(data, sample_dims, feature_dims) + stacked_data = stacker.transform(data) + unstacked_data = stacker.inverse_transform_data(stacked_data) - assert isinstance(inverse_transformed, xr.DataArray) - assert set(inverse_transformed.dims) == set(dim_sample + ("mode",)) + is_dask_before = data_is_dask(data) + is_dask_after = data_is_dask(unstacked_data) - # check that sample coordinates are preserved - for dim, coords in mock_dataset.coords.items(): - if dim in dim_sample: - assert ( - inverse_transformed.coords[dim].size == coords.size - ), "Dimension {} has different size.".format(dim) + # Unstacked data has dimensions of original data + assert_expected_dims(data, unstacked_data, policy="all") + # Unstacked data has coordinates of original data + assert_expected_coords(data, unstacked_data, policy="all") + # inverse transform should not change dask-ness + assert is_dask_before == is_dask_after @pytest.mark.parametrize( - "dim_sample, dim_feature", - [ - (("time",), ("lat", "lon")), - (("time",), ("lon", "lat")), - (("lat", "lon"), ("time",)), - (("lon", "lat"), ("time",)), - ], + "synthetic_dataset", + VALID_TEST_DATA, + indirect=["synthetic_dataset"], ) -def test_fit_transform_raises_on_invalid_dims(mock_dataset, dim_sample, dim_feature): - stacker = SingleDatasetStacker() - with pytest.raises(ValueError): - stacker.fit_transform(mock_dataset, ("invalid_dim",), dim_feature) +def test_invserse_transform_components(synthetic_dataset): + data = synthetic_dataset + all_dims, sample_dims, feature_dims = get_dims_from_data(data) + stacker = DataSetStacker() + stacker.fit(data, sample_dims, feature_dims) -def test_fit_transform_raises_on_isolated_nans( - mock_data_array_isolated_nans, -): - stacker = SingleDatasetStacker() - invalid_dataset = xr.Dataset({"var": mock_data_array_isolated_nans}) - with pytest.raises(ValueError): - stacker.fit_transform(invalid_dataset, ("time",), ("lat", "lon")) + stacked_data = stacker.transform(data) + components = stacked_data.rename({"sample": "mode"}) + components.coords.update({"mode": range(components.mode.size)}) + unstacked_data = stacker.inverse_transform_components(components) + + is_dask_before = data_is_dask(data) + is_dask_after = data_is_dask(unstacked_data) + + # Unstacked components has correct feature dimensions + assert_expected_dims(data, unstacked_data, policy="feature") + # Unstacked data has coordinates of original data + assert_expected_coords(data, unstacked_data, policy="feature") + # inverse transform should not change dask-ness + assert is_dask_before == is_dask_after @pytest.mark.parametrize( - "dim_sample, dim_feature", - [ - (("time",), ("lat", "lon")), - (("time",), ("lon", "lat")), - (("lat", "lon"), ("time",)), - (("lon", "lat"), ("time",)), - ], + "synthetic_dataset", + VALID_TEST_DATA, + indirect=["synthetic_dataset"], ) -def test_fit_transform_passes_on_full_dimensional_nans( - mock_data_array_full_dimensional_nans, dim_sample, dim_feature -): - stacker = SingleDatasetStacker() - valid_dataset = xr.Dataset({"var": mock_data_array_full_dimensional_nans}) - try: - stacker.fit_transform(valid_dataset, dim_sample, dim_feature) - except ValueError: - pytest.fail("fit_transform() raised ValueError unexpectedly!") +def test_invserse_transform_scores(synthetic_dataset): + data = synthetic_dataset + all_dims, sample_dims, feature_dims = get_dims_from_data(data) + + stacker = DataSetStacker() + stacker.fit(data, sample_dims, feature_dims) + + stacked_data = stacker.transform(data) + scores = stacked_data.rename({"feature": "mode"}) + scores.coords.update({"mode": range(scores.mode.size)}) + unstacked_data = stacker.inverse_transform_scores(scores) + + is_dask_before = data_is_dask(data) + is_dask_after = data_is_dask(unstacked_data) + + # Unstacked scores has correct feature dimensions + assert_expected_dims(data, unstacked_data, policy="sample") + # Unstacked data has coordinates of original data + assert_expected_coords(data, unstacked_data, policy="sample") + # inverse transform should not change dask-ness + assert is_dask_before == is_dask_after diff --git a/tests/preprocessing/test_dataset_stacker_stack.py b/tests/preprocessing/test_dataset_stacker_stack.py deleted file mode 100644 index a267b4e..0000000 --- a/tests/preprocessing/test_dataset_stacker_stack.py +++ /dev/null @@ -1,171 +0,0 @@ -import pytest -import xarray as xr -import numpy as np - -from xeofs.preprocessing.stacker import SingleDatasetStacker - - -def create_ds(dim_sample, dim_feature, seed=None): - n_dims = len(dim_sample) + len(dim_feature) - size = n_dims * [3] - rng = np.random.default_rng(seed) - dims = dim_sample + dim_feature - coords = {d: np.arange(i, i + 3) for i, d in enumerate(dims)} - da1 = xr.DataArray(rng.normal(0, 1, size=size), dims=dims, coords=coords) - da2 = da1.sel(**{dim_feature[0]: slice(0, 2)}).squeeze().copy() - ds = xr.Dataset({"da1": da1, "da2": da2}) - return ds - - -# Valid input -# ============================================================================= -valid_input_dims = [ - # This SHOULD work but currently doesn't, potentially due to a bug in xarray (https://github.com/pydata/xarray/discussions/8063) - # (("year", "month"), ("lon", "lat")), - (("year",), ("lat", "lon")), - (("year", "month"), ("lon",)), - (("year",), ("lon",)), -] - -valid_input = [] -for dim_sample, dim_feature in valid_input_dims: - da = create_ds(dim_sample, dim_feature) - valid_input.append((da, dim_sample, dim_feature)) - - -# Invalid input -# ============================================================================= -invalid_input_dims = [ - (("sample",), ("feature", "lat")), - (("sample",), ("month", "feature")), - (("sample", "month"), ("lon", "lat")), - (("sample",), ("lon", "lat")), - (("year",), ("month", "sample")), - (("year",), ("sample",)), - (("sample",), ("lon",)), - (("year", "month"), ("lon", "feature")), - (("year", "month"), ("feature",)), - (("year",), ("feature",)), - (("feature",), ("lon", "lat")), - (("feature",), ("lon",)), - (("feature",), ("sample",)), - (("sample",), ("feature",)), -] -invalid_input = [] -for dim_sample, dim_feature in invalid_input_dims: - da = create_ds(dim_sample, dim_feature) - invalid_input.append((da, dim_sample, dim_feature)) - - -# Test stacking -# ============================================================================= -@pytest.mark.parametrize("da, dim_sample, dim_feature", valid_input) -def test_fit_transform(da, dim_sample, dim_feature): - """Test fit_transform with valid input.""" - stacker = SingleDatasetStacker() - da_stacked = stacker.fit_transform(da, dim_sample, dim_feature) - - # Stacked data has dimensions (sample, feature) - err_msg = f"In: {da.dims}; Out: {da_stacked.dims}" - assert set(da_stacked.dims) == { - "sample", - "feature", - }, err_msg - - -@pytest.mark.parametrize("da, dim_sample, dim_feature", invalid_input) -def test_fit_transform_invalid_input(da, dim_sample, dim_feature): - """Test fit_transform with invalid input.""" - stacker = SingleDatasetStacker() - with pytest.raises(ValueError): - da_stacked = stacker.fit_transform(da, dim_sample, dim_feature) - - -@pytest.mark.parametrize("da, dim_sample, dim_feature", valid_input) -def test_inverse_transform_data(da, dim_sample, dim_feature): - """Test inverse transform with valid input.""" - stacker = SingleDatasetStacker() - da_stacked = stacker.fit_transform(da, dim_sample, dim_feature) - da_unstacked = stacker.inverse_transform_data(da_stacked) - - # Unstacked data has dimensions of original data - err_msg = f"Original: {da.dims}; Recovered: {da_unstacked.dims}" - assert set(da_unstacked.dims) == set(da.dims), err_msg - # Unstacked data has variables of original data - err_msg = f"Original: {set(da.data_vars)}; Recovered: {set(da_unstacked.data_vars)}" - assert set(da_unstacked.data_vars) == set(da.data_vars), err_msg - # Unstacked data has coordinates of original data - for var in da.data_vars: - # Check if the dimensions are correct - err_msg = f"Original: {da[var].dims}; Recovered: {da_unstacked[var].dims}" - assert set(da_unstacked[var].dims) == set(da[var].dims), err_msg - for coord in da[var].coords: - err_msg = f"Original: {da[var].coords[coord]}; Recovered: {da_unstacked[var].coords[coord]}" - coords_are_equal = da[var].coords[coord] == da_unstacked[var].coords[coord] - assert np.all(coords_are_equal), err_msg - - -@pytest.mark.parametrize("da, dim_sample, dim_feature", valid_input) -def test_inverse_transform_components(da, dim_sample, dim_feature): - """Test inverse transform components with valid input.""" - stacker = SingleDatasetStacker() - da_stacked = stacker.fit_transform(da, dim_sample, dim_feature) - # Mock components by dropping sampling dim from data - comps_stacked = da_stacked.drop_vars("sample").rename({"sample": "mode"}) - comps_stacked.coords.update({"mode": range(comps_stacked.mode.size)}) - - comps_unstacked = stacker.inverse_transform_components(comps_stacked) - - # Unstacked components are a Dataset - assert isinstance(comps_unstacked, xr.Dataset) - - # Unstacked data has variables of original data - err_msg = ( - f"Original: {set(da.data_vars)}; Recovered: {set(comps_unstacked.data_vars)}" - ) - assert set(comps_unstacked.data_vars) == set(da.data_vars), err_msg - - # Unstacked components has correct feature dimensions - expected_dims = dim_feature + ("mode",) - err_msg = f"Expected: {expected_dims}; Recovered: {comps_unstacked.dims}" - assert set(comps_unstacked.dims) == set(expected_dims), err_msg - - # Unstacked components has coordinates of original data - for var in da.data_vars: - # Feature dimensions in original and recovered comps are same for each variable - expected_dims_in_var = tuple(d for d in da[var].dims if d in dim_feature) - expected_dims_in_var += ("mode",) - err_msg = ( - f"Original: {expected_dims_in_var}; Recovered: {comps_unstacked[var].dims}" - ) - assert set(expected_dims_in_var) == set(comps_unstacked[var].dims), err_msg - # Coordinates in original and recovered comps are same for each variable - for dim in expected_dims_in_var: - if dim != "mode": - err_msg = f"Original: {da[var].coords[dim]}; Recovered: {comps_unstacked[var].coords[dim]}" - coords_are_equal = ( - da[var].coords[dim] == comps_unstacked[var].coords[dim] - ) - assert np.all(coords_are_equal), err_msg - - -@pytest.mark.parametrize("da, dim_sample, dim_feature", valid_input) -def test_inverse_transform_scores(da, dim_sample, dim_feature): - """Test inverse transform scores with valid input.""" - stacker = SingleDatasetStacker() - da_stacked = stacker.fit_transform(da, dim_sample, dim_feature) - # Mock scores by dropping feature dim from data - scores_stacked = da_stacked.drop_vars("feature").rename({"feature": "mode"}) - scores_stacked.coords.update({"mode": range(scores_stacked.mode.size)}) - - scores_unstacked = stacker.inverse_transform_scores(scores_stacked) - - # Unstacked scores are a DataArray - assert isinstance(scores_unstacked, xr.DataArray) - # Unstacked components has correct feature dimensions - expected_dims = dim_sample + ("mode",) - err_msg = f"Expected: {expected_dims}; Recovered: {scores_unstacked.dims}" - assert set(scores_unstacked.dims) == set(expected_dims), err_msg - # Unstacked data has coordinates of original data - for d in dim_sample: - assert np.all(scores_unstacked.coords[d].values == da.coords[d].values) diff --git a/tests/preprocessing/test_list_dataarray_scaler.py b/tests/preprocessing/test_list_dataarray_scaler.py deleted file mode 100644 index d0a466f..0000000 --- a/tests/preprocessing/test_list_dataarray_scaler.py +++ /dev/null @@ -1,281 +0,0 @@ -import pytest -import xarray as xr -import numpy as np - -from xeofs.preprocessing.scaler import ListDataArrayScaler - - -@pytest.mark.parametrize( - "with_std, with_coslat, with_weights", - [ - (True, True, True), - (True, True, False), - (True, False, True), - (True, False, False), - (False, True, True), - (False, True, False), - (False, False, True), - (False, False, False), - (True, True, True), - (True, True, False), - (True, False, True), - (True, False, False), - (False, True, True), - (False, True, False), - (False, False, True), - (False, False, False), - ], -) -def test_init_params(with_std, with_coslat, with_weights): - s = ListDataArrayScaler( - with_std=with_std, with_coslat=with_coslat, with_weights=with_weights - ) - assert hasattr(s, "_params") - assert s._params["with_std"] == with_std - assert s._params["with_coslat"] == with_coslat - assert s._params["with_weights"] == with_weights - - -@pytest.mark.parametrize( - "with_std, with_coslat, with_weights", - [ - (True, True, True), - (True, True, False), - (True, False, True), - (True, False, False), - (False, True, True), - (False, True, False), - (False, False, True), - (False, False, False), - (True, True, True), - (True, True, False), - (True, False, True), - (True, False, False), - (False, True, True), - (False, True, False), - (False, False, True), - (False, False, False), - ], -) -def test_fit_params(with_std, with_coslat, with_weights, mock_data_array_list): - listscalers = ListDataArrayScaler( - with_std=with_std, with_coslat=with_coslat, with_weights=with_weights - ) - data = mock_data_array_list.copy() - sample_dims = ["time"] - feature_dims = [["lat", "lon"]] * 3 - size_lats_list = [da.lat.size for da in data] - weights = [ - xr.DataArray(np.random.rand(size), dims=["lat"]) for size in size_lats_list - ] - listscalers.fit(mock_data_array_list, sample_dims, feature_dims, weights) - - for s in listscalers.scalers: - assert hasattr(s, "mean"), "Scaler has no mean attribute." - if with_std: - assert hasattr(s, "std"), "Scaler has no std attribute." - if with_coslat: - assert hasattr( - s, "coslat_weights" - ), "Scaler has no coslat_weights attribute." - if with_weights: - assert hasattr(s, "weights"), "Scaler has no weights attribute." - assert s.mean is not None, "Scaler mean is None." - if with_std: - assert s.std is not None, "Scaler std is None." - if with_coslat: - assert s.coslat_weights is not None, "Scaler coslat_weights is None." - if with_weights: - assert s.weights is not None, "Scaler weights is None." - - -@pytest.mark.parametrize( - "with_std, with_coslat, with_weights", - [ - (True, True, True), - (True, True, False), - (True, False, True), - (True, False, False), - (False, True, True), - (False, True, False), - (False, False, True), - (False, False, False), - (True, True, True), - (True, True, False), - (True, False, True), - (True, False, False), - (False, True, True), - (False, True, False), - (False, False, True), - (False, False, False), - ], -) -def test_transform_params(with_std, with_coslat, with_weights, mock_data_array_list): - listscalers = ListDataArrayScaler( - with_std=with_std, with_coslat=with_coslat, with_weights=with_weights - ) - data = mock_data_array_list.copy() - sample_dims = ["time"] - feature_dims = [["lat", "lon"]] * 3 - size_lats_list = [da.lat.size for da in data] - weights = [ - xr.DataArray(np.random.rand(size), dims=["lat"]) for size in size_lats_list - ] - listscalers.fit(mock_data_array_list, sample_dims, feature_dims, weights) - - transformed = listscalers.transform(mock_data_array_list) - transformed2 = listscalers.fit_transform( - mock_data_array_list, sample_dims, feature_dims, weights - ) - - for t, t2, s, ref in zip(transformed, transformed2, listscalers.scalers, data): - assert t is not None, "Transformed data is None." - - t_mean = t.mean(sample_dims, skipna=False) - assert np.allclose(t_mean, 0), "Mean of the transformed data is not zero." - - if with_std: - t_std = t.std(sample_dims, skipna=False) - if with_coslat or with_weights: - assert ( - t_std <= 1 - ).all(), "Standard deviation of the transformed data is larger one." - else: - assert np.allclose( - t_std, 1 - ), "Standard deviation of the transformed data is not one." - - if with_coslat: - assert s.coslat_weights is not None, "Scaler coslat_weights is None." - assert not np.array_equal( - t, mock_data_array_list - ), "Data has not been transformed." - - if with_weights: - assert s.weights is not None, "Scaler weights is None." - assert not np.array_equal(t, ref), "Data has not been transformed." - - xr.testing.assert_allclose(t, t2) - - -@pytest.mark.parametrize( - "with_std, with_coslat, with_weights", - [ - (True, True, True), - (True, True, False), - (True, False, True), - (True, False, False), - (False, True, True), - (False, True, False), - (False, False, True), - (False, False, False), - (True, True, True), - (True, True, False), - (True, False, True), - (True, False, False), - (False, True, True), - (False, True, False), - (False, False, True), - (False, False, False), - ], -) -def test_inverse_transform_params( - with_std, with_coslat, with_weights, mock_data_array_list -): - listscalers = ListDataArrayScaler( - with_std=with_std, with_coslat=with_coslat, with_weights=with_weights - ) - data = mock_data_array_list.copy() - sample_dims = ["time"] - feature_dims = [["lat", "lon"]] * 3 - size_lats_list = [da.lat.size for da in data] - weights = [ - xr.DataArray(np.random.rand(size), dims=["lat"]) for size in size_lats_list - ] - listscalers.fit(mock_data_array_list, sample_dims, feature_dims, weights) - transformed = listscalers.transform(mock_data_array_list) - inverted = listscalers.inverse_transform(transformed) - - # check that inverse transform is the same as the original data - for inv, ref in zip(inverted, mock_data_array_list): - xr.testing.assert_allclose(inv, ref) - - -@pytest.mark.parametrize( - "dim_sample, dim_feature", - [ - (("time",), ("lat", "lon")), - (("time",), ("lon", "lat")), - (("lat", "lon"), ("time",)), - (("lon", "lat"), ("time",)), - ], -) -def test_fit_dims(dim_sample, dim_feature, mock_data_array_list): - listscalers = ListDataArrayScaler(with_std=True) - data = mock_data_array_list.copy() - dim_feature = [dim_feature] * 3 - - for s in listscalers.scalers: - assert hasattr(s, "mean"), "Scaler has no mean attribute." - assert s.mean is not None, "Scaler mean is None." - assert hasattr(s, "std"), "Scaler has no std attribute." - assert s.std is not None, "Scaler std is None." - # check that all dimensions are present except the sample dimensions - assert set(s.mean.dims) == set(mock_data_array_list.dims) - set( - dim_sample - ), "Mean has wrong dimensions." - assert set(s.std.dims) == set(mock_data_array_list.dims) - set( - dim_sample - ), "Standard deviation has wrong dimensions." - - -@pytest.mark.parametrize( - "dim_sample, dim_feature", - [ - (("time",), ("lat", "lon")), - (("time",), ("lon", "lat")), - (("lat", "lon"), ("time",)), - (("lon", "lat"), ("time",)), - ], -) -def test_fit_transform_dims(dim_sample, dim_feature, mock_data_array_list): - listscalers = ListDataArrayScaler(with_std=True) - data = mock_data_array_list.copy() - dim_feature = [dim_feature] * 3 - transformed = listscalers.fit_transform( - mock_data_array_list, dim_sample, dim_feature - ) - - for trns, ref in zip(transformed, mock_data_array_list): - # check that all dimensions are present - assert set(trns.dims) == set(ref.dims), "Transformed data has wrong dimensions." - # check that the coordinates are the same - for dim in ref.dims: - xr.testing.assert_allclose(trns[dim], ref[dim]) - - -# Test input types -@pytest.mark.parametrize( - "dim_sample, dim_feature", - [ - (("time",), ("lat", "lon")), - (("time",), ("lon", "lat")), - (("lat", "lon"), ("time",)), - (("lon", "lat"), ("time",)), - ], -) -def test_fit_input_type( - dim_sample, dim_feature, mock_data_array, mock_dataset, mock_data_array_list -): - s = ListDataArrayScaler() - dim_feature = [dim_feature] * 3 - with pytest.raises(TypeError): - s.fit(mock_dataset, dim_sample, dim_feature) - with pytest.raises(TypeError): - s.fit(mock_data_array, dim_sample, dim_feature) - - s.fit(mock_data_array_list, dim_sample, dim_feature) - with pytest.raises(TypeError): - s.transform(mock_dataset) - with pytest.raises(TypeError): - s.transform(mock_data_array) diff --git a/tests/preprocessing/test_preprocessor.py b/tests/preprocessing/test_preprocessor.py deleted file mode 100644 index 6a28167..0000000 --- a/tests/preprocessing/test_preprocessor.py +++ /dev/null @@ -1,109 +0,0 @@ -import pytest -import numpy as np -import xarray as xr - -from xeofs.preprocessing.preprocessor import Preprocessor - - -@pytest.mark.parametrize( - "with_std, with_coslat, with_weights", - [ - (True, True, True), - (True, True, False), - (True, False, True), - (True, False, False), - (False, True, True), - (False, True, False), - (False, False, True), - (False, False, False), - (True, True, True), - (True, True, False), - (True, False, True), - (True, False, False), - (False, True, True), - (False, True, False), - (False, False, True), - (False, False, False), - ], -) -def test_init_params(with_std, with_coslat, with_weights): - prep = Preprocessor( - with_std=with_std, with_coslat=with_coslat, with_weights=with_weights - ) - - assert hasattr(prep, "_params") - assert prep._params["with_std"] == with_std - assert prep._params["with_coslat"] == with_coslat - assert prep._params["with_weights"] == with_weights - - -@pytest.mark.parametrize( - "with_std, with_coslat, with_weights", - [ - (True, True, True), - (True, True, False), - (True, False, True), - (True, False, False), - (False, True, True), - (False, True, False), - (False, False, True), - (False, False, False), - (True, True, True), - (True, True, False), - (True, False, True), - (True, False, False), - (False, True, True), - (False, True, False), - (False, False, True), - (False, False, False), - ], -) -def test_fit(with_std, with_coslat, with_weights, mock_data_array): - """fit method should not be implemented.""" - prep = Preprocessor( - with_std=with_std, with_coslat=with_coslat, with_weights=with_weights - ) - - with pytest.raises(NotImplementedError): - prep.fit(mock_data_array, dim="time") - - -@pytest.mark.parametrize( - "with_std, with_coslat, with_weights", - [ - (True, True, True), - (True, True, False), - (True, False, True), - (True, False, False), - (False, True, True), - (False, True, False), - (False, False, True), - (False, False, False), - (True, True, True), - (True, True, False), - (True, False, True), - (True, False, False), - (False, True, True), - (False, True, False), - (False, False, True), - (False, False, False), - ], -) -def test_fit_transform(with_std, with_coslat, with_weights, mock_data_array): - """fit method should not be implemented.""" - prep = Preprocessor( - with_std=with_std, with_coslat=with_coslat, with_weights=with_weights - ) - - weights = None - if with_weights: - weights = mock_data_array.mean("time").copy() - weights[:] = 1.0 - - data_trans = prep.fit_transform(mock_data_array, weights=weights, dim="time") - - assert hasattr(prep, "scaler") - assert hasattr(prep, "stacker") - - # Transformed data is centered - assert np.isclose(data_trans.mean("sample"), 0).all() diff --git a/tests/preprocessing/test_preprocessor_dataarray.py b/tests/preprocessing/test_preprocessor_dataarray.py new file mode 100644 index 0000000..2873423 --- /dev/null +++ b/tests/preprocessing/test_preprocessor_dataarray.py @@ -0,0 +1,182 @@ +import pytest +import numpy as np +import xarray as xr + +from xeofs.preprocessing.preprocessor import Preprocessor +from ..conftest import generate_synthetic_dataarray +from ..utilities import ( + get_dims_from_data, + data_is_dask, + data_has_multiindex, + assert_expected_dims, + assert_expected_coords, +) + +# ============================================================================= +# GENERALLY VALID TEST CASES +# ============================================================================= +N_SAMPLE_DIMS = [1, 2] +N_FEATURE_DIMS = [1, 2] +INDEX_POLICY = ["index", "multiindex"] +NAN_POLICY = ["no_nan", "fulldim"] +DASK_POLICY = ["no_dask", "dask"] +SEED = [0] + +TEST_DATA_PARAMS = [ + (ns, nf, index, nan, dask) + for ns in N_SAMPLE_DIMS + for nf in N_FEATURE_DIMS + for index in INDEX_POLICY + for nan in NAN_POLICY + for dask in DASK_POLICY +] + +SAMPLE_DIM_NAMES = ["sample", "sample_alternative"] +FEATURE_DIM_NAMES = ["feature", "feature_alternative"] + +VALID_TEST_CASES = [ + (sample_name, feature_name, data_params) + for sample_name in SAMPLE_DIM_NAMES + for feature_name in FEATURE_DIM_NAMES + for data_params in TEST_DATA_PARAMS +] + + +# TESTS +# ============================================================================= +@pytest.mark.parametrize( + "with_std, with_coslat, with_weights", + [ + (True, True, True), + (True, True, False), + (True, False, True), + (True, False, False), + (False, True, True), + (False, True, False), + (False, False, True), + (False, False, False), + ], +) +def test_fit_transform_scalings(with_std, with_coslat, with_weights, mock_data_array): + """fit method should not be implemented.""" + prep = Preprocessor(with_std=with_std, with_coslat=with_coslat) + + weights = None + if with_weights: + weights = mock_data_array.mean("time").copy() + weights[:] = 0.5 + + data_trans = prep.fit_transform( + mock_data_array, + weights=weights, + sample_dims=("time",), + ) + + assert hasattr(prep, "scaler") + assert hasattr(prep, "preconverter") + assert hasattr(prep, "stacker") + assert hasattr(prep, "postconverter") + assert hasattr(prep, "sanitizer") + + # Transformed data is centered + assert np.isclose(data_trans.mean("sample"), 0).all() + # Transformed data is standardized + if with_std and not with_coslat: + if with_weights: + assert np.isclose(data_trans.std("sample"), 0.5).all() + else: + assert np.isclose(data_trans.std("sample"), 1).all() + + +@pytest.mark.parametrize( + "index_policy, nan_policy, dask_policy", + [ + ("index", "no_nan", "no_dask"), + ("multiindex", "no_nan", "dask"), + ("index", "fulldim", "no_dask"), + ("multiindex", "fulldim", "dask"), + ], +) +def test_fit_transform_same_dim_names(index_policy, nan_policy, dask_policy): + data = generate_synthetic_dataarray(1, 1, index_policy, nan_policy, dask_policy) + all_dims, sample_dims, feature_dims = get_dims_from_data(data) + + prep = Preprocessor(sample_name="sample0", feature_name="feature0") + transformed = prep.fit_transform(data, sample_dims) + reconstructed = prep.inverse_transform_data(transformed) + + data_is_dask_before = data_is_dask(data) + data_is_dask_interm = data_is_dask(transformed) + data_is_dask_after = data_is_dask(reconstructed) + + assert set(transformed.dims) == set(("sample0", "feature0")) + assert set(reconstructed.dims) == set(("sample0", "feature0")) + assert not data_has_multiindex(transformed) + assert transformed.notnull().all() + assert data_is_dask_before == data_is_dask_interm + assert data_is_dask_before == data_is_dask_after + + +@pytest.mark.parametrize( + "sample_name, feature_name, data_params", + VALID_TEST_CASES, +) +def test_fit_transform(sample_name, feature_name, data_params): + data = generate_synthetic_dataarray(*data_params) + all_dims, sample_dims, feature_dims = get_dims_from_data(data) + + prep = Preprocessor(sample_name=sample_name, feature_name=feature_name) + transformed = prep.fit_transform(data, sample_dims) + + data_is_dask_before = data_is_dask(data) + data_is_dask_after = data_is_dask(transformed) + + assert transformed.dims == (sample_name, feature_name) + assert not data_has_multiindex(transformed) + assert transformed.notnull().all() + assert data_is_dask_before == data_is_dask_after + + +@pytest.mark.parametrize( + "sample_name, feature_name, data_params", + VALID_TEST_CASES, +) +def test_inverse_transform(sample_name, feature_name, data_params): + data = generate_synthetic_dataarray(*data_params) + all_dims, sample_dims, feature_dims = get_dims_from_data(data) + + prep = Preprocessor(sample_name=sample_name, feature_name=feature_name) + transformed = prep.fit_transform(data, sample_dims) + components = transformed.rename({sample_name: "mode"}) + scores = transformed.rename({feature_name: "mode"}) + + reconstructed = prep.inverse_transform_data(transformed) + components = prep.inverse_transform_components(components) + scores = prep.inverse_transform_scores(scores) + + # Reconstructed data has the same dimensions as the original data + assert_expected_dims(data, reconstructed, policy="all") + assert_expected_dims(data, components, policy="feature") + assert_expected_dims(data, scores, policy="sample") + + # Reconstructed data has the same coordinates as the original data + assert_expected_coords(data, reconstructed, policy="all") + assert_expected_coords(data, components, policy="feature") + assert_expected_coords(data, scores, policy="sample") + + # Reconstructed data and original data have NaNs in the same FEATURES + # Note: NaNs in the same place is not guaranteed, since isolated NaNs will be propagated + # to all samples in the same feature + features_with_nans_before = data.isnull().any(sample_dims) + features_with_nans_after = reconstructed.isnull().any(sample_dims) + assert features_with_nans_before.equals(features_with_nans_after) + + # Reconstructed data has MultiIndex if and only if original data has MultiIndex + data_has_multiindex_before = data_has_multiindex(data) + data_has_multiindex_after = data_has_multiindex(reconstructed) + assert data_has_multiindex_before == data_has_multiindex_after + + # Reconstructed data is dask if and only if original data is dask + data_is_dask_before = data_is_dask(data) + data_is_dask_after = data_is_dask(reconstructed) + assert data_is_dask_before == data_is_dask_after diff --git a/tests/preprocessing/test_preprocessor_datalist.py b/tests/preprocessing/test_preprocessor_datalist.py new file mode 100644 index 0000000..8602d08 --- /dev/null +++ b/tests/preprocessing/test_preprocessor_datalist.py @@ -0,0 +1,188 @@ +import pytest +import numpy as np +import xarray as xr + +from xeofs.preprocessing.preprocessor import Preprocessor +from ..conftest import generate_list_of_synthetic_dataarrays +from ..utilities import ( + get_dims_from_data_list, + data_is_dask, + data_has_multiindex, + assert_expected_dims, + assert_expected_coords, +) + +# ============================================================================= +# GENERALLY VALID TEST CASES +# ============================================================================= +N_ARRAYS = [1, 2] +N_SAMPLE_DIMS = [1, 2] +N_FEATURE_DIMS = [1, 2] +INDEX_POLICY = ["index", "multiindex"] +NAN_POLICY = ["no_nan", "fulldim"] +DASK_POLICY = ["no_dask", "dask"] +SEED = [0] + +TEST_DATA_PARAMS = [ + (na, ns, nf, index, nan, dask) + for na in N_ARRAYS + for ns in N_SAMPLE_DIMS + for nf in N_FEATURE_DIMS + for index in INDEX_POLICY + for nan in NAN_POLICY + for dask in DASK_POLICY +] + +SAMPLE_DIM_NAMES = ["sample"] +FEATURE_DIM_NAMES = ["feature", "feature_alternative"] + +VALID_TEST_CASES = [ + (sample_name, feature_name, data_params) + for sample_name in SAMPLE_DIM_NAMES + for feature_name in FEATURE_DIM_NAMES + for data_params in TEST_DATA_PARAMS +] + + +# TESTS +# ============================================================================= +@pytest.mark.parametrize( + "with_std, with_coslat, with_weights", + [ + (True, True, True), + (True, True, False), + (True, False, True), + (True, False, False), + (False, True, True), + (False, True, False), + (False, False, True), + (False, False, False), + ], +) +def test_fit_transform_scalings( + with_std, with_coslat, with_weights, mock_data_array_list +): + prep = Preprocessor(with_std=with_std, with_coslat=with_coslat) + + n_data = len(mock_data_array_list) + sample_dims = ("time",) + weights = None + if with_weights: + weights = [da.mean("time").copy() for da in mock_data_array_list] + weights = [xr.ones_like(w) * 0.5 for w in weights] + + data_trans = prep.fit_transform(mock_data_array_list, sample_dims, weights) + + assert hasattr(prep, "scaler") + assert hasattr(prep, "preconverter") + assert hasattr(prep, "stacker") + assert hasattr(prep, "postconverter") + assert hasattr(prep, "sanitizer") + + # Transformed data is centered + assert np.isclose(data_trans.mean("sample"), 0).all() + # Transformed data is standardized + if with_std and not with_coslat: + if with_weights: + assert np.isclose(data_trans.std("sample"), 0.5).all() + else: + assert np.isclose(data_trans.std("sample"), 1).all() + + +@pytest.mark.parametrize( + "index_policy, nan_policy, dask_policy", + [ + ("index", "no_nan", "no_dask"), + ("multiindex", "no_nan", "dask"), + ("index", "fulldim", "no_dask"), + ("multiindex", "fulldim", "dask"), + ], +) +def test_fit_transform_same_dim_names(index_policy, nan_policy, dask_policy): + data = generate_list_of_synthetic_dataarrays( + 1, 1, 1, index_policy, nan_policy, dask_policy + ) + all_dims, sample_dims, feature_dims = get_dims_from_data_list(data) + + prep = Preprocessor(sample_name="sample0", feature_name="feature") + transformed = prep.fit_transform(data, sample_dims[0]) + reconstructed = prep.inverse_transform_data(transformed) + + data_is_dask_before = data_is_dask(data) + data_is_dask_interm = data_is_dask(transformed) + data_is_dask_after = data_is_dask(reconstructed) + + assert set(transformed.dims) == set(("sample0", "feature")) + assert all(set(rec.dims) == set(("sample0", "feature0")) for rec in reconstructed) + assert not data_has_multiindex(transformed) + assert transformed.notnull().all() + assert data_is_dask_before == data_is_dask_interm + assert data_is_dask_before == data_is_dask_after + + +@pytest.mark.parametrize( + "sample_name, feature_name, data_params", + VALID_TEST_CASES, +) +def test_fit_transform(sample_name, feature_name, data_params): + data = generate_list_of_synthetic_dataarrays(*data_params) + all_dims, sample_dims, feature_dims = get_dims_from_data_list(data) + + prep = Preprocessor(sample_name=sample_name, feature_name=feature_name) + transformed = prep.fit_transform(data, sample_dims[0]) + + data_is_dask_before = data_is_dask(data) + data_is_dask_after = data_is_dask(transformed) + + assert transformed.dims == (sample_name, feature_name) + assert not data_has_multiindex(transformed) + assert transformed.notnull().all() + assert data_is_dask_before == data_is_dask_after + + +@pytest.mark.parametrize( + "sample_name, feature_name, data_params", + VALID_TEST_CASES, +) +def test_inverse_transform(sample_name, feature_name, data_params): + data = generate_list_of_synthetic_dataarrays(*data_params) + all_dims, sample_dims, feature_dims = get_dims_from_data_list(data) + + prep = Preprocessor(sample_name=sample_name, feature_name=feature_name) + transformed = prep.fit_transform(data, sample_dims[0]) + components = transformed.rename({sample_name: "mode"}) + scores = transformed.rename({feature_name: "mode"}) + + reconstructed = prep.inverse_transform_data(transformed) + components = prep.inverse_transform_components(components) + scores = prep.inverse_transform_scores(scores) + + # Reconstructed data has the same dimensions as the original data + assert_expected_dims(data, reconstructed, policy="all") + assert_expected_dims(data, components, policy="feature") + assert_expected_dims(data[0], scores, policy="sample") + + # Reconstructed data has the same coordinates as the original data + assert_expected_coords(data, reconstructed, policy="all") + assert_expected_coords(data, components, policy="feature") + assert_expected_coords(data[0], scores, policy="sample") + + # Reconstructed data and original data have NaNs in the same FEATURES + # Note: NaNs in the same place is not guaranteed, since isolated NaNs will be propagated + # to all samples in the same feature + nan_features_before = [da.isnull().any(sample_dims[0]) for da in data] + nan_features_after = [rec.isnull().any(sample_dims[0]) for rec in reconstructed] + assert all( + before.equals(after) + for before, after in zip(nan_features_before, nan_features_after) + ) + + # Reconstructed data has MultiIndex if and only if original data has MultiIndex + data_has_multiindex_before = data_has_multiindex(data) + data_has_multiindex_after = data_has_multiindex(reconstructed) + assert data_has_multiindex_before == data_has_multiindex_after + + # Reconstructed data is dask if and only if original data is dask + data_is_dask_before = data_is_dask(data) + data_is_dask_after = data_is_dask(reconstructed) + assert data_is_dask_before == data_is_dask_after diff --git a/tests/preprocessing/test_preprocessor_dataset.py b/tests/preprocessing/test_preprocessor_dataset.py new file mode 100644 index 0000000..9c047f3 --- /dev/null +++ b/tests/preprocessing/test_preprocessor_dataset.py @@ -0,0 +1,181 @@ +import pytest +import numpy as np +import xarray as xr + +from xeofs.preprocessing.preprocessor import Preprocessor +from ..conftest import generate_synthetic_dataset +from ..utilities import ( + get_dims_from_data, + data_is_dask, + data_has_multiindex, + assert_expected_dims, + assert_expected_coords, +) + +# ============================================================================= +# GENERALLY VALID TEST CASES +# ============================================================================= +N_VARIABLES = [1, 2] +N_SAMPLE_DIMS = [1, 2] +N_FEATURE_DIMS = [1, 2] +INDEX_POLICY = ["index", "multiindex"] +NAN_POLICY = ["no_nan", "fulldim"] +DASK_POLICY = ["no_dask", "dask"] +SEED = [0] + +TEST_DATA_PARAMS = [ + (nv, ns, nf, index, nan, dask) + for nv in N_VARIABLES + for ns in N_SAMPLE_DIMS + for nf in N_FEATURE_DIMS + for index in INDEX_POLICY + for nan in NAN_POLICY + for dask in DASK_POLICY +] + +SAMPLE_DIM_NAMES = ["sample"] +FEATURE_DIM_NAMES = ["feature", "feature_alternative"] + +VALID_TEST_CASES = [ + (sample_name, feature_name, data_params) + for sample_name in SAMPLE_DIM_NAMES + for feature_name in FEATURE_DIM_NAMES + for data_params in TEST_DATA_PARAMS +] + + +# TESTS +# ============================================================================= +@pytest.mark.parametrize( + "with_std, with_coslat, with_weights", + [ + (True, True, True), + (True, True, False), + (True, False, True), + (True, False, False), + (False, True, True), + (False, True, False), + (False, False, True), + (False, False, False), + ], +) +def test_fit_transform_scalings(with_std, with_coslat, with_weights, mock_dataset): + """fit method should not be implemented.""" + prep = Preprocessor(with_std=with_std, with_coslat=with_coslat) + + weights = None + if with_weights: + weights = mock_dataset.mean("time").copy() + weights = weights.where(weights == True, 0.5) + + data_trans = prep.fit_transform(mock_dataset, "time", weights) + + assert hasattr(prep, "scaler") + assert hasattr(prep, "renamer") + assert hasattr(prep, "preconverter") + assert hasattr(prep, "stacker") + assert hasattr(prep, "postconverter") + assert hasattr(prep, "sanitizer") + + # Transformed data is centered + assert np.isclose(data_trans.mean("sample"), 0).all() + # Transformed data is standardized + if with_std and not with_coslat: + if with_weights: + assert np.isclose(data_trans.std("sample"), 0.5).all() + else: + assert np.isclose(data_trans.std("sample"), 1).all() + + +@pytest.mark.parametrize( + "index_policy, nan_policy, dask_policy", + [ + ("index", "no_nan", "no_dask"), + ("multiindex", "no_nan", "dask"), + ("index", "fulldim", "no_dask"), + ("multiindex", "fulldim", "dask"), + ], +) +def test_fit_transform_same_dim_names(index_policy, nan_policy, dask_policy): + data = generate_synthetic_dataset(1, 1, 1, index_policy, nan_policy, dask_policy) + all_dims, sample_dims, feature_dims = get_dims_from_data(data) + + prep = Preprocessor(sample_name="sample0", feature_name="feature") + transformed = prep.fit_transform(data, sample_dims) + reconstructed = prep.inverse_transform_data(transformed) + + data_is_dask_before = data_is_dask(data) + data_is_dask_interm = data_is_dask(transformed) + data_is_dask_after = data_is_dask(reconstructed) + + assert set(transformed.dims) == set(("sample0", "feature")) + assert set(reconstructed.dims) == set(("sample0", "feature0")) + assert not data_has_multiindex(transformed) + assert transformed.notnull().all() + assert data_is_dask_before == data_is_dask_interm + assert data_is_dask_before == data_is_dask_after + + +@pytest.mark.parametrize( + "sample_name, feature_name, data_params", + VALID_TEST_CASES, +) +def test_fit_transform(sample_name, feature_name, data_params): + data = generate_synthetic_dataset(*data_params) + all_dims, sample_dims, feature_dims = get_dims_from_data(data) + + prep = Preprocessor(sample_name=sample_name, feature_name=feature_name) + transformed = prep.fit_transform(data, sample_dims) + + data_is_dask_before = data_is_dask(data) + data_is_dask_after = data_is_dask(transformed) + + assert transformed.dims == (sample_name, feature_name) + assert not data_has_multiindex(transformed) + assert transformed.notnull().all() + assert data_is_dask_before == data_is_dask_after + + +@pytest.mark.parametrize( + "sample_name, feature_name, data_params", + VALID_TEST_CASES, +) +def test_inverse_transform(sample_name, feature_name, data_params): + data = generate_synthetic_dataset(*data_params) + all_dims, sample_dims, feature_dims = get_dims_from_data(data) + + prep = Preprocessor(sample_name=sample_name, feature_name=feature_name) + transformed = prep.fit_transform(data, sample_dims) + components = transformed.rename({sample_name: "mode"}) + scores = transformed.rename({feature_name: "mode"}) + + reconstructed = prep.inverse_transform_data(transformed) + components = prep.inverse_transform_components(components) + scores = prep.inverse_transform_scores(scores) + + # Reconstructed data has the same dimensions as the original data + assert_expected_dims(data, reconstructed, policy="all") + assert_expected_dims(data, components, policy="feature") + assert_expected_dims(data, scores, policy="sample") + + # Reconstructed data has the same coordinates as the original data + assert_expected_coords(data, reconstructed, policy="all") + assert_expected_coords(data, components, policy="feature") + assert_expected_coords(data, scores, policy="sample") + + # Reconstructed data and original data have NaNs in the same FEATURES + # Note: NaNs in the same place is not guaranteed, since isolated NaNs will be propagated + # to all samples in the same feature + features_with_nans_before = data.isnull().any(sample_dims) + features_with_nans_after = reconstructed.isnull().any(sample_dims) + assert features_with_nans_before.equals(features_with_nans_after) + + # Reconstructed data has MultiIndex if and only if original data has MultiIndex + data_has_multiindex_before = data_has_multiindex(data) + data_has_multiindex_after = data_has_multiindex(reconstructed) + assert data_has_multiindex_before == data_has_multiindex_after + + # Reconstructed data is dask if and only if original data is dask + data_is_dask_before = data_is_dask(data) + data_is_dask_after = data_is_dask(reconstructed) + assert data_is_dask_before == data_is_dask_after diff --git a/tests/preprocessing/test_single_dataarray_scaler.py b/tests/preprocessing/test_single_dataarray_scaler.py deleted file mode 100644 index 042c106..0000000 --- a/tests/preprocessing/test_single_dataarray_scaler.py +++ /dev/null @@ -1,245 +0,0 @@ -import pytest -import xarray as xr -import numpy as np - -from xeofs.preprocessing.scaler import SingleDataArrayScaler - - -@pytest.mark.parametrize( - "with_std, with_coslat, with_weights", - [ - (True, True, True), - (True, True, False), - (True, False, True), - (True, False, False), - (False, True, True), - (False, True, False), - (False, False, True), - (False, False, False), - (True, True, True), - (True, True, False), - (True, False, True), - (True, False, False), - (False, True, True), - (False, True, False), - (False, False, True), - (False, False, False), - ], -) -def test_init_params(with_std, with_coslat, with_weights): - s = SingleDataArrayScaler( - with_std=with_std, with_coslat=with_coslat, with_weights=with_weights - ) - assert hasattr(s, "_params") - assert s._params["with_std"] == with_std - assert s._params["with_coslat"] == with_coslat - assert s._params["with_weights"] == with_weights - - -@pytest.mark.parametrize( - "with_std, with_coslat, with_weights", - [ - (True, True, True), - (True, True, False), - (True, False, True), - (True, False, False), - (False, True, True), - (False, True, False), - (False, False, True), - (False, False, False), - (True, True, True), - (True, True, False), - (True, False, True), - (True, False, False), - (False, True, True), - (False, True, False), - (False, False, True), - (False, False, False), - ], -) -def test_fit_params(with_std, with_coslat, with_weights, mock_data_array): - s = SingleDataArrayScaler( - with_std=with_std, with_coslat=with_coslat, with_weights=with_weights - ) - sample_dims = ["time"] - feature_dims = ["lat", "lon"] - size_lats = mock_data_array.lat.size - weights = xr.DataArray(np.random.rand(size_lats), dims=["lat"]) - s.fit(mock_data_array, sample_dims, feature_dims, weights) - assert hasattr(s, "mean"), "Scaler has no mean attribute." - if with_std: - assert hasattr(s, "std"), "Scaler has no std attribute." - if with_coslat: - assert hasattr(s, "coslat_weights"), "Scaler has no coslat_weights attribute." - if with_weights: - assert hasattr(s, "weights"), "Scaler has no weights attribute." - assert s.mean is not None, "Scaler mean is None." - if with_std: - assert s.std is not None, "Scaler std is None." - if with_coslat: - assert s.coslat_weights is not None, "Scaler coslat_weights is None." - if with_weights: - assert s.weights is not None, "Scaler weights is None." - - -@pytest.mark.parametrize( - "with_std, with_coslat, with_weights", - [ - (True, True, True), - (True, True, False), - (True, False, True), - (True, False, False), - (False, True, True), - (False, True, False), - (False, False, True), - (False, False, False), - (True, True, True), - (True, True, False), - (True, False, True), - (True, False, False), - (False, True, True), - (False, True, False), - (False, False, True), - (False, False, False), - ], -) -def test_transform_params(with_std, with_coslat, with_weights, mock_data_array): - s = SingleDataArrayScaler( - with_std=with_std, with_coslat=with_coslat, with_weights=with_weights - ) - sample_dims = ["time"] - feature_dims = ["lat", "lon"] - size_lats = mock_data_array.lat.size - weights = xr.DataArray( - np.random.rand(size_lats), dims=["lat"], coords={"lat": mock_data_array.lat} - ) - s.fit(mock_data_array, sample_dims, feature_dims, weights) - transformed = s.transform(mock_data_array) - assert transformed is not None, "Transformed data is None." - - transformed_mean = transformed.mean(sample_dims, skipna=False) - assert np.allclose(transformed_mean, 0), "Mean of the transformed data is not zero." - - if with_std: - transformed_std = transformed.std(sample_dims, skipna=False) - if with_coslat or with_weights: - assert ( - transformed_std <= 1 - ).all(), "Standard deviation of the transformed data is larger one." - else: - assert np.allclose( - transformed_std, 1 - ), "Standard deviation of the transformed data is not one." - - if with_coslat: - assert s.coslat_weights is not None, "Scaler coslat_weights is None." - assert not np.array_equal( - transformed, mock_data_array - ), "Data has not been transformed." - - if with_weights: - assert s.weights is not None, "Scaler weights is None." - assert not np.array_equal( - transformed, mock_data_array - ), "Data has not been transformed." - - transformed2 = s.fit_transform(mock_data_array, sample_dims, feature_dims, weights) - xr.testing.assert_allclose(transformed, transformed2) - - -@pytest.mark.parametrize( - "with_std, with_coslat, with_weights", - [ - (True, True, True), - (True, True, False), - (True, False, True), - (True, False, False), - (False, True, True), - (False, True, False), - (False, False, True), - (False, False, False), - (True, True, True), - (True, True, False), - (True, False, True), - (True, False, False), - (False, True, True), - (False, True, False), - (False, False, True), - (False, False, False), - ], -) -def test_inverse_transform_params(with_std, with_coslat, with_weights, mock_data_array): - s = SingleDataArrayScaler( - with_std=with_std, with_coslat=with_coslat, with_weights=with_weights - ) - sample_dims = ["time"] - feature_dims = ["lat", "lon"] - size_lats = mock_data_array.lat.size - weights = xr.DataArray( - np.random.rand(size_lats), dims=["lat"], coords={"lat": mock_data_array.lat} - ) - s.fit(mock_data_array, sample_dims, feature_dims, weights) - transformed = s.transform(mock_data_array) - inverted = s.inverse_transform(transformed) - xr.testing.assert_allclose(inverted, mock_data_array) - - -@pytest.mark.parametrize( - "dim_sample, dim_feature", - [ - (("time",), ("lat", "lon")), - (("time",), ("lon", "lat")), - (("lat", "lon"), ("time",)), - (("lon", "lat"), ("time",)), - ], -) -def test_fit_dims(dim_sample, dim_feature, mock_data_array): - s = SingleDataArrayScaler() - s.fit(mock_data_array, dim_sample, dim_feature) - assert hasattr(s, "mean"), "Scaler has no mean attribute." - assert s.mean is not None, "Scaler mean is None." - assert hasattr(s, "std"), "Scaler has no std attribute." - assert s.std is not None, "Scaler std is None." - # check that all dimensions are present except the sample dimensions - assert set(s.mean.dims) == set(mock_data_array.dims) - set( - dim_sample - ), "Mean has wrong dimensions." - assert set(s.std.dims) == set(mock_data_array.dims) - set( - dim_sample - ), "Standard deviation has wrong dimensions." - - -@pytest.mark.parametrize( - "dim_sample, dim_feature", - [ - (("time",), ("lat", "lon")), - (("time",), ("lon", "lat")), - (("lat", "lon"), ("time",)), - (("lon", "lat"), ("time",)), - ], -) -def test_fit_transform_dims(dim_sample, dim_feature, mock_data_array): - s = SingleDataArrayScaler() - transformed = s.fit_transform(mock_data_array, dim_sample, dim_feature) - # check that all dimensions are present - assert set(transformed.dims) == set( - mock_data_array.dims - ), "Transformed data has wrong dimensions." - # check that the coordinates are the same - for dim in mock_data_array.dims: - xr.testing.assert_allclose(transformed[dim], mock_data_array[dim]) - - -# Test input types -def test_fit_input_type(mock_data_array, mock_dataset, mock_data_array_list): - s = SingleDataArrayScaler() - with pytest.raises(TypeError): - s.fit(mock_dataset, ["time"], ["lon", "lat"]) - with pytest.raises(TypeError): - s.fit(mock_data_array_list, ["time"], ["lon", "lat"]) - - s.fit(mock_data_array, ["time"], ["lon", "lat"]) - with pytest.raises(TypeError): - s.transform(mock_dataset) - with pytest.raises(TypeError): - s.transform(mock_data_array_list) diff --git a/tests/utilities.py b/tests/utilities.py new file mode 100644 index 0000000..1609f35 --- /dev/null +++ b/tests/utilities.py @@ -0,0 +1,172 @@ +from typing import Tuple, List, Hashable +import numpy as np +import pandas as pd +import xarray as xr +import dask.array as da +from xeofs.utils.data_types import ( + DataArray, + DataSet, + DataList, + DaskArray, + Dims, + DimsList, + DimsTuple, + DimsListTuple, +) + + +def is_xdata(data): + return isinstance(data, (DataArray, DataSet)) + + +def get_dims_from_data(data: DataArray | DataSet) -> DimsTuple: + # If data is DataArray/Dataset + if is_xdata(data): + data_dims: Dims = tuple(data.dims) + sample_dims: Dims = tuple([dim for dim in data.dims if "sample" in str(dim)]) + feature_dims: Dims = tuple([dim for dim in data.dims if "feature" in str(dim)]) + return data_dims, sample_dims, feature_dims + else: + raise ValueError("unrecognized input type") + + +def get_dims_from_data_list(data_list: DataList) -> DimsListTuple: + # If data is list + if isinstance(data_list, list): + data_dims: DimsList = [data.dims for data in data_list] + sample_dims: DimsList = [] + feature_dims: DimsList = [] + for data in data_list: + sdims = tuple([dim for dim in data.dims if "sample" in str(dim)]) + fdims = tuple([dim for dim in data.dims if "feature" in str(dim)]) + sample_dims.append(sdims) + feature_dims.append(fdims) + return data_dims, sample_dims, feature_dims + + else: + raise ValueError("unrecognized input type") + + +def data_has_multiindex(data: DataArray | DataSet | DataList) -> bool: + """Check if the given data object has any MultiIndex.""" + if isinstance(data, DataArray) or isinstance(data, DataSet): + return any(isinstance(index, pd.MultiIndex) for index in data.indexes.values()) + elif isinstance(data, list): + return all(data_has_multiindex(da) for da in data) + else: + raise ValueError("unrecognized input type") + + +def data_is_dask(data: DataArray | DataSet | DataList) -> bool: + """Check if the given data is backed by a dask array.""" + + # If data is a DataArray, check its underlying data type + if isinstance(data, DataArray): + return isinstance(data.data, DaskArray) + + # If data is a DataSet, recursively check all contained DataArrays + if isinstance(data, DataSet): + return all(data_is_dask(da) for da in data.data_vars.values()) + + # If data is a list, recursively check each element in the list + if isinstance(data, list): + return all(data_is_dask(da) for da in data) + + # If none of the above, the data type is unrecognized + raise ValueError("unrecognized data type.") + + +def assert_expected_dims(data1, data2, policy="all"): + """ + Check if dimensions of two data objects matches. + + Parameters: + - data1: Reference data object (either a DataArray, DataSet, or list of DataArray) + - data2: Test data object (same type as data1) + - policy: Policy to check the dimensions. Can be either "all", "feature" or "sample" + + """ + + if is_xdata(data1) and is_xdata(data2): + all_dims1, sample_dims1, feature_dims1 = get_dims_from_data(data1) + all_dims2, sample_dims2, feature_dims2 = get_dims_from_data(data2) + + if policy == "all": + err_msg = "Dimensions do not match: {:} vs {:}".format(all_dims1, all_dims2) + assert set(all_dims1) == set(all_dims2), err_msg + elif policy == "feature": + err_msg = "Dimensions do not match: {:} vs {:}".format( + feature_dims1, feature_dims2 + ) + assert set(feature_dims1) == set(feature_dims2), err_msg + assert len(sample_dims2) == 0, "Sample dimensions should be empty" + assert "mode" in all_dims2, "Mode dimension is missing" + + elif policy == "sample": + err_msg = "Dimensions do not match: {:} vs {:}".format( + sample_dims1, sample_dims2 + ) + assert set(sample_dims1) == set(sample_dims2), err_msg + assert len(feature_dims2) == 0, "Feature dimensions should be empty" + assert "mode" in all_dims2, "Mode dimension is missing" + else: + raise ValueError("Unrecognized policy: {:}".format(policy)) + + elif isinstance(data1, list) and isinstance(data2, list): + for da1, da2 in zip(data1, data2): + assert_expected_dims(da1, da2, policy=policy) + + # If neither of the above conditions are met, raise an error + else: + raise ValueError( + "Cannot check coordinates. Unrecognized data type. data1: {:}, data2: {:}".format( + type(data1), type(data2) + ) + ) + + +def assert_expected_coords(data1, data2, policy="all") -> None: + """ + Check if coordinates of the data objects matches. + + Parameters: + - data1: Reference data object (either a DataArray, DataSet, or list of DataArray) + - data2: Test data object (same type as data1) + - policy: Policy to check the dimensions. Can be either "all", "feature" or "sample" + + """ + + # Data objects is either DataArray or DataSet + if is_xdata(data1) and is_xdata(data2): + all_dims1, sample_dims1, feature_dims1 = get_dims_from_data(data1) + all_dims2, sample_dims2, feature_dims2 = get_dims_from_data(data2) + if policy == "all": + assert all( + np.all(data1.coords[dim].values == data2.coords[dim].values) + for dim in all_dims1 + ) + elif policy == "feature": + assert all( + np.all(data1.coords[dim].values == data2.coords[dim].values) + for dim in feature_dims1 + ) + elif policy == "sample": + assert all( + np.all(data1.coords[dim].values == data2.coords[dim].values) + for dim in sample_dims1 + ) + else: + raise ValueError("Unrecognized policy: {:}".format(policy)) + + # Data object is list + elif isinstance(data1, list) and isinstance(data2, list): + for da1, da2 in zip(data1, data2): + assert_expected_coords(da1, da2, policy=policy) + + # If neither of the above conditions are met, raise an error + else: + raise ValueError( + "Cannot check coordinates. Unrecognized data type. data1: {:}, data2: {:}".format( + type(data1), type(data2) + ) + ) diff --git a/tests/validation/test_eof_bootstrapper.py b/tests/validation/test_eof_bootstrapper.py index e1a43f4..e78ef12 100644 --- a/tests/validation/test_eof_bootstrapper.py +++ b/tests/validation/test_eof_bootstrapper.py @@ -56,51 +56,51 @@ def test_fit(eof_model): # DataArrays are created assert isinstance( - bootstrapper.data.explained_variance, xr.DataArray + bootstrapper.data["explained_variance"], xr.DataArray ), "explained variance is not a DataArray" assert isinstance( - bootstrapper.data.components, xr.DataArray + bootstrapper.data["components"], xr.DataArray ), "components is not a DataArray" assert isinstance( - bootstrapper.data.scores, xr.DataArray + bootstrapper.data["scores"], xr.DataArray ), "scores is not a DataArray" # DataArrays have expected dims - expected_dims = set(eof_model.data.explained_variance.dims) + expected_dims = set(eof_model.data["explained_variance"].dims) expected_dims.add("n") - true_dims = set(bootstrapper.data.explained_variance.dims) + true_dims = set(bootstrapper.data["explained_variance"].dims) err_message = ( f"explained variance dimensions are {true_dims} instead of {expected_dims}" ) assert true_dims == expected_dims, err_message - expected_dims = set(eof_model.data.components.dims) + expected_dims = set(eof_model.data["components"].dims) expected_dims.add("n") - true_dims = set(bootstrapper.data.components.dims) + true_dims = set(bootstrapper.data["components"].dims) err_message = f"components dimensions are {true_dims} instead of {expected_dims}" assert true_dims == expected_dims, err_message - expected_dims = set(eof_model.data.scores.dims) + expected_dims = set(eof_model.data["scores"].dims) expected_dims.add("n") - true_dims = set(bootstrapper.data.scores.dims) + true_dims = set(bootstrapper.data["scores"].dims) err_message = f"scores dimensions are {true_dims} instead of {expected_dims}" assert true_dims == expected_dims, err_message # DataArrays have expected coords - ref_da = eof_model.data.explained_variance - test_da = bootstrapper.data.explained_variance + ref_da = eof_model.data["explained_variance"] + test_da = bootstrapper.data["explained_variance"] for dim, coords in ref_da.coords.items(): assert test_da[dim].equals( coords ), f"explained variance coords for {dim} are not equal" - ref_da = eof_model.data.components - test_da = bootstrapper.data.components + ref_da = eof_model.data["components"] + test_da = bootstrapper.data["components"] for dim, coords in ref_da.coords.items(): assert test_da[dim].equals(coords), f"components coords for {dim} are not equal" - ref_da = eof_model.data.scores - test_da = bootstrapper.data.scores + ref_da = eof_model.data["scores"] + test_da = bootstrapper.data["scores"] for dim, coords in ref_da.coords.items(): assert test_da[dim].equals(coords), f"scores coords for {dim} are not equal" diff --git a/xeofs/data_container/__init__.py b/xeofs/data_container/__init__.py index 4065e8c..3a8fcd5 100644 --- a/xeofs/data_container/__init__.py +++ b/xeofs/data_container/__init__.py @@ -1,12 +1 @@ -from ._base_model_data_container import _BaseModelDataContainer -from ._base_cross_model_data_container import _BaseCrossModelDataContainer -from .eof_data_container import EOFDataContainer, ComplexEOFDataContainer -from .mca_data_container import MCADataContainer, ComplexMCADataContainer -from .eof_rotator_data_container import ( - EOFRotatorDataContainer, - ComplexEOFRotatorDataContainer, -) -from .mca_rotator_data_container import ( - MCARotatorDataContainer, - ComplexMCARotatorDataContainer, -) +from .data_container import DataContainer diff --git a/xeofs/data_container/_base_cross_model_data_container.py b/xeofs/data_container/_base_cross_model_data_container.py deleted file mode 100644 index a53a9fa..0000000 --- a/xeofs/data_container/_base_cross_model_data_container.py +++ /dev/null @@ -1,128 +0,0 @@ -from abc import ABC -from typing import Optional - -from dask.diagnostics.progress import ProgressBar - -from ..utils.data_types import DataArray - - -class _BaseCrossModelDataContainer(ABC): - """Abstract base class that holds the cross model data.""" - - def __init__(self): - self._input_data1: Optional[DataArray] = None - self._input_data2: Optional[DataArray] = None - self._components1: Optional[DataArray] = None - self._components2: Optional[DataArray] = None - self._scores1: Optional[DataArray] = None - self._scores2: Optional[DataArray] = None - - @staticmethod - def _verify_dims(da: DataArray, dims_expected: tuple): - """Verify that the dimensions of the data are correct.""" - if not set(da.dims) == set(dims_expected): - raise ValueError(f"The data must have dimensions {dims_expected}.") - - @staticmethod - def _sanity_check(data) -> DataArray: - """Check whether the Data of the DataContainer has been set.""" - if data is None: - raise ValueError("There is no data. Have you called .fit()?") - else: - return data - - def set_data( - self, - input_data1: DataArray, - input_data2: DataArray, - components1: DataArray, - components2: DataArray, - scores1: DataArray, - scores2: DataArray, - ): - self._verify_dims(input_data1, ("sample", "feature")) - self._verify_dims(input_data2, ("sample", "feature")) - self._verify_dims(components1, ("feature", "mode")) - self._verify_dims(components2, ("feature", "mode")) - self._verify_dims(scores1, ("sample", "mode")) - self._verify_dims(scores2, ("sample", "mode")) - - components1.name = "left_components" - components2.name = "right_components" - scores1.name = "left_scores" - scores2.name = "right_scores" - - self._input_data1 = input_data1 - self._input_data2 = input_data2 - self._components1 = components1 - self._components2 = components2 - self._scores1 = scores1 - self._scores2 = scores2 - - @property - def input_data1(self) -> DataArray: - """Get the left input data.""" - data1 = self._sanity_check(self._input_data1) - return data1 - - @property - def input_data2(self) -> DataArray: - """Get the right input data.""" - data2 = self._sanity_check(self._input_data2) - return data2 - - @property - def components1(self) -> DataArray: - """Get the left components.""" - components1 = self._sanity_check(self._components1) - return components1 - - @property - def components2(self) -> DataArray: - """Get the right components.""" - components2 = self._sanity_check(self._components2) - return components2 - - @property - def scores1(self) -> DataArray: - """Get the left scores.""" - scores1 = self._sanity_check(self._scores1) - return scores1 - - @property - def scores2(self) -> DataArray: - """Get the right scores.""" - scores2 = self._sanity_check(self._scores2) - return scores2 - - def compute(self, verbose=False): - """Compute and load delayed dask DataArrays into memory. - - Parameters - ---------- - verbose : bool - Whether or not to provide additional information about the computing progress. - """ - if verbose: - with ProgressBar(): - self._components1 = self.components1.compute() - self._components2 = self.components2.compute() - self._scores1 = self.scores1.compute() - self._scores2 = self.scores2.compute() - else: - self._components1 = self.components1.compute() - self._components2 = self.components2.compute() - self._scores1 = self.scores1.compute() - self._scores2 = self.scores2.compute() - - def set_attrs(self, attrs: dict): - """Set the attributes of the results.""" - components1 = self._sanity_check(self._components1) - components2 = self._sanity_check(self._components2) - scores1 = self._sanity_check(self._scores1) - scores2 = self._sanity_check(self._scores2) - - components1.attrs.update(attrs) - components2.attrs.update(attrs) - scores1.attrs.update(attrs) - scores2.attrs.update(attrs) diff --git a/xeofs/data_container/_base_model_data_container.py b/xeofs/data_container/_base_model_data_container.py deleted file mode 100644 index 2dc00f3..0000000 --- a/xeofs/data_container/_base_model_data_container.py +++ /dev/null @@ -1,83 +0,0 @@ -from abc import ABC -from typing import Optional - -from dask.diagnostics.progress import ProgressBar - -from ..utils.data_types import DataArray - - -class _BaseModelDataContainer(ABC): - """Abstract base class that holds the model data.""" - - def __init__(self): - self._input_data: Optional[DataArray] = None - self._components: Optional[DataArray] = None - self._scores: Optional[DataArray] = None - - @staticmethod - def _verify_dims(da: DataArray, dims_expected: tuple): - """Verify that the dimensions of the data are correct.""" - if not set(da.dims) == set(dims_expected): - raise ValueError(f"The data must have dimensions {dims_expected}.") - - @staticmethod - def _sanity_check(data) -> DataArray: - """Check whether the Data of the DataContainer has been set.""" - if data is None: - raise ValueError("There is no data. Have you called .fit()?") - else: - return data - - def set_data(self, input_data: DataArray, components: DataArray, scores: DataArray): - self._verify_dims(input_data, ("sample", "feature")) - self._verify_dims(components, ("feature", "mode")) - self._verify_dims(scores, ("sample", "mode")) - - components.name = "components" - scores.name = "scores" - - self._input_data = input_data - self._components = components - self._scores = scores - - @property - def input_data(self) -> DataArray: - """Get the input data.""" - data = self._sanity_check(self._input_data) - return data - - @property - def components(self) -> DataArray: - """Get the components.""" - components = self._sanity_check(self._components) - return components - - @property - def scores(self) -> DataArray: - """Get the scores.""" - scores = self._sanity_check(self._scores) - return scores - - def compute(self, verbose=False): - """Compute and load delayed dask DataArrays into memory. - - Parameters - ---------- - verbose : bool - Whether or not to provide additional information about the computing progress. - """ - if verbose: - with ProgressBar(): - self._components = self.components.compute() - self._scores = self.scores.compute() - else: - self._components = self.components.compute() - self._scores = self.scores.compute() - - def set_attrs(self, attrs: dict): - """Set the attributes of the results.""" - components = self._sanity_check(self._components) - scores = self._sanity_check(self._scores) - - components.attrs.update(attrs) - scores.attrs.update(attrs) diff --git a/xeofs/data_container/data_container.py b/xeofs/data_container/data_container.py new file mode 100644 index 0000000..0f1e25f --- /dev/null +++ b/xeofs/data_container/data_container.py @@ -0,0 +1,51 @@ +from typing import Dict +from dask.diagnostics.progress import ProgressBar + +from ..utils.data_types import DataArray + + +class DataContainer(dict): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._allow_compute = dict({k: True for k in self.keys()}) + + def add(self, data: DataArray, name: str, allow_compute: bool = True) -> None: + data.name = name + super().__setitem__(name, data) + self._allow_compute[name] = True if allow_compute else False + + def __setitem__(self, __key: str, __value: DataArray) -> None: + super().__setitem__(__key, __value) + self._allow_compute[__key] = self._allow_compute.get(__key, True) + + def __getitem__(self, __key: str) -> DataArray: + try: + return super().__getitem__(__key) + except KeyError: + raise KeyError( + f"Cannot find data '{__key}'. Please fit the model first by calling .fit()." + ) + + def compute(self, verbose=False): + for k, v in self.items(): + if self._allow_compute[k]: + if verbose: + with ProgressBar(): + self[k] = v.compute() + else: + self[k] = v.compute() + + def _validate_attrs(self, attrs: Dict) -> Dict: + """Convert any boolean and None values to strings""" + for key, value in attrs.items(): + if isinstance(value, bool): + attrs[key] = str(value) + elif value is None: + attrs[key] = "None" + + return attrs + + def set_attrs(self, attrs: Dict): + attrs = self._validate_attrs(attrs) + for key in self.keys(): + self[key].attrs = attrs diff --git a/xeofs/data_container/eof_bootstrapper_data_container.py b/xeofs/data_container/eof_bootstrapper_data_container.py deleted file mode 100644 index 0abe5c7..0000000 --- a/xeofs/data_container/eof_bootstrapper_data_container.py +++ /dev/null @@ -1,25 +0,0 @@ -import numpy as np - -from ..data_container.eof_data_container import EOFDataContainer -from ..utils.data_types import DataArray - - -class EOFBootstrapperDataContainer(EOFDataContainer): - """Container that holds the data related to a Bootstrapper EOF model.""" - - @staticmethod - def _verify_dims(da: DataArray, dims: tuple): - """Verify that the dimensions of the data are correct.""" - # Bootstrapper EOFs have an additional dimension for the bootstrap - expected_dims = dims - given_dims = da.dims - - # In the case of the input data, the dimensions are ('sample', 'feature') - # Otherwise, the data should have an additional dimension for the bootstrap `n` - has_input_data_dims = set(given_dims) == set(("sample", "feature")) - if not has_input_data_dims: - expected_dims = ("n",) + dims - - dims_are_equal = set(given_dims) == set(expected_dims) - if not dims_are_equal: - raise ValueError(f"The data must have dimensions {expected_dims}.") diff --git a/xeofs/data_container/eof_data_container.py b/xeofs/data_container/eof_data_container.py deleted file mode 100644 index b22a7da..0000000 --- a/xeofs/data_container/eof_data_container.py +++ /dev/null @@ -1,135 +0,0 @@ -from typing import Optional - -import numpy as np -import xarray as xr -from dask.diagnostics.progress import ProgressBar - -from ._base_model_data_container import _BaseModelDataContainer -from ..utils.data_types import DataArray - - -class EOFDataContainer(_BaseModelDataContainer): - """Container to store the results of an EOF analysis.""" - - def __init__(self): - super().__init__() - self._explained_variance: Optional[DataArray] = None - self._total_variance: Optional[DataArray] = None - self._idx_modes_sorted: Optional[DataArray] = None - - def set_data( - self, - input_data: DataArray, - components: DataArray, - scores: DataArray, - explained_variance: DataArray, - total_variance: DataArray, - idx_modes_sorted: DataArray, - ): - super().set_data(input_data=input_data, components=components, scores=scores) - - self._verify_dims(explained_variance, ("mode",)) - self._explained_variance = explained_variance - self._explained_variance.name = "explained_variance" - - self._total_variance = total_variance - self._total_variance.name = "total_variance" - - self._verify_dims(idx_modes_sorted, ("mode",)) - self._idx_modes_sorted = idx_modes_sorted - self._idx_modes_sorted.name = "idx_modes_sorted" - - @property - def total_variance(self) -> DataArray: - """Get the total variance.""" - total_var = super()._sanity_check(self._total_variance) - return total_var - - @property - def explained_variance(self) -> DataArray: - """Get the explained variance.""" - exp_var = super()._sanity_check(self._explained_variance) - return exp_var - - @property - def explained_variance_ratio(self) -> DataArray: - """Get the explained variance ratio.""" - expvar_ratio = self.explained_variance / self.total_variance - expvar_ratio.name = "explained_variance_ratio" - expvar_ratio.attrs.update(self.explained_variance.attrs) - return expvar_ratio - - @property - def idx_modes_sorted(self) -> DataArray: - """Get the index of the sorted explained variance.""" - idx_modes_sorted = super()._sanity_check(self._idx_modes_sorted) - return idx_modes_sorted - - @property - def singular_values(self) -> DataArray: - """Get the explained variance.""" - svals = np.sqrt((self.input_data.sample.size - 1) * self.explained_variance) - svals.attrs.update(self.explained_variance.attrs) - svals.name = "singular_values" - return svals - - def compute(self, verbose=False): - super().compute(verbose) - - if verbose: - with ProgressBar(): - self._explained_variance = self.explained_variance.compute() - self._total_variance = self.total_variance.compute() - self._idx_modes_sorted = self.idx_modes_sorted.compute() - else: - self._explained_variance = self.explained_variance.compute() - self._total_variance = self.total_variance.compute() - self._idx_modes_sorted = self.idx_modes_sorted.compute() - - def set_attrs(self, attrs: dict): - """Set the attributes of the results.""" - super().set_attrs(attrs) - - explained_variance = self._sanity_check(self._explained_variance) - total_variance = self._sanity_check(self._total_variance) - idx_modes_sorted = self._sanity_check(self._idx_modes_sorted) - - explained_variance.attrs.update(attrs) - total_variance.attrs.update(attrs) - idx_modes_sorted.attrs.update(attrs) - - -class ComplexEOFDataContainer(EOFDataContainer): - """Container to store the results of a complex EOF analysis.""" - - @property - def components_amplitude(self) -> DataArray: - """Get the components amplitude.""" - comp_abs = abs(self.components) - comp_abs.name = "components_amplitude" - return comp_abs - - @property - def components_phase(self) -> DataArray: - """Get the components phase.""" - comp_phase = xr.apply_ufunc( - np.angle, self.components, dask="allowed", keep_attrs=True - ) - comp_phase.name = "components_phase" - return comp_phase - - @property - def scores_amplitude(self) -> DataArray: - """Get the scores amplitude.""" - score_abs = abs(self.scores) - score_abs.name = "scores_amplitude" - return score_abs - - @property - def scores_phase(self) -> DataArray: - """Get the scores phase.""" - score_phase = xr.apply_ufunc( - np.angle, self.scores, dask="allowed", keep_attrs=True - ) - score_phase.name = "scores_phase" - return score_phase diff --git a/xeofs/data_container/eof_rotator_data_container.py b/xeofs/data_container/eof_rotator_data_container.py deleted file mode 100644 index 2b18d4d..0000000 --- a/xeofs/data_container/eof_rotator_data_container.py +++ /dev/null @@ -1,120 +0,0 @@ -from abc import abstractmethod -from typing import TypeVar, Optional - -import numpy as np -import xarray as xr -from dask.diagnostics.progress import ProgressBar - -from .eof_data_container import EOFDataContainer, ComplexEOFDataContainer -from ..utils.data_types import DataArray - - -class EOFRotatorDataContainer(EOFDataContainer): - """Container to store the results of a rotated EOF analysis.""" - - def __init__(self): - super().__init__() - self._rotation_matrix: Optional[DataArray] = None - self._phi_matrix: Optional[DataArray] = None - self._modes_sign: Optional[DataArray] = None - - def set_data( - self, - input_data: DataArray, - components: DataArray, - scores: DataArray, - explained_variance: DataArray, - total_variance: DataArray, - idx_modes_sorted: DataArray, - modes_sign: DataArray, - rotation_matrix: DataArray, - phi_matrix: DataArray, - ): - super().set_data( - input_data=input_data, - components=components, - scores=scores, - explained_variance=explained_variance, - total_variance=total_variance, - idx_modes_sorted=idx_modes_sorted, - ) - - self._verify_dims(rotation_matrix, ("mode_m", "mode_n")) - self._rotation_matrix = rotation_matrix - self._rotation_matrix.name = "rotation_matrix" - - self._verify_dims(phi_matrix, ("mode_m", "mode_n")) - self._phi_matrix = phi_matrix - self._phi_matrix.name = "phi_matrix" - - self._verify_dims(modes_sign, ("mode",)) - self._modes_sign = modes_sign - self._modes_sign.name = "modes_sign" - - @property - def rotation_matrix(self) -> DataArray: - """Get the rotation matrix.""" - rotation_matrix = super()._sanity_check(self._rotation_matrix) - return rotation_matrix - - @property - def phi_matrix(self) -> DataArray: - """Get the phi matrix.""" - phi_matrix = super()._sanity_check(self._phi_matrix) - return phi_matrix - - @property - def modes_sign(self) -> DataArray: - """Get the modes sign.""" - modes_sign = super()._sanity_check(self._modes_sign) - return modes_sign - - def compute(self, verbose: bool = False): - super().compute(verbose) - - if verbose: - with ProgressBar(): - self._rotation_matrix = self.rotation_matrix.compute() - self._phi_matrix = self.phi_matrix.compute() - self._modes_sign = self.modes_sign.compute() - else: - self._rotation_matrix = self.rotation_matrix.compute() - self._phi_matrix = self.phi_matrix.compute() - self._modes_sign = self.modes_sign.compute() - - def set_attrs(self, attrs: dict): - super().set_attrs(attrs) - self.rotation_matrix.attrs.update(attrs) - self.phi_matrix.attrs.update(attrs) - self.modes_sign.attrs.update(attrs) - - -class ComplexEOFRotatorDataContainer(EOFRotatorDataContainer, ComplexEOFDataContainer): - """Container to store the results of a complex rotated EOF analysis.""" - - def __init__(self): - super(ComplexEOFRotatorDataContainer, self).__init__() - - def set_data( - self, - input_data: DataArray, - components: DataArray, - scores: DataArray, - explained_variance: DataArray, - total_variance: DataArray, - idx_modes_sorted: DataArray, - rotation_matrix: DataArray, - phi_matrix: DataArray, - modes_sign: DataArray, - ): - super().set_data( - input_data=input_data, - components=components, - scores=scores, - explained_variance=explained_variance, - total_variance=total_variance, - idx_modes_sorted=idx_modes_sorted, - rotation_matrix=rotation_matrix, - phi_matrix=phi_matrix, - modes_sign=modes_sign, - ) diff --git a/xeofs/data_container/mca_data_container.py b/xeofs/data_container/mca_data_container.py deleted file mode 100644 index 68fe1f5..0000000 --- a/xeofs/data_container/mca_data_container.py +++ /dev/null @@ -1,227 +0,0 @@ -from typing import Optional - -import numpy as np -import xarray as xr -from dask.diagnostics.progress import ProgressBar - -from ._base_cross_model_data_container import _BaseCrossModelDataContainer -from ..utils.data_types import DataArray - - -class MCADataContainer(_BaseCrossModelDataContainer): - """Container to store the results of a MCA.""" - - def __init__(self): - super().__init__() - self._squared_covariance: Optional[DataArray] = None - self._total_squared_covariance: Optional[DataArray] = None - self._idx_modes_sorted: Optional[DataArray] = None - self._norm1: Optional[DataArray] = None - self._norm2: Optional[DataArray] = None - - def set_data( - self, - input_data1: DataArray, - input_data2: DataArray, - components1: DataArray, - components2: DataArray, - scores1: DataArray, - scores2: DataArray, - squared_covariance: DataArray, - total_squared_covariance: DataArray, - idx_modes_sorted: DataArray, - norm1: DataArray, - norm2: DataArray, - ): - super().set_data( - input_data1=input_data1, - input_data2=input_data2, - components1=components1, - components2=components2, - scores1=scores1, - scores2=scores2, - ) - - self._verify_dims(squared_covariance, ("mode",)) - self._squared_covariance = squared_covariance - self._squared_covariance.name = "squared_covariance" - - self._total_squared_covariance = total_squared_covariance - self._total_squared_covariance.name = "total_squared_covariance" - - self._verify_dims(idx_modes_sorted, ("mode",)) - self._idx_modes_sorted = idx_modes_sorted - self._idx_modes_sorted.name = "idx_modes_sorted" - - self._verify_dims(norm1, ("mode",)) - self._norm1 = norm1 - self._norm1.name = "left_norm" - - self._verify_dims(norm2, ("mode",)) - self._norm2 = norm2 - self._norm2.name = "right_norm" - - @property - def total_squared_covariance(self) -> DataArray: - """Get the total squared covariance.""" - tsc = super()._sanity_check(self._total_squared_covariance) - return tsc - - @property - def squared_covariance(self) -> DataArray: - """Get the squared covariance.""" - sc = super()._sanity_check(self._squared_covariance) - return sc - - @property - def squared_covariance_fraction(self) -> DataArray: - """Get the squared covariance fraction (SCF).""" - scf = self.squared_covariance / self.total_squared_covariance - scf.attrs.update(self.squared_covariance.attrs) - scf.name = "squared_covariance_fraction" - return scf - - @property - def norm1(self) -> DataArray: - """Get the norm of the left scores.""" - norm1 = super()._sanity_check(self._norm1) - return norm1 - - @property - def norm2(self) -> DataArray: - """Get the norm of the right scores.""" - norm2 = super()._sanity_check(self._norm2) - return norm2 - - @property - def idx_modes_sorted(self) -> DataArray: - """Get the indices of the modes sorted by the squared covariance.""" - idx_modes_sorted = super()._sanity_check(self._idx_modes_sorted) - return idx_modes_sorted - - @property - def singular_values(self) -> DataArray: - """Get the singular values.""" - singular_values = xr.apply_ufunc( - np.sqrt, - self.squared_covariance, - dask="allowed", - vectorize=False, - keep_attrs=True, - ) - singular_values.name = "singular_values" - return singular_values - - @property - def total_covariance(self) -> DataArray: - """Get the total covariance. - - This measure follows the defintion of Cheng and Dunkerton (1995). - Note that this measure is not an invariant in MCA. - - """ - tot_cov = self.singular_values.sum() - tot_cov.attrs.update(self.singular_values.attrs) - tot_cov.name = "total_covariance" - return tot_cov - - @property - def covariance_fraction(self) -> DataArray: - """Get the covariance fraction (CF). - - This measure follows the defintion of Cheng and Dunkerton (1995). - Note that this measure is not an invariant in MCA. - - """ - cov_frac = self.singular_values / self.total_covariance - cov_frac.attrs.update(self.singular_values.attrs) - cov_frac.name = "covariance_fraction" - return cov_frac - - def compute(self, verbose=False): - super().compute(verbose) - - if verbose: - with ProgressBar(): - self._total_squared_covariance = self.total_squared_covariance.compute() - self._squared_covariance = self.squared_covariance.compute() - self._norm1 = self.norm1.compute() - self._norm2 = self.norm2.compute() - else: - self._total_squared_covariance = self.total_squared_covariance.compute() - self._squared_covariance = self.squared_covariance.compute() - self._norm1 = self.norm1.compute() - self._norm2 = self.norm2.compute() - - def set_attrs(self, attrs: dict): - super().set_attrs(attrs) - - total_squared_covariance = super()._sanity_check(self._total_squared_covariance) - squared_covariance = super()._sanity_check(self._squared_covariance) - norm1 = super()._sanity_check(self._norm1) - norm2 = super()._sanity_check(self._norm2) - - total_squared_covariance.attrs.update(attrs) - squared_covariance.attrs.update(attrs) - norm1.attrs.update(attrs) - norm2.attrs.update(attrs) - - -class ComplexMCADataContainer(MCADataContainer): - """Container that holds the data related to a Complex MCA model.""" - - @property - def components_amplitude1(self) -> DataArray: - """Get the component amplitudes of the left field.""" - comp_amps1 = abs(self.components1) - comp_amps1.name = "left_components_amplitude" - return comp_amps1 - - @property - def components_amplitude2(self) -> DataArray: - """Get the component amplitudes of the right field.""" - comp_amps2 = abs(self.components2) - comp_amps2.name = "right_components_amplitude" - return comp_amps2 - - @property - def components_phase1(self) -> DataArray: - """Get the component phases of the left field.""" - comp_phs1 = xr.apply_ufunc(np.angle, self.components1, keep_attrs=True) - comp_phs1.name = "left_components_phase" - return comp_phs1 - - @property - def components_phase2(self) -> DataArray: - """Get the component phases of the right field.""" - comp_phs2 = xr.apply_ufunc(np.angle, self._components2, keep_attrs=True) - comp_phs2.name = "right_components_phase" - return comp_phs2 - - @property - def scores_amplitude1(self) -> DataArray: - """Get the scores amplitudes of the left field.""" - scores_amps1 = abs(self.scores1) - scores_amps1.name = "left_scores_amplitude" - return scores_amps1 - - @property - def scores_amplitude2(self) -> DataArray: - """Get the scores amplitudes of the right field.""" - scores_amps2 = abs(self.scores2) - scores_amps2.name = "right_scores_amplitude" - return scores_amps2 - - @property - def scores_phase1(self) -> DataArray: - """Get the scores phases of the left field.""" - scores_phs1 = xr.apply_ufunc(np.angle, self.scores1, keep_attrs=True) - scores_phs1.name = "left_scores_phase" - return scores_phs1 - - @property - def scores_phase2(self) -> DataArray: - """Get the scores phases of the right field.""" - scores_phs2 = xr.apply_ufunc(np.angle, self.scores2, keep_attrs=True) - scores_phs2.name = "right_scores_phase" - return scores_phs2 diff --git a/xeofs/data_container/mca_rotator_data_container.py b/xeofs/data_container/mca_rotator_data_container.py deleted file mode 100644 index fa051f2..0000000 --- a/xeofs/data_container/mca_rotator_data_container.py +++ /dev/null @@ -1,148 +0,0 @@ -from typing import Optional - -import numpy as np -import xarray as xr -from dask.diagnostics.progress import ProgressBar - -from xeofs.utils.data_types import DataArray - -from .mca_data_container import MCADataContainer, ComplexMCADataContainer -from ..utils.data_types import DataArray - - -class MCARotatorDataContainer(MCADataContainer): - """Container that holds the data related to a rotated MCA model.""" - - def __init__(self): - super().__init__() - self._rotation_matrix: Optional[DataArray] = None - self._phi_matrix: Optional[DataArray] = None - self._modes_sign: Optional[DataArray] = None - - def set_data( - self, - input_data1: DataArray, - input_data2: DataArray, - components1: DataArray, - components2: DataArray, - scores1: DataArray, - scores2: DataArray, - squared_covariance: DataArray, - total_squared_covariance: DataArray, - idx_modes_sorted: DataArray, - modes_sign: DataArray, - norm1: DataArray, - norm2: DataArray, - rotation_matrix: DataArray, - phi_matrix: DataArray, - ): - super().set_data( - input_data1=input_data1, - input_data2=input_data2, - components1=components1, - components2=components2, - scores1=scores1, - scores2=scores2, - squared_covariance=squared_covariance, - total_squared_covariance=total_squared_covariance, - idx_modes_sorted=idx_modes_sorted, - norm1=norm1, - norm2=norm2, - ) - - self._verify_dims(rotation_matrix, ("mode_m", "mode_n")) - self._rotation_matrix = rotation_matrix - self._rotation_matrix.name = "rotation_matrix" - - self._verify_dims(phi_matrix, ("mode_m", "mode_n")) - self._phi_matrix = phi_matrix - self._phi_matrix.name = "phi_matrix" - - self._verify_dims(modes_sign, ("mode",)) - self._modes_sign = modes_sign - self._modes_sign.name = "modes_sign" - - @property - def rotation_matrix(self) -> DataArray: - """Get the rotation matrix.""" - rotation_matrix = super()._sanity_check(self._rotation_matrix) - return rotation_matrix - - @property - def phi_matrix(self) -> DataArray: - """Get the phi matrix.""" - phi_matrix = super()._sanity_check(self._phi_matrix) - return phi_matrix - - @property - def modes_sign(self) -> DataArray: - """Get the mode signs.""" - modes_sign = super()._sanity_check(self._modes_sign) - return modes_sign - - def compute(self, verbose: bool = False): - """Compute the rotated MCA model.""" - super().compute(verbose=verbose) - - if verbose: - with ProgressBar(): - self._rotation_matrix = self.rotation_matrix.compute() - self._phi_matrix = self.phi_matrix.compute() - self._modes_sign = self.modes_sign.compute() - else: - self._rotation_matrix = self.rotation_matrix.compute() - self._phi_matrix = self.phi_matrix.compute() - self._modes_sign = self.modes_sign.compute() - - def set_attrs(self, attrs: dict): - """Set the attributes of the data container.""" - super().set_attrs(attrs) - - rotation_matrix = super()._sanity_check(self._rotation_matrix) - phi_matrix = super()._sanity_check(self._phi_matrix) - modes_sign = super()._sanity_check(self._modes_sign) - - rotation_matrix.attrs.update(attrs) - phi_matrix.attrs.update(attrs) - modes_sign.attrs.update(attrs) - - -class ComplexMCARotatorDataContainer(MCARotatorDataContainer, ComplexMCADataContainer): - """Container that holds the data related to a rotated complex MCA model.""" - - def __init__(self): - super(ComplexMCARotatorDataContainer, self).__init__() - - def set_data( - self, - input_data1: DataArray, - input_data2: DataArray, - components1: DataArray, - components2: DataArray, - scores1: DataArray, - scores2: DataArray, - squared_covariance: DataArray, - total_squared_covariance: DataArray, - idx_modes_sorted: DataArray, - modes_sign: DataArray, - norm1: DataArray, - norm2: DataArray, - rotation_matrix: DataArray, - phi_matrix: DataArray, - ): - super().set_data( - input_data1=input_data1, - input_data2=input_data2, - components1=components1, - components2=components2, - scores1=scores1, - scores2=scores2, - squared_covariance=squared_covariance, - total_squared_covariance=total_squared_covariance, - idx_modes_sorted=idx_modes_sorted, - modes_sign=modes_sign, - norm1=norm1, - norm2=norm2, - rotation_matrix=rotation_matrix, - phi_matrix=phi_matrix, - ) diff --git a/xeofs/data_container/opa_data_container.py b/xeofs/data_container/opa_data_container.py deleted file mode 100644 index 7f6930e..0000000 --- a/xeofs/data_container/opa_data_container.py +++ /dev/null @@ -1,78 +0,0 @@ -from typing import Optional - -import numpy as np -import xarray as xr -from dask.diagnostics.progress import ProgressBar - -from xeofs.utils.data_types import DataArray - -from ._base_model_data_container import _BaseModelDataContainer -from ..utils.data_types import DataArray - - -class OPADataContainer(_BaseModelDataContainer): - """Container to store the results of an Optimal Persistence Analysis (OPA).""" - - def __init__(self): - super().__init__() - self._filter_patterns: Optional[DataArray] = None - self._decorrelation_time: Optional[DataArray] = None - - def set_data( - self, - input_data: DataArray, - components: DataArray, - scores: DataArray, - filter_patterns: DataArray, - decorrelation_time: DataArray, - ): - super().set_data(input_data=input_data, components=components, scores=scores) - - self._verify_dims(decorrelation_time, ("mode",)) - self._decorrelation_time = decorrelation_time - self._decorrelation_time.name = "decorrelation_time" - - self._verify_dims(filter_patterns, ("feature", "mode")) - self._filter_patterns = filter_patterns - self._filter_patterns.name = "filter_patterns" - - @property - def components(self) -> DataArray: - comps = super().components - comps.name = "optimal_persistence_pattern" - return comps - - @property - def decorrelation_time(self) -> DataArray: - """Get the decorrelation time.""" - decorr = super()._sanity_check(self._decorrelation_time) - decorr.name = "decorrelation_time" - return decorr - - @property - def filter_patterns(self) -> DataArray: - """Get the filter patterns.""" - filter_patterns = super()._sanity_check(self._filter_patterns) - filter_patterns.name = "filter_patterns" - return filter_patterns - - def compute(self, verbose=False): - super().compute(verbose) - - if verbose: - with ProgressBar(): - self._filter_patterns = self.filter_patterns.compute() - self._decorrelation_time = self.decorrelation_time.compute() - else: - self._filter_patterns = self.filter_patterns.compute() - self._decorrelation_time = self.decorrelation_time.compute() - - def set_attrs(self, attrs: dict): - """Set the attributes of the results.""" - super().set_attrs(attrs) - - filter_patterns = self._sanity_check(self._filter_patterns) - decorrelation_time = self._sanity_check(self._decorrelation_time) - - filter_patterns.attrs.update(attrs) - decorrelation_time.attrs.update(attrs) diff --git a/xeofs/models/__init__.py b/xeofs/models/__init__.py index d1d954e..49ee8ec 100644 --- a/xeofs/models/__init__.py +++ b/xeofs/models/__init__.py @@ -1,6 +1,26 @@ from .eof import EOF, ComplexEOF from .mca import MCA, ComplexMCA +from .eeof import ExtendedEOF from .opa import OPA +from .gwpca import GWPCA from .rotator_factory import RotatorFactory from .eof_rotator import EOFRotator, ComplexEOFRotator from .mca_rotator import MCARotator, ComplexMCARotator +from .cca import CCA + + +__all__ = [ + "EOF", + "ComplexEOF", + "ExtendedEOF", + "EOFRotator", + "ComplexEOFRotator", + "OPA", + "GWPCA", + "MCA", + "ComplexMCA", + "MCARotator", + "ComplexMCARotator", + "CCA", + "RotatorFactory", +] diff --git a/xeofs/models/_base_cross_model.py b/xeofs/models/_base_cross_model.py index a18b8a8..5075e6c 100644 --- a/xeofs/models/_base_cross_model.py +++ b/xeofs/models/_base_cross_model.py @@ -1,13 +1,14 @@ -from typing import Tuple, Hashable, Sequence, Dict, Any, Optional +from typing import Tuple, Hashable, Sequence, Dict, Optional, List +from typing_extensions import Self from abc import ABC, abstractmethod from datetime import datetime -from dask.diagnostics.progress import ProgressBar - from .eof import EOF from ..preprocessing.preprocessor import Preprocessor -from ..data_container import _BaseCrossModelDataContainer -from ..utils.data_types import AnyDataObject, DataArray +from ..data_container import DataContainer +from ..utils.data_types import DataObject, DataArray +from ..utils.xarray_utils import convert_to_dim_type +from ..utils.sanity_checks import validate_input_type from .._version import __version__ @@ -19,14 +20,20 @@ class _BaseCrossModel(ABC): ------------- n_modes: int, default=10 Number of modes to calculate. + center: bool, default=True + Whether to center the input data. standardize: bool, default=False Whether to standardize the input data. use_coslat: bool, default=False Whether to use cosine of latitude for scaling. - use_weights: bool, default=False - Whether to use weights. n_pca_modes: int, default=None Number of PCA modes to calculate. + compute : bool, default=True + Whether to compute the decomposition immediately. + sample_name: str, default="sample" + Name of the new sample dimension. + feature_name: str, default="feature" + Name of the new feature dimension. solver: {"auto", "full", "randomized"}, default="auto" Solver to use for the SVD computation. solver_kwargs: dict, default={} @@ -37,27 +44,49 @@ class _BaseCrossModel(ABC): def __init__( self, n_modes=10, + center=True, standardize=False, use_coslat=False, - use_weights=False, n_pca_modes=None, + compute=True, + sample_name="sample", + feature_name="feature", solver="auto", + random_state=None, solver_kwargs={}, ): + self.n_modes = n_modes + self.sample_name = sample_name + self.feature_name = feature_name + # Define model parameters self._params = { "n_modes": n_modes, + "center": center, "standardize": standardize, "use_coslat": use_coslat, - "use_weights": use_weights, "n_pca_modes": n_pca_modes, + "compute": compute, + "sample_name": sample_name, + "feature_name": feature_name, "solver": solver, + "random_state": random_state, } + self._solver_kwargs = solver_kwargs + self._solver_kwargs.update( + {"solver": solver, "random_state": random_state, "compute": compute} + ) + self._preprocessor_kwargs = { + "sample_name": sample_name, + "feature_name": feature_name, + "with_center": center, + "with_std": standardize, + "with_coslat": use_coslat, + } # Define analysis-relevant meta data self.attrs = {"model": "BaseCrossModel"} - self.attrs.update(self._params) self.attrs.update( { "software": "xeofs", @@ -65,98 +94,156 @@ def __init__( "date": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), } ) + self.attrs.update(self._params) # Initialize preprocessors to scale and stack left (1) and right (2) data - self.preprocessor1 = Preprocessor( - with_std=standardize, - with_coslat=use_coslat, - with_weights=use_weights, - ) - self.preprocessor2 = Preprocessor( - with_std=standardize, - with_coslat=use_coslat, - with_weights=use_weights, - ) - # Initialize the data container only to avoid type errors - # The actual data container will be initialized in respective subclasses - self.data: _BaseCrossModelDataContainer = _BaseCrossModelDataContainer() + self.preprocessor1 = Preprocessor(**self._preprocessor_kwargs) + self.preprocessor2 = Preprocessor(**self._preprocessor_kwargs) + + # Initialize the data container that stores the results + self.data = DataContainer() # Initialize PCA objects - self.pca1 = EOF(n_modes=n_pca_modes) if n_pca_modes else None - self.pca2 = EOF(n_modes=n_pca_modes) if n_pca_modes else None + self.pca1 = ( + EOF(n_modes=n_pca_modes, compute=self._params["compute"]) + if n_pca_modes + else None + ) + self.pca2 = ( + EOF(n_modes=n_pca_modes, compute=self._params["compute"]) + if n_pca_modes + else None + ) - @abstractmethod def fit( self, - data1: AnyDataObject, - data2: AnyDataObject, + data1: DataObject, + data2: DataObject, dim: Hashable | Sequence[Hashable], - weights1: Optional[AnyDataObject] = None, - weights2: Optional[AnyDataObject] = None, - ): + weights1: Optional[DataObject] = None, + weights2: Optional[DataObject] = None, + ) -> Self: """ - Abstract method to fit the model. + Fit the model to the data. Parameters ---------- - data1: DataArray | Dataset | list of DataArray + data1: DataArray | Dataset | List[DataArray] Left input data. - data2: DataArray | Dataset | list of DataArray + data2: DataArray | Dataset | List[DataArray] Right input data. dim: Hashable | Sequence[Hashable] Define the sample dimensions. The remaining dimensions will be treated as feature dimensions. - weights1: Optional[AnyDataObject] + weights1: Optional[DataObject] Weights to be applied to the left input data. - weights2: Optional[AnyDataObject]=None + weights2: Optional[DataObject] Weights to be applied to the right input data. """ - # Here follows the implementation to fit the model - # Typically you want to start by calling - # self.preprocessor1.fit_transform(data1, dim, weights) - # self.preprocessor2.fit_transform(data2, dim, weights) + validate_input_type(data1) + validate_input_type(data2) + if weights1 is not None: + validate_input_type(weights1) + if weights2 is not None: + validate_input_type(weights2) + + self.sample_dims = convert_to_dim_type(dim) + # Preprocess data1 + data1 = self.preprocessor1.fit_transform(data1, self.sample_dims, weights1) + # Preprocess data2 + data2 = self.preprocessor2.fit_transform(data2, self.sample_dims, weights2) + + return self._fit_algorithm(data1, data2) + + def transform( + self, data1: Optional[DataObject] = None, data2: Optional[DataObject] = None + ) -> Sequence[DataArray]: + """ + Abstract method to transform the data. + + + """ + if data1 is None and data2 is None: + raise ValueError("Either data1 or data2 must be provided.") + + if data1 is not None: + validate_input_type(data1) + # Preprocess data1 + data1 = self.preprocessor1.transform(data1) + if data2 is not None: + validate_input_type(data2) + # Preprocess data2 + data2 = self.preprocessor2.transform(data2) + + return self._transform_algorithm(data1, data2) + + @abstractmethod + def _fit_algorithm(self, data1: DataArray, data2: DataArray) -> Self: + """ + Fit the model to the preprocessed data. This method needs to be implemented in the respective + subclass. + + Parameters + ---------- + data1, data2: DataArray + Preprocessed input data of two dimensions: (`sample_name`, `feature_name`) + + """ raise NotImplementedError @abstractmethod - def transform( - self, data1: Optional[AnyDataObject], data2: Optional[AnyDataObject] - ) -> Tuple[DataArray, DataArray]: + def _transform_algorithm( + self, data1: Optional[DataArray] = None, data2: Optional[DataArray] = None + ) -> Sequence[DataArray]: + """ + Transform the preprocessed data. This method needs to be implemented in the respective + subclass. + + Parameters + ---------- + data1, data2: DataArray + Preprocessed input data of two dimensions: (`sample_name`, `feature_name`) + + """ raise NotImplementedError @abstractmethod - def inverse_transform(self, mode) -> Tuple[AnyDataObject, AnyDataObject]: + def inverse_transform(self, mode) -> Tuple[DataObject, DataObject]: raise NotImplementedError - def components(self) -> Tuple[AnyDataObject, AnyDataObject]: + def components(self) -> Tuple[DataObject, DataObject]: """Get the components.""" - comps1 = self.data.components1 - comps2 = self.data.components2 + comps1 = self.data["components1"] + comps2 = self.data["components2"] - components1: AnyDataObject = self.preprocessor1.inverse_transform_components( + components1: DataObject = self.preprocessor1.inverse_transform_components( comps1 ) - components2: AnyDataObject = self.preprocessor2.inverse_transform_components( + components2: DataObject = self.preprocessor2.inverse_transform_components( comps2 ) return components1, components2 def scores(self) -> Tuple[DataArray, DataArray]: """Get the scores.""" - scores1 = self.data.scores1 - scores2 = self.data.scores2 + scores1 = self.data["scores1"] + scores2 = self.data["scores2"] scores1: DataArray = self.preprocessor1.inverse_transform_scores(scores1) scores2: DataArray = self.preprocessor2.inverse_transform_scores(scores2) return scores1, scores2 def compute(self, verbose: bool = False): - """Compute the results.""" - if verbose: - with ProgressBar(): - self.data.compute() - else: - self.data.compute() + """Compute and load delayed model results. + + Parameters + ---------- + verbose : bool + Whether or not to provide additional information about the computing progress. + + """ + self.data.compute(verbose=verbose) def get_params(self) -> Dict: """Get the model parameters.""" diff --git a/xeofs/models/_base_model.py b/xeofs/models/_base_model.py index 8963f40..bbf0639 100644 --- a/xeofs/models/_base_model.py +++ b/xeofs/models/_base_model.py @@ -1,17 +1,31 @@ import warnings -from typing import Optional, Sequence, Hashable, Dict, Any +from typing import Optional, Sequence, Hashable, Dict, Any, List, TypeVar, Tuple +from typing_extensions import Self from abc import ABC, abstractmethod from datetime import datetime -from dask.diagnostics.progress import ProgressBar + +import numpy as np +import xarray as xr from ..preprocessing.preprocessor import Preprocessor -from ..data_container import _BaseModelDataContainer -from ..utils.data_types import AnyDataObject, DataArray +from ..data_container import DataContainer +from ..utils.data_types import DataObject, Data, DataArray, DataSet, DataList, Dims +from ..utils.sanity_checks import validate_input_type +from ..utils.xarray_utils import ( + convert_to_dim_type, + get_dims, + feature_ones_like, + convert_to_list, + process_parameter, + _check_parameter_number, +) from .._version import __version__ # Ignore warnings from numpy casting with additional coordinates warnings.filterwarnings("ignore", message=r"^invalid value encountered in cast*") +xr.set_options(keep_attrs=True) + class _BaseModel(ABC): """ @@ -21,12 +35,20 @@ class _BaseModel(ABC): ---------- n_modes: int, default=10 Number of modes to calculate. + center: bool, default=True + Whether to center the input data. standardize: bool, default=False Whether to standardize the input data. use_coslat: bool, default=False Whether to use cosine of latitude for scaling. - use_weights: bool, default=False - Whether to use weights. + sample_name: str, default="sample" + Name of the sample dimension. + feature_name: str, default="feature" + Name of the feature dimension. + compute: bool, default=True + Whether to compute the decomposition immediately. This is recommended + if the SVD result for the first ``n_modes`` can be accommodated in memory, as it + boosts computational efficiency compared to deferring the computation. solver: {"auto", "full", "randomized"}, default="auto" Solver to use for the SVD computation. solver_kwargs: dict, default={} @@ -37,25 +59,39 @@ class _BaseModel(ABC): def __init__( self, n_modes=10, + center=True, standardize=False, use_coslat=False, - use_weights=False, + sample_name="sample", + feature_name="feature", + compute=True, + random_state=None, solver="auto", solver_kwargs={}, ): + self.n_modes = n_modes + self.sample_name = sample_name + self.feature_name = feature_name + # Define model parameters self._params = { "n_modes": n_modes, + "center": center, "standardize": standardize, "use_coslat": use_coslat, - "use_weights": use_weights, + "sample_name": sample_name, + "feature_name": feature_name, + "random_state": random_state, + "compute": compute, "solver": solver, } self._solver_kwargs = solver_kwargs + self._solver_kwargs.update( + {"solver": solver, "random_state": random_state, "compute": compute} + ) # Define analysis-relevant meta data self.attrs = {"model": "BaseModel"} - self.attrs.update(self._params) self.attrs.update( { "software": "xeofs", @@ -63,28 +99,31 @@ def __init__( "date": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), } ) + self.attrs.update(self._params) # Initialize the Preprocessor to scale and stack the data self.preprocessor = Preprocessor( - with_std=standardize, with_coslat=use_coslat, with_weights=use_weights + sample_name=sample_name, + feature_name=feature_name, + with_center=center, + with_std=standardize, + with_coslat=use_coslat, ) - # Initialize the data container only to avoid type errors - # The actual data container will be initialized in respective subclasses - self.data: _BaseModelDataContainer = _BaseModelDataContainer() + # Initialize the data container that stores the results + self.data = DataContainer() - @abstractmethod def fit( self, - data: AnyDataObject, + X: List[Data] | Data, dim: Sequence[Hashable] | Hashable, - weights: Optional[AnyDataObject] = None, - ): + weights: Optional[List[Data] | Data] = None, + ) -> Self: """ Fit the model to the input data. Parameters ---------- - data: DataArray | Dataset | List[DataArray] + X: DataArray | Dataset | List[DataArray] Input data. dim: Sequence[Hashable] | Hashable Specify the sample dimensions. The remaining dimensions @@ -93,27 +132,169 @@ def fit( Weighting factors for the input data. """ - # Here follows the implementation to fit the model - # Typically you want to start by calling the Preprocessor first: - # self.preprocessor.fit_transform(data, dim, weights) + # Check for invalid types + validate_input_type(X) + if weights is not None: + validate_input_type(weights) + + self.sample_dims = convert_to_dim_type(dim) + + # Preprocess the data & transform to 2D + data2D: DataArray = self.preprocessor.fit_transform( + X, self.sample_dims, weights + ) + + return self._fit_algorithm(data2D) + + @abstractmethod + def _fit_algorithm(self, data: DataArray) -> Self: + """Fit the model to the input data assuming a 2D DataArray. + + Parameters + ---------- + data: DataArray + Input data with dimensions (sample_name, feature_name) + + Returns + ------- + self: Self + The fitted model. + + """ raise NotImplementedError + def transform(self, data: List[Data] | Data, normalized=True) -> DataArray: + """Project data onto the components. + + Parameters + ---------- + data: DataArray | Dataset | List[DataArray] + Data to be transformed. + normalized: bool, default=True + Whether to normalize the scores by the L2 norm. + + Returns + ------- + projections: DataArray + Projections of the data onto the components. + + """ + validate_input_type(data) + + data2D = self.preprocessor.transform(data) + data2D = self._transform_algorithm(data2D) + if normalized: + data2D = data2D / self.data["norms"] + data2D.name = "scores" + return self.preprocessor.inverse_transform_scores(data2D) + @abstractmethod - def transform(self): + def _transform_algorithm(self, data: DataArray) -> DataArray: + """Project data onto the components. + + Parameters + ---------- + data: DataArray + Input data with dimensions (sample_name, feature_name) + + Returns + ------- + projections: DataArray + Projections of the data onto the components. + + """ raise NotImplementedError + def fit_transform( + self, + data: List[Data] | Data, + dim: Sequence[Hashable] | Hashable, + weights: Optional[List[Data] | Data] = None, + **kwargs + ) -> DataArray: + """Fit the model to the input data and project the data onto the components. + + Parameters + ---------- + data: DataObject + Input data. + dim: Sequence[Hashable] | Hashable + Specify the sample dimensions. The remaining dimensions + will be treated as feature dimensions. + weights: Optional[DataObject] + Weighting factors for the input data. + **kwargs + Additional keyword arguments to pass to the transform method. + + Returns + ------- + projections: DataArray + Projections of the data onto the components. + + """ + return self.fit(data, dim, weights).transform(data, **kwargs) + + def inverse_transform(self, mode) -> DataObject: + """Reconstruct the original data from transformed data. + + Parameters + ---------- + mode: integer, a list of integers, or a slice object. + The mode(s) used to reconstruct the data. If a scalar is given, + the data will be reconstructed using the given mode. If a slice + is given, the data will be reconstructed using the modes in the + given slice. If a list of integers is given, the data will be reconstructed + using the modes in the given list. + + Returns + ------- + data: DataArray | Dataset | List[DataArray] + Reconstructed data. + + """ + data_reconstructed = self._inverse_transform_algorithm(mode) + return self.preprocessor.inverse_transform_data(data_reconstructed) + @abstractmethod - def inverse_transform(self, mode): + def _inverse_transform_algorithm(self, mode) -> DataArray: + """Reconstruct the original data from transformed data. + + Parameters + ---------- + mode: integer, a list of integers, or a slice object. + The mode(s) used to reconstruct the data. If a scalar is given, + the data will be reconstructed using the given mode. If a slice + is given, the data will be reconstructed using the modes in the + given slice. If a list of integers is given, the data will be reconstructed + using the modes in the given list. + + Returns + ------- + data: DataArray + Reconstructed 2D data with dimensions (sample_name, feature_name) + + """ raise NotImplementedError - def components(self) -> AnyDataObject: + def components(self) -> DataObject: """Get the components.""" - components = self.data.components + components = self.data["components"] return self.preprocessor.inverse_transform_components(components) - def scores(self) -> DataArray: - """Get the scores.""" - scores = self.data.scores + def scores(self, normalized=True) -> DataArray: + """Get the scores. + + Parameters + ---------- + normalized: bool, default=True + Whether to normalize the scores by the L2 norm. + """ + scores = self.data["scores"].copy() + if normalized: + attrs = scores.attrs.copy() + scores = scores / self.data["norms"] + scores.attrs.update(attrs) + scores.name = "scores" return self.preprocessor.inverse_transform_scores(scores) def compute(self, verbose: bool = False): @@ -125,11 +306,7 @@ def compute(self, verbose: bool = False): Whether or not to provide additional information about the computing progress. """ - if verbose: - with ProgressBar(): - self.data.compute() - else: - self.data.compute() + self.data.compute(verbose=verbose) def get_params(self) -> Dict[str, Any]: """Get the model parameters.""" diff --git a/xeofs/models/cca.py b/xeofs/models/cca.py index 1955887..0d3029e 100644 --- a/xeofs/models/cca.py +++ b/xeofs/models/cca.py @@ -1,743 +1,701 @@ -# from typing import Tuple - -# import numpy as np -# import xarray as xr - -# from ._base_cross_model import _BaseCrossModel -# from .decomposer import CrossDecomposer -# from ..utils.data_types import AnyDataObject, DataArray -# from ..data_container.mca_data_container import ( -# MCADataContainer, -# ComplexMCADataContainer, -# ) -# from ..utils.statistics import pearson_correlation -# from ..utils.xarray_utils import hilbert_transform - - -# class MCA(_BaseCrossModel): -# """Maximum Covariance Analyis (MCA). - -# MCA is a statistical method that finds patterns of maximum covariance between two datasets. - -# Parameters -# ---------- -# n_modes: int, default=10 -# Number of modes to calculate. -# standardize: bool, default=False -# Whether to standardize the input data. -# use_coslat: bool, default=False -# Whether to use cosine of latitude for scaling. -# use_weights: bool, default=False -# Whether to use additional weights. -# solver_kwargs: dict, default={} -# Additional keyword arguments passed to the SVD solver. -# n_pca_modes: int, default=None -# The number of principal components to retain during the PCA preprocessing -# step applied to both data sets prior to executing MCA. -# If set to None, PCA preprocessing will be bypassed, and the MCA will be performed on the original datasets. -# Specifying an integer value greater than 0 for `n_pca_modes` will trigger the PCA preprocessing, retaining -# only the specified number of principal components. This reduction in dimensionality can be especially beneficial -# when dealing with high-dimensional data, where computing the cross-covariance matrix can become computationally -# intensive or in scenarios where multicollinearity is a concern. - -# Notes -# ----- -# MCA is similar to Principal Component Analysis (PCA) and Canonical Correlation Analysis (CCA), -# but while PCA finds modes of maximum variance and CCA finds modes of maximum correlation, -# MCA finds modes of maximum covariance. See [1]_ [2]_ for more details. - -# References -# ---------- -# .. [1] Bretherton, C., Smith, C., Wallace, J., 1992. An intercomparison of methods for finding coupled patterns in climate data. Journal of climate 5, 541–560. -# .. [2] Cherry, S., 1996. Singular value decomposition analysis and canonical correlation analysis. Journal of Climate 9, 2003–2009. - -# Examples -# -------- -# >>> model = MCA(n_modes=5, standardize=True) -# >>> model.fit(data1, data2) - -# """ - -# def __init__(self, solver_kwargs={}, **kwargs): -# super().__init__(solver_kwargs=solver_kwargs, **kwargs) -# self.attrs.update({"model": "MCA"}) - -# # Initialize the DataContainer to store the results -# self.data: MCADataContainer = MCADataContainer() - -# def fit( -# self, -# data1: AnyDataObject, -# data2: AnyDataObject, -# dim, -# weights1=None, -# weights2=None, -# ): -# # Preprocess the data -# data1_processed: DataArray = self.preprocessor1.fit_transform( -# data1, dim, weights1 -# ) -# data2_processed: DataArray = self.preprocessor2.fit_transform( -# data2, dim, weights2 -# ) - -# # Perform the decomposition of the cross-covariance matrix -# decomposer = CrossDecomposer( -# n_modes=self._params["n_modes"], **self._solver_kwargs -# ) - -# # Perform SVD on PCA-reduced data -# if (self.pca1 is not None) and (self.pca2 is not None): -# # Fit the PCA models -# self.pca1.fit(data1_processed, "sample") -# self.pca2.fit(data2_processed, "sample") -# # Get the PCA scores -# pca_scores1 = self.pca1.data.scores * self.pca1.data.singular_values -# pca_scores2 = self.pca2.data.scores * self.pca2.data.singular_values -# # Rename the dimensions to adhere to the CrossDecomposer API -# pca_scores1 = pca_scores1.rename({"mode": "feature"}) -# pca_scores2 = pca_scores2.rename({"mode": "feature"}) - -# # Perform the SVD -# decomposer.fit(pca_scores1, pca_scores2) -# V1 = decomposer.singular_vectors1_.rename({"feature": "core"}) -# V2 = decomposer.singular_vectors2_.rename({"feature": "core"}) - -# V1pre = self.pca1.data.components.rename({"mode": "core"}) -# V2pre = self.pca2.data.components.rename({"mode": "core"}) - -# # Compute the singular vectors based on the PCA eigenvectors -# singular_vectors1 = xr.dot(V1pre, V1, dims="core") -# singular_vectors2 = xr.dot(V2pre, V2, dims="core") - -# # Perform SVD directly on data -# else: -# decomposer.fit(data1_processed, data2_processed) -# singular_vectors1 = decomposer.singular_vectors1_ -# singular_vectors2 = decomposer.singular_vectors2_ - -# # Store the results -# singular_values = decomposer.singular_values_ - -# squared_covariance = singular_values**2 -# total_squared_covariance = decomposer.total_squared_covariance_ - -# norm1 = np.sqrt(singular_values) -# norm2 = np.sqrt(singular_values) - -# # Index of the sorted squared covariance -# idx_sorted_modes = squared_covariance.compute().argsort()[::-1] -# idx_sorted_modes.coords.update(squared_covariance.coords) - -# # Project the data onto the singular vectors -# scores1 = xr.dot(data1_processed, singular_vectors1, dims="feature") / norm1 -# scores2 = xr.dot(data2_processed, singular_vectors2, dims="feature") / norm2 - -# self.data.set_data( -# input_data1=data1_processed, -# input_data2=data2_processed, -# components1=singular_vectors1, -# components2=singular_vectors2, -# scores1=scores1, -# scores2=scores2, -# squared_covariance=squared_covariance, -# total_squared_covariance=total_squared_covariance, -# idx_modes_sorted=idx_sorted_modes, -# norm1=norm1, -# norm2=norm2, -# ) -# # Assign analysis-relevant meta data -# self.data.set_attrs(self.attrs) - -# def transform(self, **kwargs): -# """Project new unseen data onto the singular vectors. - -# Parameters -# ---------- -# data1: xr.DataArray or list of xarray.DataArray -# Left input data. Must be provided if `data2` is not provided. -# data2: xr.DataArray or list of xarray.DataArray -# Right input data. Must be provided if `data1` is not provided. - -# Returns -# ------- -# scores1: DataArray | Dataset | List[DataArray] -# Left scores. -# scores2: DataArray | Dataset | List[DataArray] -# Right scores. - -# """ -# results = [] -# if "data1" in kwargs.keys(): -# # Preprocess input data -# data1 = kwargs["data1"] -# data1 = self.preprocessor1.transform(data1) -# # Project data onto singular vectors -# comps1 = self.data.components1 -# norm1 = self.data.norm1 -# scores1 = xr.dot(data1, comps1) / norm1 -# # Inverse transform scores -# scores1 = self.preprocessor1.inverse_transform_scores(scores1) -# results.append(scores1) - -# if "data2" in kwargs.keys(): -# # Preprocess input data -# data2 = kwargs["data2"] -# data2 = self.preprocessor2.transform(data2) -# # Project data onto singular vectors -# comps2 = self.data.components2 -# norm2 = self.data.norm2 -# scores2 = xr.dot(data2, comps2) / norm2 -# # Inverse transform scores -# scores2 = self.preprocessor2.inverse_transform_scores(scores2) -# results.append(scores2) - -# return results - -# def inverse_transform(self, mode): -# """Reconstruct the original data from transformed data. - -# Parameters -# ---------- -# mode: scalars, slices or array of tick labels. -# The mode(s) used to reconstruct the data. If a scalar is given, -# the data will be reconstructed using the given mode. If a slice -# is given, the data will be reconstructed using the modes in the -# given slice. If a array is given, the data will be reconstructed -# using the modes in the given array. - -# Returns -# ------- -# Xrec1: DataArray | Dataset | List[DataArray] -# Reconstructed data of left field. -# Xrec2: DataArray | Dataset | List[DataArray] -# Reconstructed data of right field. - -# """ -# # Singular vectors -# comps1 = self.data.components1.sel(mode=mode) -# comps2 = self.data.components2.sel(mode=mode) - -# # Scores = projections -# scores1 = self.data.scores1.sel(mode=mode) -# scores2 = self.data.scores2.sel(mode=mode) - -# # Norms -# norm1 = self.data.norm1.sel(mode=mode) -# norm2 = self.data.norm2.sel(mode=mode) - -# # Reconstruct the data -# data1 = xr.dot(scores1, comps1.conj() * norm1, dims="mode") -# data2 = xr.dot(scores2, comps2.conj() * norm2, dims="mode") - -# # Enforce real output -# data1 = data1.real -# data2 = data2.real - -# # Unstack and rescale the data -# data1 = self.preprocessor1.inverse_transform_data(data1) -# data2 = self.preprocessor2.inverse_transform_data(data2) - -# return data1, data2 - -# def squared_covariance(self): -# """Get the squared covariance. - -# The squared covariance corresponds to the explained variance in PCA and is given by the -# squared singular values of the covariance matrix. - -# """ -# return self.data.squared_covariance - -# def squared_covariance_fraction(self): -# """Calculate the squared covariance fraction (SCF). - -# The SCF is a measure of the proportion of the total squared covariance that is explained by each mode `i`. It is computed -# as follows: - -# .. math:: -# SCF_i = \\frac{\\sigma_i^2}{\\sum_{i=1}^{m} \\sigma_i^2} - -# where `m` is the total number of modes and :math:`\\sigma_i` is the `ith` singular value of the covariance matrix. - -# """ -# return self.data.squared_covariance_fraction - -# def singular_values(self): -# """Get the singular values of the cross-covariance matrix.""" -# return self.data.singular_values - -# def covariance_fraction(self): -# """Get the covariance fraction (CF). - -# Cheng and Dunkerton (1995) define the CF as follows: - -# .. math:: -# CF_i = \\frac{\\sigma_i}{\\sum_{i=1}^{m} \\sigma_i} - -# where `m` is the total number of modes and :math:`\\sigma_i` is the `ith` singular value of the covariance matrix. - -# In this implementation the sum of singular values is estimated from the first `n` modes, therefore one should aim to -# retain as many modes as possible to get a good estimate of the covariance fraction. - -# Note -# ---- -# It is important to differentiate the CF from the squared covariance fraction (SCF). While the SCF is an invariant quantity in MCA, the CF is not. -# Therefore, the SCF is used to assess the relative importance of each mode. Cheng and Dunkerton (1995) [1]_ introduced the CF in the context of -# Varimax-rotated MCA to compare the relative importance of each mode before and after rotation. In the special case of both data fields in MCA being identical, -# the CF is equivalent to the explained variance ratio in EOF analysis. - -# References -# ---------- -# .. [1] Cheng, X., Dunkerton, T.J., 1995. Orthogonal Rotation of Spatial Patterns Derived from Singular Value Decomposition Analysis. J. Climate 8, 2631–2643. https://doi.org/10.1175/1520-0442(1995)008<2631:OROSPD>2.0.CO;2 - - -# """ -# # Check how sensitive the CF is to the number of modes -# svals = self.data.singular_values -# cf = svals[0] / svals.cumsum() -# change_per_mode = cf.shift({"mode": 1}) - cf -# change_in_cf_in_last_mode = change_per_mode.isel(mode=-1) -# if change_in_cf_in_last_mode > 0.001: -# print( -# f"Warning: CF is sensitive to the number of modes retained. Please increase `n_modes` for a better estimate." -# ) -# return self.data.covariance_fraction - -# def components(self): -# """Return the singular vectors of the left and right field. - -# Returns -# ------- -# components1: DataArray | Dataset | List[DataArray] -# Left components of the fitted model. -# components2: DataArray | Dataset | List[DataArray] -# Right components of the fitted model. - -# """ -# return super().components() - -# def scores(self): -# """Return the scores of the left and right field. - -# The scores in MCA are the projection of the left and right field onto the -# left and right singular vector of the cross-covariance matrix. - -# Returns -# ------- -# scores1: DataArray -# Left scores. -# scores2: DataArray -# Right scores. - -# """ -# return super().scores() - -# def homogeneous_patterns(self, correction=None, alpha=0.05): -# """Return the homogeneous patterns of the left and right field. - -# The homogeneous patterns are the correlation coefficients between the -# input data and the scores. - -# More precisely, the homogeneous patterns `r_{hom}` are defined as - -# .. math:: -# r_{hom, x} = corr \\left(X, A_x \\right) -# .. math:: -# r_{hom, y} = corr \\left(Y, A_y \\right) - -# where :math:`X` and :math:`Y` are the input data, :math:`A_x` and :math:`A_y` -# are the scores of the left and right field, respectively. - -# Parameters -# ---------- -# correction: str, default=None -# Method to apply a multiple testing correction. If None, no correction -# is applied. Available methods are: -# - bonferroni : one-step correction -# - sidak : one-step correction -# - holm-sidak : step down method using Sidak adjustments -# - holm : step-down method using Bonferroni adjustments -# - simes-hochberg : step-up method (independent) -# - hommel : closed method based on Simes tests (non-negative) -# - fdr_bh : Benjamini/Hochberg (non-negative) (default) -# - fdr_by : Benjamini/Yekutieli (negative) -# - fdr_tsbh : two stage fdr correction (non-negative) -# - fdr_tsbky : two stage fdr correction (non-negative) -# alpha: float, default=0.05 -# The desired family-wise error rate. Not used if `correction` is None. - -# Returns -# ------- -# patterns1: DataArray | Dataset | List[DataArray] -# Left homogenous patterns. -# patterns2: DataArray | Dataset | List[DataArray] -# Right homogenous patterns. -# pvals1: DataArray | Dataset | List[DataArray] -# Left p-values. -# pvals2: DataArray | Dataset | List[DataArray] -# Right p-values. - -# """ -# input_data1 = self.data.input_data1 -# input_data2 = self.data.input_data2 - -# scores1 = self.data.scores1 -# scores2 = self.data.scores2 - -# hom_pat1, pvals1 = pearson_correlation( -# input_data1, scores1, correction=correction, alpha=alpha -# ) -# hom_pat2, pvals2 = pearson_correlation( -# input_data2, scores2, correction=correction, alpha=alpha -# ) - -# hom_pat1 = self.preprocessor1.inverse_transform_components(hom_pat1) -# hom_pat2 = self.preprocessor2.inverse_transform_components(hom_pat2) - -# pvals1 = self.preprocessor1.inverse_transform_components(pvals1) -# pvals2 = self.preprocessor2.inverse_transform_components(pvals2) - -# hom_pat1.name = "left_homogeneous_patterns" -# hom_pat2.name = "right_homogeneous_patterns" - -# pvals1.name = "pvalues_of_left_homogeneous_patterns" -# pvals2.name = "pvalues_of_right_homogeneous_patterns" - -# return (hom_pat1, hom_pat2), (pvals1, pvals2) - -# def heterogeneous_patterns(self, correction=None, alpha=0.05): -# """Return the heterogeneous patterns of the left and right field. - -# The heterogeneous patterns are the correlation coefficients between the -# input data and the scores of the other field. - -# More precisely, the heterogeneous patterns `r_{het}` are defined as - -# .. math:: -# r_{het, x} = corr \\left(X, A_y \\right) -# .. math:: -# r_{het, y} = corr \\left(Y, A_x \\right) - -# where :math:`X` and :math:`Y` are the input data, :math:`A_x` and :math:`A_y` -# are the scores of the left and right field, respectively. - -# Parameters -# ---------- -# correction: str, default=None -# Method to apply a multiple testing correction. If None, no correction -# is applied. Available methods are: -# - bonferroni : one-step correction -# - sidak : one-step correction -# - holm-sidak : step down method using Sidak adjustments -# - holm : step-down method using Bonferroni adjustments -# - simes-hochberg : step-up method (independent) -# - hommel : closed method based on Simes tests (non-negative) -# - fdr_bh : Benjamini/Hochberg (non-negative) (default) -# - fdr_by : Benjamini/Yekutieli (negative) -# - fdr_tsbh : two stage fdr correction (non-negative) -# - fdr_tsbky : two stage fdr correction (non-negative) -# alpha: float, default=0.05 -# The desired family-wise error rate. Not used if `correction` is None. - -# """ -# input_data1 = self.data.input_data1 -# input_data2 = self.data.input_data2 - -# scores1 = self.data.scores1 -# scores2 = self.data.scores2 - -# patterns1, pvals1 = pearson_correlation( -# input_data1, scores2, correction=correction, alpha=alpha -# ) -# patterns2, pvals2 = pearson_correlation( -# input_data2, scores1, correction=correction, alpha=alpha -# ) - -# patterns1 = self.preprocessor1.inverse_transform_components(patterns1) -# patterns2 = self.preprocessor2.inverse_transform_components(patterns2) - -# pvals1 = self.preprocessor1.inverse_transform_components(pvals1) -# pvals2 = self.preprocessor2.inverse_transform_components(pvals2) - -# patterns1.name = "left_heterogeneous_patterns" -# patterns2.name = "right_heterogeneous_patterns" - -# pvals1.name = "pvalues_of_left_heterogeneous_patterns" -# pvals2.name = "pvalues_of_right_heterogeneous_patterns" - -# return (patterns1, patterns2), (pvals1, pvals2) - - -# class ComplexMCA(MCA): -# """Complex Maximum Covariance Analysis (MCA). - -# Complex MCA, also referred to as Analytical SVD (ASVD) by Shane et al. (2017)[1]_, -# enhances traditional MCA by accommodating both amplitude and phase information. -# It achieves this by utilizing the Hilbert transform to preprocess the data, -# thus allowing for a more comprehensive analysis in the subsequent MCA computation. - -# An optional padding with exponentially decaying values can be applied prior to -# the Hilbert transform in order to mitigate the impact of spectral leakage. - -# Parameters -# ---------- -# n_modes: int, default=10 -# Number of modes to calculate. -# standardize: bool, default=False -# Whether to standardize the input data. -# use_coslat: bool, default=False -# Whether to use cosine of latitude for scaling. -# use_weights: bool, default=False -# Whether to use additional weights. -# padding : str, optional -# Specifies the method used for padding the data prior to applying the Hilbert -# transform. This can help to mitigate the effect of spectral leakage. -# Currently, only 'exp' for exponential padding is supported. Default is 'exp'. -# decay_factor : float, optional -# Specifies the decay factor used in the exponential padding. This parameter -# is only used if padding='exp'. The recommended value typically ranges between 0.05 to 0.2 -# but ultimately depends on the variability in the data. -# A smaller value (e.g. 0.05) is recommended for -# data with high variability, while a larger value (e.g. 0.2) is recommended -# for data with low variability. Default is 0.2. -# solver_kwargs: dict, default={} -# Additional keyword arguments passed to the SVD solver. - -# Notes -# ----- -# Complex MCA extends MCA to complex-valued data that contain both magnitude and phase information. -# The Hilbert transform is used to transform real-valued data to complex-valued data, from which both -# amplitude and phase can be extracted. - -# Similar to MCA, Complex MCA is used in climate science to identify coupled patterns of variability -# between two different climate variables. But unlike MCA, Complex MCA can identify coupled patterns -# that involve phase shifts. - -# References -# ---------- -# [1]_: Elipot, S., Frajka-Williams, E., Hughes, C.W., Olhede, S., Lankhorst, M., 2017. Observed Basin-Scale Response of the North Atlantic Meridional Overturning Circulation to Wind Stress Forcing. Journal of Climate 30, 2029–2054. https://doi.org/10.1175/JCLI-D-16-0664.1 - - -# Examples -# -------- -# >>> model = ComplexMCA(n_modes=5, standardize=True) -# >>> model.fit(data1, data2) - -# """ - -# def __init__(self, padding="exp", decay_factor=0.2, **kwargs): -# super().__init__(**kwargs) -# self.attrs.update({"model": "Complex MCA"}) -# self._params.update({"padding": padding, "decay_factor": decay_factor}) - -# # Initialize the DataContainer to store the results -# self.data: ComplexMCADataContainer = ComplexMCADataContainer() - -# def fit( -# self, -# data1: AnyDataObject, -# data2: AnyDataObject, -# dim, -# weights1=None, -# weights2=None, -# ): -# """Fit the model. - -# Parameters -# ---------- -# data1: xr.DataArray or list of xarray.DataArray -# Left input data. -# data2: xr.DataArray or list of xarray.DataArray -# Right input data. -# dim: tuple -# Tuple specifying the sample dimensions. The remaining dimensions -# will be treated as feature dimensions. -# weights1: xr.DataArray or xr.Dataset or None, default=None -# If specified, the left input data will be weighted by this array. -# weights2: xr.DataArray or xr.Dataset or None, default=None -# If specified, the right input data will be weighted by this array. - -# """ - -# data1_processed: DataArray = self.preprocessor1.fit_transform( -# data1, dim, weights2 -# ) -# data2_processed: DataArray = self.preprocessor2.fit_transform( -# data2, dim, weights2 -# ) - -# # apply hilbert transform: -# padding = self._params["padding"] -# decay_factor = self._params["decay_factor"] -# data1_processed = hilbert_transform( -# data1_processed, dim="sample", padding=padding, decay_factor=decay_factor -# ) -# data2_processed = hilbert_transform( -# data2_processed, dim="sample", padding=padding, decay_factor=decay_factor -# ) - -# decomposer = CrossDecomposer( -# n_modes=self._params["n_modes"], **self._solver_kwargs -# ) -# decomposer.fit(data1_processed, data2_processed) - -# # Note: -# # - explained variance is given by the singular values of the SVD; -# # - We use the term singular_values_pca as used in the context of PCA: -# # Considering data X1 = X2, MCA is the same as PCA. In this case, -# # singular_values_pca is equivalent to the singular values obtained -# # when performing PCA of X1 or X2. -# singular_values = decomposer.singular_values_ -# singular_vectors1 = decomposer.singular_vectors1_ -# singular_vectors2 = decomposer.singular_vectors2_ - -# squared_covariance = singular_values**2 -# total_squared_covariance = decomposer.total_squared_covariance_ - -# norm1 = np.sqrt(singular_values) -# norm2 = np.sqrt(singular_values) - -# # Index of the sorted squared covariance -# idx_sorted_modes = squared_covariance.compute().argsort()[::-1] -# idx_sorted_modes.coords.update(squared_covariance.coords) - -# # Project the data onto the singular vectors -# scores1 = xr.dot(data1_processed, singular_vectors1) / norm1 -# scores2 = xr.dot(data2_processed, singular_vectors2) / norm2 - -# self.data.set_data( -# input_data1=data1_processed, -# input_data2=data2_processed, -# components1=singular_vectors1, -# components2=singular_vectors2, -# scores1=scores1, -# scores2=scores2, -# squared_covariance=squared_covariance, -# total_squared_covariance=total_squared_covariance, -# idx_modes_sorted=idx_sorted_modes, -# norm1=norm1, -# norm2=norm2, -# ) -# # Assign analysis relevant meta data -# self.data.set_attrs(self.attrs) - -# def components_amplitude(self) -> Tuple[AnyDataObject, AnyDataObject]: -# """Compute the amplitude of the components. - -# The amplitude of the components are defined as - -# .. math:: -# A_ij = |C_ij| - -# where :math:`C_{ij}` is the :math:`i`-th entry of the :math:`j`-th component and -# :math:`|\\cdot|` denotes the absolute value. - -# Returns -# ------- -# AnyDataObject -# Amplitude of the left components. -# AnyDataObject -# Amplitude of the left components. - -# """ -# comps1 = self.data.components_amplitude1 -# comps2 = self.data.components_amplitude2 - -# comps1 = self.preprocessor1.inverse_transform_components(comps1) -# comps2 = self.preprocessor2.inverse_transform_components(comps2) - -# return (comps1, comps2) - -# def components_phase(self) -> Tuple[AnyDataObject, AnyDataObject]: -# """Compute the phase of the components. - -# The phase of the components are defined as - -# .. math:: -# \\phi_{ij} = \\arg(C_{ij}) - -# where :math:`C_{ij}` is the :math:`i`-th entry of the :math:`j`-th component and -# :math:`\\arg(\\cdot)` denotes the argument of a complex number. - -# Returns -# ------- -# AnyDataObject -# Phase of the left components. -# AnyDataObject -# Phase of the right components. - -# """ -# comps1 = self.data.components_phase1 -# comps2 = self.data.components_phase2 - -# comps1 = self.preprocessor1.inverse_transform_components(comps1) -# comps2 = self.preprocessor2.inverse_transform_components(comps2) - -# return (comps1, comps2) - -# def scores_amplitude(self) -> Tuple[DataArray, DataArray]: -# """Compute the amplitude of the scores. - -# The amplitude of the scores are defined as - -# .. math:: -# A_ij = |S_ij| - -# where :math:`S_{ij}` is the :math:`i`-th entry of the :math:`j`-th score and -# :math:`|\\cdot|` denotes the absolute value. - -# Returns -# ------- -# DataArray -# Amplitude of the left scores. -# DataArray -# Amplitude of the right scores. - -# """ -# scores1 = self.data.scores_amplitude1 -# scores2 = self.data.scores_amplitude2 - -# scores1 = self.preprocessor1.inverse_transform_scores(scores1) -# scores2 = self.preprocessor2.inverse_transform_scores(scores2) -# return (scores1, scores2) - -# def scores_phase(self) -> Tuple[DataArray, DataArray]: -# """Compute the phase of the scores. - -# The phase of the scores are defined as - -# .. math:: -# \\phi_{ij} = \\arg(S_{ij}) - -# where :math:`S_{ij}` is the :math:`i`-th entry of the :math:`j`-th score and -# :math:`\\arg(\\cdot)` denotes the argument of a complex number. - -# Returns -# ------- -# DataArray -# Phase of the left scores. -# DataArray -# Phase of the right scores. - -# """ -# scores1 = self.data.scores_phase1 -# scores2 = self.data.scores_phase2 - -# scores1 = self.preprocessor1.inverse_transform_scores(scores1) -# scores2 = self.preprocessor2.inverse_transform_scores(scores2) - -# return (scores1, scores2) - -# def transform(self, data1: AnyDataObject, data2: AnyDataObject): -# raise NotImplementedError("Complex MCA does not support transform method.") - -# def homogeneous_patterns(self, correction=None, alpha=0.05): -# raise NotImplementedError( -# "Complex MCA does not support homogeneous_patterns method." -# ) - -# def heterogeneous_patterns(self, correction=None, alpha=0.05): -# raise NotImplementedError( -# "Complex MCA does not support heterogeneous_patterns method." -# ) +""" +This code is based on the work of James Chapman from cca-zoo. +Source: https://github.com/jameschapman19/cca_zoo + +The original code is licensed under the MIT License. + +Copyright (c) 2020-2023 James Chapman +""" + +from abc import abstractmethod +from datetime import datetime +from typing import Sequence, List, Hashable +from typing_extensions import Self + +import dask.array as da +import numpy as np +import xarray as xr +from scipy.linalg import eigh +from sklearn.base import BaseEstimator +from sklearn.utils.validation import FLOAT_DTYPES +from xeofs.models import EOF + +from .._version import __version__ +from ..preprocessing.preprocessor import Preprocessor +from ..utils.data_types import DataObject, DataArray, DataList + + +def _check_parameter_number(parameter_name: str, parameter, n_views: int): + if len(parameter) != n_views: + raise ValueError( + f"number of views passed should match number of parameter {parameter_name}" + f"len(views)={n_views} and " + f"len({parameter_name})={len(parameter)}" + ) + + +def _process_parameter(parameter_name: str, parameter, default, n_views: int): + if parameter is None: + parameter = [default] * n_views + elif not isinstance(parameter, (list, tuple)): + parameter = [parameter] * n_views + _check_parameter_number(parameter_name, parameter, n_views) + return parameter + + +class CCABaseModel(BaseEstimator): + def __init__( + self, + n_modes: int = 10, + use_coslat: bool = False, + pca: bool = False, + variance_fraction: float = 0.99, + init_pca_modes: int | float = 0.75, + compute: bool = True, + sample_name: str = "sample", + feature_name: str = "feature", + ): + self.sample_name = sample_name + self.feature_name = feature_name + self.n_modes = n_modes + self.use_coslat = use_coslat + self.pca = pca + self.compute = compute + self.variance_fraction = variance_fraction + self.init_pca_modes = init_pca_modes + + self.dtypes = FLOAT_DTYPES + + self._preprocessor_kwargs = { + "sample_name": sample_name, + "feature_name": feature_name, + "with_std": False, + } + + # Define analysis-relevant meta data + self.attrs = {"model": "BaseCrossModel"} + self.attrs.update( + { + "software": "xeofs", + "version": __version__, + "date": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), + } + ) + + # Initialize the data container only to avoid type errors + # The actual data container will be initialized in respective subclasses + # self.data: _BaseCrossModelDataContainer = _BaseCrossModelDataContainer() + self.data = {} + + def _validate_data(self, views: Sequence[DataArray]): + if not all( + data[self.sample_name].size == views[0][self.sample_name].size + for data in views + ): + raise ValueError("All views must have the same number of samples") + if not all(data.ndim == 2 for data in views): + raise ValueError("All views must have 2 dimensions") + if not all(data.dtype in self.dtypes for data in views): + raise ValueError("All views must have dtype of {}.".format(self.dtypes)) + if not all(data[self.feature_name].size >= self.n_modes for data in views): + raise ValueError( + "All views must have at least {} features.".format(self.n_modes) + ) + + def _process_init_pca_modes(self, n_modes): + err_msg = "init_pca_modes must be either a float <= 1.0 or an integer > 1" + n_modes_list = [] + n_modes_max = [ + min(self.n_samples_, n_features) for n_features in self.n_features_ + ] + for n, n_max in zip(n_modes, n_modes_max): + if isinstance(n, float): + if n > 1.0: + raise ValueError(err_msg) + n = int(n * n_max) + n_modes_list.append(n) + elif isinstance(n, int): + if n <= 1: + raise ValueError(err_msg) + n_modes_list.append(n) + else: + raise ValueError(err_msg) + return n_modes_list + + def fit( + self, + views: Sequence[DataObject], + dim: Hashable | Sequence[Hashable], + ) -> Self: + self.n_views_ = len(views) + self.use_coslat = _process_parameter( + "use_coslat", self.use_coslat, False, self.n_views_ + ) + self.init_pca_modes = _process_parameter( + "init_pca_modes", self.init_pca_modes, 0.75, self.n_views_ + ) + + # Preprocess the input data + self.preprocessors = [ + Preprocessor(with_coslat=self.use_coslat[i], **self._preprocessor_kwargs) + for i in range(self.n_views_) + ] + views2D: List[DataArray] = [ + preprocessor.fit_transform(data, dim) + for preprocessor, data in zip(self.preprocessors, views) + ] + self._validate_data(views2D) + self.n_features_ = [data.coords[self.feature_name].size for data in views2D] + self.n_samples_ = views2D[0][self.sample_name].size + + self.data["input_data"] = views2D + views2D = self._process_data(views2D) + self.data["pca_data"] = views2D + + self._fit_algorithm(views2D) + + return self + + def _process_data(self, views: DataList) -> DataList: + if self.pca: + views = self._apply_pca(views) + return views + + def _apply_pca(self, views: DataList): + self.pca_models = [] + + n_pca_modes = self._process_init_pca_modes(self.init_pca_modes) + + view_transformed = [] + + for i, view in enumerate(views): + pca = EOF(n_modes=n_pca_modes[i], compute=self.compute) + pca.fit(view, dim=self.sample_name) + if self.compute: + pca.compute() + self.pca_models.append(pca) + + # TODO: method to get cumulative explained variance + cum_exp_var_ratio = pca.explained_variance_ratio().cumsum() + # Ensure that the sum of the explained variance ratio is always less than 1 + # Due to rounding errors the total sum may be slightly larger than 1, + # which we counter by a small correction + cum_exp_var_ratio -= 1e-6 + max_exp_var_ratio = cum_exp_var_ratio.isel(mode=-1).item() + if ( + max_exp_var_ratio <= self.variance_fraction + and max_exp_var_ratio <= 0.9999 + ): + print( + "Warning: variance fraction {:.4f} is not reached. ".format( + self.variance_fraction + ) + + "Only {:.4f} of variance is explained.".format( + cum_exp_var_ratio.isel(mode=-1).item() + ) + ) + n_modes_keep = cum_exp_var_ratio.where( + cum_exp_var_ratio <= self.variance_fraction, drop=True + ).size + if n_modes_keep == 0: + n_modes_keep += 1 + + # TODO: it's more convinient to work the common scaling of sklearn; provide additional parameter + # provide this parameter to transform method as well + scores = pca.scores().isel(mode=slice(0, n_modes_keep)) + svals = pca.singular_values().isel(mode=slice(0, n_modes_keep)) + scores = ( + (scores * svals) + .rename({"mode": self.feature_name}) + .transpose(self.sample_name, self.feature_name) + ) + view_transformed.append(scores) + return view_transformed + + @abstractmethod + def _fit_algorithm(self, views: List[DataArray]) -> Self: + raise NotImplementedError + + +class CCA(CCABaseModel): + r"""Canonical Correlation Analysis (CCA) model. + + Regularised CCA (canonical ridge) model. + + CCA identifies linear combinations of variables from multiple datasets that + maximize their mutual correlations. An optional regularisation parameter can be used to + improve the conditioning of the covariance matrix. + + The objective function of (regularised) CCA is: + + .. math:: + + w_{opt}=\underset{w}{\mathrm{argmax}}\{ w_1^TX_1^TX_2w_2 \}\\ + + \text{subject to:} + + (1-c_1)w_1^TX_1^TX_1w_1+c_1w_1^Tw_1=n + + (1-c_2)w_2^TX_2^TX_2w_2+c_2w_2^Tw_2=n + + where :math:`c_i` are the regularization parameters for dataset. + + Parameters + ---------- + n_modes : int, optional + Number of latent dimensions to use, by default 10 + use_coslat : bool, optional + Whether to use the square root of the cosine of the latitude as weights, by default False + pca : bool, optional + Whether to perform PCA on the input data, by default True + variance_fraction : float, optional + Fraction of variance to keep when performing PCA, by default 0.99 + init_pca_modes : int | float, optional + Number of PCA modes to compute. If float, the number of modes is given by the fraction of maximum number of modes for the given data. + A value of 1.0 will perform a full SVD of the data. Choosing a smaller value can increase computation speed. Default 0.75 + c : Sequence[float] | float], optional + Regularisation parameter, by default 0 (no regularization) + compute : bool, optional + Whether to compute the decomposition immediately, by default True + + + Notes + ----- + This implementation is largely based on the MCCA class from the cca_zoo repository [3]_ . + + + References + ---------- + .. [1] Vinod, Hrishikesh _D. "Canonical ridge and econometrics of joint production." Journal of econometrics 4.2 (1976): 147-166. + .. [2] Hotelling, Harold. "Relations between two sets of variates." Breakthroughs in statistics. Springer, New York, NY, 1992. 162-190. + .. [3] Chapman et al., (2021). CCA-Zoo: A collection of Regularized, Deep Learning based, Kernel, and Probabilistic CCA methods in a scikit-learn style framework. Journal of Open Source Software, 6(68), 3823 + + Examples + -------- + >>> from xe.models import CCA + >>> model = CCA(n_modes=5) + >>> model.fit(data) + >>> can_loadings = model.canonical_loadings() + + """ + + def __init__( + self, + n_modes: int = 2, + use_coslat: bool = False, + c: float = 0, + pca: bool = True, + variance_fraction: float = 0.99, + init_pca_modes: float = 0.75, + compute: bool = True, + eps: float = 1e-6, + ): + super().__init__( + n_modes=n_modes, + use_coslat=use_coslat, + pca=pca, + compute=compute, + variance_fraction=variance_fraction, + init_pca_modes=init_pca_modes, + ) + self.attrs.update({"model": "CCA"}) + self.c = c + self.eps = eps + + def _fit_algorithm(self, views: List[DataArray]) -> Self: + self.c = _process_parameter("c", self.c, 0, self.n_views_) + eigvals, eigvecs = self._solve_gevp(views) + self.eigvals = eigvals + self.eigvecs = eigvecs + # Compute the weights for each view + self._weights(eigvals, eigvecs, views) + # Compute loadings (= normalized weights) + self.data["loadings"] = [ + wght / self._apply_norm(wght, [self.feature_name]) + for wght in self.data["weights"] + ] + canonical_variates = self._transform(self.data["input_data"]) + self.data["variates"] = canonical_variates + + self.data["canonical_loadings"] = [ + xr.dot(data, vari, dims=self.sample_name, optimize=True) + for data, vari in zip(self.data["input_data"], canonical_variates) + ] + + # Compute explained variance + # Transform the views using the loadings + transformed_views = [ + xr.dot(view, loading, dims=self.feature_name) + for view, loading in zip(views, self.data["loadings"]) + ] + # Calculate the variance of each latent dimension in the transformed views + self.data["explained_variance"] = [ + transformed.var(self.sample_name) for transformed in transformed_views + ] + + # Explained variance ratio + self.data["total_variance"] = [ + view.var(self.sample_name).sum() for view in views + ] + + # Calculate the explained variance ratio for each latent dimension for each view + self.data["explained_variance_ratio"] = [ + exp_var / total_var + for exp_var, total_var in zip( + self.data["explained_variance"], self.data["total_variance"] + ) + ] + + # Explained Covariance + k = self.n_modes + explained_covariance = [] + + # just take the kth column of each transformed view and _compute_covariance + for i in range(k): + transformed_views_k = [ + view.isel(mode=slice(i, i + 1)) for view in transformed_views + ] + cov_ = self._apply_compute_covariance( + transformed_views_k, dims_in=["sample", "mode"] + ) + svals = self._compute_singular_values(cov_, dims_in=["mode1", "mode2"]) + explained_covariance.append(svals.isel(mode=0).item()) + self.data["explained_covariance"] = xr.DataArray( + explained_covariance, dims=["mode"], coords={"mode": range(1, k + 1)} + ) + + minimum_dimension = min([view[self.feature_name].size for view in views]) + + cov = self._apply_compute_covariance(views, dims_in=["sample", "feature"]) + S = self._compute_singular_values(cov, dims_in=["feature1", "feature2"]) + # select every other element starting from the first until the minimum dimension + self.data["total_explained_covariance"] = ( + S.isel(mode=slice(0, None, 2)).isel(mode=slice(0, minimum_dimension)).sum() + ) + self.data["explained_covariance_ratio"] = ( + self.data["explained_covariance"] / self.data["total_explained_covariance"] + ) + + return self + + def _compute_singular_values( + self, x, dims_in=["feature1", "feature2"], dims_out=["mode"] + ): + svals = xr.apply_ufunc( + np.linalg.svd, + x, + input_core_dims=[dims_in], + output_core_dims=[dims_out], + kwargs={"compute_uv": False}, + vectorize=False, + dask="allowed", + ) + svals = svals.assign_coords({"mode": range(1, svals.mode.size + 1)}) + return svals + + def _apply_norm(self, x, dims): + return xr.apply_ufunc( + np.linalg.norm, + x, + input_core_dims=[dims], + output_core_dims=[[]], + kwargs={"axis": -1}, + vectorize=True, + dask="allowed", + ) + + def _solve_gevp(self, views: Sequence[DataArray], y=None, **kwargs): + # Setup the eigenvalue problem + C = self._C(views, dims_in=[self.sample_name, self.feature_name]) + D = self._D(views, **kwargs) + self.splits = np.cumsum([view.shape[1] for view in views]) + # Solve the eigenvalue problem + # Get the dimension of _C + p = C.shape[0] + subset_by_index = [p - self.n_modes, p - 1] + # Solve the generalized eigenvalue problem Cx=lambda Dx using a subset of eigenvalues and eigenvectors + [eigvals, eigvecs] = self._apply_eigh(C, D, subset_by_index=subset_by_index) + # Sort the eigenvalues and eigenvectors in descending order + idx_sorted_modes = eigvals.compute().argsort()[::-1] + idx_sorted_modes = idx_sorted_modes.assign_coords( + {"mode": range(idx_sorted_modes.mode.size)} + ) + eigvals = eigvals.isel(mode=idx_sorted_modes) + eigvecs = eigvecs.isel(mode=idx_sorted_modes).real + # Set coordiantes + coords_mode = range(1, eigvals.mode.size + 1) + coords_feature = C.coords[self.feature_name + "1"].values + eigvals = eigvals.assign_coords({"mode": coords_mode}) + eigvecs = eigvecs.assign_coords( + { + "mode": coords_mode, + self.feature_name: coords_feature, + } + ) + return eigvals, eigvecs + + def _weights(self, eigvals, eigvecs, views, **kwargs): + # split eigvecs into weights for each view + # add 0 before the np ndarray splits + idx = np.concatenate([[0], self.splits]) + self.data["weights"] = [ + eigvecs.isel({self.feature_name: slice(idx[i], idx[i + 1])}) + for i in range(len(idx) - 1) + ] + if self.pca: + # go from weights in PCA space to weights in original space + n_modes = [data.feature.size for data in self.data["pca_data"]] + self.data["weights"] = [ + xr.dot( + pca.components() + .isel(mode=slice(0, n_modes[i])) + .rename({"mode": "temp_dim"}), + self.data["weights"][i].rename({"feature": "temp_dim"}), + dims="temp_dim", + optimize=True, + ) + for i, pca in enumerate(self.pca_models) + ] + + def _apply_eigh(self, a, b, subset_by_index): + return xr.apply_ufunc( + eigh, + a, + b, + input_core_dims=[ + [self.feature_name + "1", self.feature_name + "2"], + [self.feature_name + "1", self.feature_name + "2"], + ], + output_core_dims=[["mode"], ["feature", "mode"]], + kwargs={"subset_by_index": subset_by_index}, + vectorize=False, + dask="allowed", + ) + + def _C(self, views, dims_in): + C = self._apply_compute_covariance(views, dims_in=dims_in) + return C / len(views) + + def _apply_compute_covariance( + self, views: Sequence[DataArray], dims_in, dims_out=None + ) -> DataArray: + if dims_out is None: + dims_out = [dims_in[1] + "1", dims_in[1] + "2"] + all_views = xr.concat(views, dim=dims_in[1]) + C = self._apply_cov(all_views, dims_in=dims_in, dims_out=dims_out) + Ci = [ + self._apply_cov(view, dims_in=dims_in, dims_out=dims_out) for view in views + ] + return C - self._block_diag_dask(Ci, dims_in=dims_out) + + def _apply_cov( + self, x, dims_in=["sample", "feature"], dims_out=["feature1", "feature2"] + ): + if x[dims_in[1]].size == 1: + return xr.apply_ufunc( + np.cov, + x, + input_core_dims=[dims_in], + output_core_dims=[[]], + kwargs={"rowvar": False}, + vectorize=False, + dask="allowed", + ) + else: + C = xr.apply_ufunc( + np.cov, + x, + input_core_dims=[dims_in], + output_core_dims=[dims_out], + kwargs={"rowvar": False}, + vectorize=False, + dask="allowed", + ) + feature_coords = x.coords[dims_in[1]].values + C = C.assign_coords( + {dims_out[0]: feature_coords, dims_out[1]: feature_coords} + ) + return C + + def _block_diag_dask(self, views, dims_in=["feature1", "featur2"], dims_out=None): + if dims_out is None: + dims_out = dims_in + if all(view.size == 1 for view in views): + result = da.diag(np.array([view.item() for view in views])) + else: + # Extract underlying Dask arrays + arrays = [da.asarray(view.data) for view in views] + + # Construct a block-diagonal dask array + blocks = [ + [ + darr2 if j == i else da.zeros((darr2.shape[0], darr1.shape[0])) + for j, darr1 in enumerate(views) + ] + for i, darr2 in enumerate(arrays) + ] + + # Use Dask's block to stack the arrays + blocked_array = da.block(blocks) + + # Convert the result back to a DataArray + feature_coords = xr.concat(views, dim=dims_in[0])[dims_in[0]].values + result = xr.DataArray( + blocked_array, + dims=dims_out, + coords={dims_out[0]: feature_coords, dims_out[1]: feature_coords}, + ) + if any(isinstance(view.data, da.Array) for view in views): + return result + else: + return result.compute() + + def _D(self, views): + if self.pca: + blocks = [] + for i, view in enumerate(views): + pc = self.pca_models[i] + feature_coords = view.coords[self.feature_name] + n_features = feature_coords.size + expvar = pc.explained_variance().isel(mode=slice(0, n_features)) + block = xr.DataArray( + da.diag((1 - self.c[i]) * expvar.data + self.c[i]), + dims=[self.feature_name + "1", self.feature_name + "2"], + coords={ + self.feature_name + "1": feature_coords.values, + self.feature_name + "2": feature_coords.values, + }, + ) + block = block.compute() + blocks.append(block) + + else: + blocks = [self._apply_E(view, c) for view, c in zip(views, self.c)] + + D = self._block_diag_dask(blocks, dims_in=["feature1", "feature2"]) + + D_smallest_eig = self._apply_smallest_eigval(D, dims=["feature1", "feature2"]) + D_smallest_eig = D_smallest_eig - self.eps + identity_matrix = xr.DataArray(np.eye(D.shape[0]), dims=D.dims, coords=D.coords) + D = D - D_smallest_eig * identity_matrix + return D / len(views) + + def _apply_E(self, view, c): + E = xr.apply_ufunc( + self._E, + view, + input_core_dims=[[self.sample_name, self.feature_name]], + output_core_dims=[[self.feature_name + "1", self.feature_name + "2"]], + kwargs={"c": c}, + vectorize=False, + dask="allowed", + ) + feature_coords = view.coords[self.feature_name].values + E = E.assign_coords( + { + self.feature_name + "1": feature_coords, + self.feature_name + "2": feature_coords, + } + ) + return E + + def _E(self, view, c): + return (1 - c) * np.cov(view, rowvar=False) + c * np.eye(view.shape[1]) + + def _apply_smallest_eigval(self, D, dims): + return xr.apply_ufunc( + self._smallest_eigval, + D, + input_core_dims=[dims], + output_core_dims=[[]], + vectorize=True, + dask="allowed", + ) + + def _smallest_eigval(self, D): + return min(0, np.linalg.eigvalsh(D).min()) + + def weights(self) -> List[DataObject]: + weights = [ + prep.inverse_transform_components(wghts) + for prep, wghts in zip(self.preprocessors, self.data["weights"]) + ] + return weights + + def _transform(self, views: Sequence[DataArray]) -> List[DataArray]: + transformed_views = [] + for i, view in enumerate(views): + transformed_view = xr.dot(view, self.data["weights"][i], dims="feature") + transformed_views.append(transformed_view) + return transformed_views + + def transform(self, views: Sequence[DataObject]) -> List[DataArray]: + """Transform the input data into the canonical space. + + Parameters + ---------- + views : List[DataArray | Dataset] + Input data to transform + + """ + view_preprocessed = [] + for i, view in enumerate(views): + view_preprocessed = self.preprocessors[i].transform(view) + + transformed_views = self._transform(view_preprocessed) + + unstacked_transformed_views = [] + for i, view in enumerate(transformed_views): + unstacked_view = self.preprocessors[i].inverse_transform_scores(view) + unstacked_transformed_views.append(unstacked_view) + return unstacked_transformed_views + + def components(self, normalize: bool = True) -> List[DataObject]: + """Get the canonical loadings for each view.""" + can_loads = self.data["canonical_loadings"] + input_data = self.data["input_data"] + variates = self.data["variates"] + + if normalize: + # Compute correlations + loadings = [ + ( + loads + / data[self.sample_name].size + / data.std(self.sample_name) + / vari.std(self.sample_name) + ).clip(-1, 1) + for loads, data, vari in zip(can_loads, input_data, variates) + ] + else: + loadings = can_loads + + loadings = [ + prep.inverse_transform_components(load) + for prep, load in zip(self.preprocessors, loadings) + ] + return loadings + + def scores(self) -> List[DataArray]: + """Get the canonical variates for each view.""" + variates = [] + for i, view in enumerate(self.data["variates"]): + vari = self.preprocessors[i].inverse_transform_scores(view) + variates.append(vari) + return variates + + def explained_variance(self) -> List[DataArray]: + """Get the explained variance for each view.""" + return self.data["explained_variance"] + + def explained_variance_ratio(self) -> List[DataArray]: + """Get the explained variance ratio for each view.""" + return self.data["explained_variance_ratio"] + + def explained_covariance(self) -> DataArray: + """Get the explained covariance.""" + return self.data["explained_covariance"] + + def explained_covariance_ratio(self) -> DataArray: + """Get the explained covariance ratio.""" + return self.data["explained_covariance_ratio"] diff --git a/xeofs/models/decomposer.py b/xeofs/models/decomposer.py index 3a0df3b..2416308 100644 --- a/xeofs/models/decomposer.py +++ b/xeofs/models/decomposer.py @@ -1,10 +1,12 @@ import numpy as np import xarray as xr from dask.array import Array as DaskArray # type: ignore +from dask.diagnostics.progress import ProgressBar from numpy.linalg import svd from sklearn.utils.extmath import randomized_svd from scipy.sparse.linalg import svds as complex_svd # type: ignore from dask.array.linalg import svd_compressed as dask_svd +from typing import Optional class Decomposer: @@ -18,20 +20,36 @@ class Decomposer: ---------- n_modes : int Number of components to be computed. + flip_signs : bool, default=True + Whether to flip the sign of the components to ensure deterministic output. + compute : bool, default=True + Whether to compute the decomposition immediately. solver : {'auto', 'full', 'randomized'}, default='auto' The solver is selected by a default policy based on size of `X` and `n_modes`: if the input data is larger than 500x500 and the number of modes to extract is lower than 80% of the smallest dimension of the data, then the more efficient `randomized` method is enabled. Otherwise the exact full SVD is computed and optionally truncated afterwards. + random_state : Optional[int], default=None + Seed for the random number generator. **kwargs Additional keyword arguments passed to the SVD solver. """ - def __init__(self, n_modes=100, flip_signs=True, solver="auto", **kwargs): + def __init__( + self, + n_modes: int, + flip_signs: bool = True, + compute: bool = True, + solver: str = "auto", + random_state: Optional[int] = None, + **kwargs, + ): self.n_modes = n_modes self.flip_signs = flip_signs + self.compute = compute self.solver = solver + self.random_state = random_state self.solver_kwargs = kwargs def fit(self, X, dims=("sample", "feature")): @@ -65,54 +83,34 @@ def fit(self, X, dims=("sample", "feature")): is_small_data = max(n_coords1, n_coords2) < 500 - if self.solver == "auto": - use_exact = ( - True if is_small_data and self.n_modes > int(0.8 * rank) else False - ) - elif self.solver == "full": - use_exact = True - elif self.solver == "randomized": - use_exact = False - else: - raise ValueError( - f"Unrecognized solver '{self.solver}'. " - "Valid options are 'auto', 'full', and 'randomized'." - ) + match self.solver: + case "auto": + use_exact = ( + True if is_small_data and self.n_modes > int(0.8 * rank) else False + ) + case "full": + use_exact = True + case "randomized": + use_exact = False + case _: + raise ValueError( + f"Unrecognized solver '{self.solver}'. " + "Valid options are 'auto', 'full', and 'randomized'." + ) # Use exact SVD for small data sets if use_exact: - U, s, VT = xr.apply_ufunc( - np.linalg.svd, - X, - kwargs=self.solver_kwargs, - input_core_dims=[dims], - output_core_dims=[ - [dims[0], "mode"], - ["mode"], - ["mode", dims[1]], - ], - dask="allowed", - vectorize=False, - ) + U, s, VT = self._svd(X, dims, np.linalg.svd, self.solver_kwargs) U = U[:, : self.n_modes] s = s[: self.n_modes] VT = VT[: self.n_modes, :] # Use randomized SVD for large, real-valued data sets elif (not use_complex) and (not use_dask): - self.solver_kwargs.update({"n_components": self.n_modes}) - - U, s, VT = xr.apply_ufunc( - randomized_svd, - X, - kwargs=self.solver_kwargs, - input_core_dims=[dims], - output_core_dims=[ - [dims[0], "mode"], - ["mode"], - ["mode", dims[1]], - ], + self.solver_kwargs.update( + {"n_components": self.n_modes, "random_state": self.random_state} ) + U, s, VT = self._svd(X, dims, randomized_svd, self.solver_kwargs) # Use scipy sparse SVD for large, complex-valued data sets elif use_complex and (not use_dask): @@ -121,19 +119,10 @@ def fit(self, X, dims=("sample", "feature")): { "k": self.n_modes, "solver": "lobpcg", + "random_state": self.random_state, } ) - U, s, VT = xr.apply_ufunc( - complex_svd, - X, - kwargs=self.solver_kwargs, - input_core_dims=[dims], - output_core_dims=[ - [dims[0], "mode"], - ["mode"], - ["mode", dims[1]], - ], - ) + U, s, VT = self._svd(X, dims, complex_svd, self.solver_kwargs) idx_sort = np.argsort(s)[::-1] U = U[:, idx_sort] s = s[idx_sort] @@ -141,19 +130,9 @@ def fit(self, X, dims=("sample", "feature")): # Use dask SVD for large, real-valued, delayed data sets elif (not use_complex) and use_dask: - self.solver_kwargs.update({"k": self.n_modes}) - U, s, VT = xr.apply_ufunc( - dask_svd, - X, - kwargs=self.solver_kwargs, - input_core_dims=[dims], - output_core_dims=[ - [dims[0], "mode"], - ["mode"], - ["mode", dims[1]], - ], - dask="allowed", - ) + self.solver_kwargs.update({"k": self.n_modes, "seed": self.random_state}) + U, s, VT = self._svd(X, dims, dask_svd, self.solver_kwargs) + U, s, VT = self._compute_svd_result(U, s, VT) else: err_msg = ( "Complex data together with dask is currently not implemented. See dask issue 7639 " @@ -184,3 +163,75 @@ def fit(self, X, dims=("sample", "feature")): self.U_ = U self.s_ = s self.V_ = VT.conj().transpose(dims[1], "mode") + + def _svd(self, X, dims, func, kwargs): + """Performs SVD on the data + + Parameters + ---------- + X : DataArray + A 2-dimensional data object to be decomposed. + dims : tuple of str + Dimensions of the data object. + func : Callable + Method to perform SVD. + kwargs : dict + Additional keyword arguments passed to the SVD solver. + + Returns + ------- + U : DataArray + Left singular vectors. + s : DataArray + Singular values. + VT : DataArray + Right singular vectors. + """ + try: + U, s, VT = xr.apply_ufunc( + func, + X, + kwargs=kwargs, + input_core_dims=[dims], + output_core_dims=[ + [dims[0], "mode"], + ["mode"], + ["mode", dims[1]], + ], + dask="allowed", + ) + return U, s, VT + except ValueError: + raise ValueError( + "SVD failed. This may be due to isolated NaN values in the data. Please consider the following steps:\n" + "1. Check for and remove any isolated NaNs in your dataset.\n" + "2. If the error persists, please raise an issue at https://github.com/nicrie/xeofs/issues." + ) + + def _compute_svd_result(self, U, s, VT): + """Computes the SVD result. + + Parameters + ---------- + U : DataArray + Left singular vectors. + s : DataArray + Singular values. + VT : DataArray + Right singular vectors. + + Returns + ------- + U : DataArray + Left singular vectors. + s : DataArray + Singular values. + VT : DataArray + Right singular vectors. + """ + if self.compute: + with ProgressBar(): + U = U.compute() + s = s.compute() + VT = VT.compute() + return U, s, VT diff --git a/xeofs/models/eeof.py b/xeofs/models/eeof.py new file mode 100644 index 0000000..82a8654 --- /dev/null +++ b/xeofs/models/eeof.py @@ -0,0 +1,167 @@ +from typing import Optional + +import numpy as np +import xarray as xr + +from ._base_model import _BaseModel +from .eof import EOF +from .decomposer import Decomposer +from ..utils.data_types import DataArray, Data, Dims +from ..data_container import DataContainer +from ..utils.xarray_utils import total_variance as compute_total_variance + + +class ExtendedEOF(EOF): + """Extended EOF analysis. + + Extended EOF (EEOF) analysis [1]_ [2]_, often referred to as + Multivariate/Multichannel Singular Spectrum Analysis, enhances + traditional EOF analysis by identifying propagating signals or + oscillations in multivariate datasets. This approach integrates the + spatial correlation of EOFs with the temporal auto- and cross-correlation + derived from the lagged covariance matrix. + + Parameters + ---------- + n_modes : int + Number of modes to be computed. + tau : int + Time delay used to construct a time-delayed version of the original time series. + embedding : int + Embedding dimension is the number of dimensions in the delay-coordinate space used to represent + the dynamics of the system. It determines the number of delayed copies + of the time series that are used to construct the delay-coordinate space. + n_pca_modes : Optional[int] + If provided, the input data is first preprocessed using PCA with the + specified number of modes. The EEOF analysis is then performed on the + resulting PCA scores. This approach can lead to important computational + savings. + **kwargs : + Additional keyword arguments passed to the EOF model. + + References + ---------- + .. [1] Weare, B. C. & Nasstrom, J. S. Examples of Extended Empirical Orthogonal Function Analyses. Monthly Weather Review 110, 481–485 (1982). + .. [2] Broomhead, D. S. & King, G. P. Extracting qualitative dynamics from experimental data. Physica D: Nonlinear Phenomena 20, 217–236 (1986). + + + Examples + -------- + >>> from xeofs.models import EEOF + >>> model = EEOF(n_modes=5, tau=1, embedding=20, n_pca_modes=20) + >>> model.fit(data, dim=("time")) + + Retrieve the extended empirical orthogonal functions (EEOFs) and their explained variance: + + >>> eeofs = model.components() + >>> exp_var = model.explained_variance() + + Retrieve the time-dependent coefficients corresponding to the EEOF modes: + + >>> scores = model.scores() + """ + + def __init__( + self, + n_modes: int, + tau: int, + embedding: int, + n_pca_modes: Optional[int] = None, + center: bool = True, + standardize: bool = False, + use_coslat: bool = False, + sample_name: str = "sample", + feature_name: str = "feature", + compute: bool = True, + solver: str = "auto", + random_state: Optional[int] = None, + solver_kwargs: dict = {}, + ): + super().__init__( + n_modes=n_modes, + center=center, + standardize=standardize, + use_coslat=use_coslat, + sample_name=sample_name, + feature_name=feature_name, + compute=compute, + solver=solver, + random_state=random_state, + solver_kwargs=solver_kwargs, + ) + self.attrs.update({"model": "Extended EOF Analysis"}) + self._params.update( + {"tau": tau, "embedding": embedding, "n_pca_modes": n_pca_modes} + ) + + # Initialize the DataContainer to store the results + self.data = DataContainer() + self.pca = ( + EOF( + n_modes=n_pca_modes, + center=True, + standardize=False, + use_coslat=False, + compute=self._params["compute"], + sample_name=self.sample_name, + feature_name=self.feature_name, + ) + if n_pca_modes + else None + ) + + def _fit_algorithm(self, X: DataArray): + self.data.add(X.copy(), "input_data", allow_compute=False) + + # Preprocess the data using PCA + if self.pca: + self.pca.fit(X, dim=self.sample_name) + X = self.pca.data["scores"] + X = X.rename({"mode": self.feature_name}) + + # Construct the time-delayed version of the original time series + tau = self._params["tau"] + embedding = self._params["embedding"] + shift = np.arange(embedding) * tau + X_extended = [] + for i in shift: + X_extended.append(X.shift(sample=-i)) + X_extended = xr.concat(X_extended, dim="embedding") + n_samples_cut = (embedding - 1) * tau + X_extended = X_extended.isel(sample=slice(None, -n_samples_cut)) + X_extended.coords.update({"embedding": shift}) + + # Perform standard PCA on extended data + n_modes = self._params["n_modes"] + model = EOF( + n_modes=n_modes, + center=True, + standardize=False, + use_coslat=False, + compute=self._params["compute"], + sample_name=self.sample_name, + feature_name=self.feature_name, + solver=self._params["solver"], + solver_kwargs=self._solver_kwargs, + ) + model.fit(X_extended, dim=self.sample_name) + + self.model = model + self.data = model.data + self.data["components"] = model.components() + self.data["scores"] = model.scores(normalized=False) + + if self.pca: + self.data["components"] = xr.dot( + self.pca.data["components"].rename({"mode": "temp"}), + self.data["components"].rename({"feature": "temp"}), + dims="temp", + ) + + self.data.set_attrs(self.attrs) + + def _transform_algorithm(self, X): + raise NotImplementedError("EEOF does currently not support transform") + + def _inverse_transform_algorithm(self, X): + raise NotImplementedError("EEOF does currently not support inverse transform") diff --git a/xeofs/models/eof.py b/xeofs/models/eof.py index ad29064..027094a 100644 --- a/xeofs/models/eof.py +++ b/xeofs/models/eof.py @@ -1,28 +1,38 @@ +from typing import Optional, Dict +from typing_extensions import Self +import numpy as np import xarray as xr from ._base_model import _BaseModel from .decomposer import Decomposer -from ..utils.data_types import AnyDataObject, DataArray -from ..data_container import EOFDataContainer, ComplexEOFDataContainer -from ..utils.xarray_utils import hilbert_transform +from ..utils.data_types import DataObject, DataArray, Dims +from ..utils.hilbert_transform import hilbert_transform from ..utils.xarray_utils import total_variance as compute_total_variance class EOF(_BaseModel): """Empirical Orthogonal Functions (EOF) analysis. - EOF analysis is more commonly referend to as principal component analysis (PCA). + More commonly known as Principal Component Analysis (PCA). Parameters ---------- n_modes: int, default=10 Number of modes to calculate. + center: bool, default=True + Whether to center the input data. standardize: bool, default=False Whether to standardize the input data. use_coslat: bool, default=False Whether to use cosine of latitude for scaling. - use_weights: bool, default=False - Whether to use weights. + sample_name: str, default="sample" + Name of the sample dimension. + feature_name: str, default="feature" + Name of the feature dimension. + compute: bool, default=True + Whether to compute the decomposition immediately. This is recommended + if the SVD result for the first ``n_modes`` can be accommodated in memory, as it + boosts computational efficiency compared to deferring the computation. solver: {"auto", "full", "randomized"}, default="auto" Solver to use for the SVD computation. solver_kwargs: dict, default={} @@ -38,94 +48,79 @@ class EOF(_BaseModel): def __init__( self, - n_modes=10, - standardize=False, - use_coslat=False, - use_weights=False, - solver="auto", - solver_kwargs={}, + n_modes: int = 2, + center: bool = True, + standardize: bool = False, + use_coslat: bool = False, + sample_name: str = "sample", + feature_name: str = "feature", + compute: bool = True, + random_state: Optional[int] = None, + solver: str = "auto", + solver_kwargs: Dict = {}, + **kwargs, ): super().__init__( n_modes=n_modes, + center=center, standardize=standardize, use_coslat=use_coslat, - use_weights=use_weights, + sample_name=sample_name, + feature_name=feature_name, + compute=compute, + random_state=random_state, solver=solver, solver_kwargs=solver_kwargs, + **kwargs, ) self.attrs.update({"model": "EOF analysis"}) - # Initialize the DataContainer to store the results - self.data: EOFDataContainer = EOFDataContainer() - - def fit(self, data: AnyDataObject, dim, weights=None): - # Preprocess the data - input_data: DataArray = self.preprocessor.fit_transform(data, dim, weights) + def _fit_algorithm(self, data: DataArray) -> Self: + sample_name = self.sample_name + feature_name = self.feature_name # Compute the total variance - total_variance = compute_total_variance(input_data, dim="sample") + total_variance = compute_total_variance(data, dim=sample_name) # Decompose the data n_modes = self._params["n_modes"] - decomposer = Decomposer( - n_modes=n_modes, solver=self._params["solver"], **self._solver_kwargs - ) - decomposer.fit(input_data, dims=("sample", "feature")) + decomposer = Decomposer(n_modes=n_modes, **self._solver_kwargs) + decomposer.fit(data, dims=(sample_name, feature_name)) singular_values = decomposer.s_ components = decomposer.V_ - scores = decomposer.U_ - - # Compute the explained variance - explained_variance = singular_values**2 / (input_data.sample.size - 1) - - # Index of the sorted explained variance - # It's already sorted, we just need to assign it to the DataContainer - # for the sake of consistency - idx_modes_sorted = explained_variance.compute().argsort()[::-1] - idx_modes_sorted.coords.update(explained_variance.coords) - - # Assign the results to the data container - self.data.set_data( - input_data=input_data, - components=components, - scores=scores, - explained_variance=explained_variance, - total_variance=total_variance, - idx_modes_sorted=idx_modes_sorted, - ) - self.data.set_attrs(self.attrs) - - def transform(self, data: AnyDataObject) -> DataArray: - """Project new unseen data onto the components (EOFs/eigenvectors). + scores = decomposer.U_ * decomposer.s_ + scores.name = "scores" + + # Compute the explained variance per mode + n_samples = data.coords[self.sample_name].size + exp_var = singular_values**2 / (n_samples - 1) + exp_var.name = "explained_variance" + + # Store the results + self.data.add(data, "input_data", allow_compute=False) + self.data.add(components, "components") + self.data.add(scores, "scores") + self.data.add(singular_values, "norms") + self.data.add(exp_var, "explained_variance") + self.data.add(total_variance, "total_variance") - Parameters - ---------- - data: AnyDataObject - Data to be transformed. - - Returns - ------- - projections: DataArray - Projections of the new data onto the components. + self.data.set_attrs(self.attrs) + return self - """ - # Preprocess the data - data_stacked: DataArray = self.preprocessor.transform(data) + def _transform_algorithm(self, data: DataObject) -> DataArray: + feature_name = self.preprocessor.feature_name - components = self.data.components - singular_values = self.data.singular_values + components = self.data["components"] # Project the data - projections = xr.dot(data_stacked, components, dims="feature") / singular_values + projections = xr.dot(data, components, dims=feature_name) projections.name = "scores" - # Unstack the projections - projections = self.preprocessor.inverse_transform_scores(projections) return projections - def inverse_transform(self, mode) -> AnyDataObject: + def _inverse_transform_algorithm(self, mode) -> DataArray: """Reconstruct the original data from transformed data. Parameters @@ -144,22 +139,17 @@ def inverse_transform(self, mode) -> AnyDataObject: """ # Reconstruct the data - svals = self.data.singular_values.sel(mode=mode) - comps = self.data.components.sel(mode=mode) - scores = self.data.scores.sel(mode=mode) * svals + comps = self.data["components"].sel(mode=mode) + scores = self.data["scores"].sel(mode=mode) reconstructed_data = xr.dot(comps.conj(), scores) reconstructed_data.name = "reconstructed_data" # Enforce real output reconstructed_data = reconstructed_data.real - # Unstack and unscale the data - reconstructed_data = self.preprocessor.inverse_transform_data( - reconstructed_data - ) return reconstructed_data - def components(self) -> AnyDataObject: + def components(self) -> DataObject: """Return the (EOF) components. The components in EOF anaylsis are the eigenvectors of the covariance/correlation matrix. @@ -171,24 +161,27 @@ def components(self) -> AnyDataObject: Components of the fitted model. """ - components = self.data.components - return self.preprocessor.inverse_transform_components(components) + return super().components() - def scores(self) -> DataArray: + def scores(self, normalized: bool = True) -> DataArray: """Return the (PC) scores. The scores in EOF anaylsis are the projection of the data matrix onto the eigenvectors of the covariance matrix (or correlation) matrix. Other names include the principal component (PC) scores or just PCs. + Parameters + ---------- + normalized : bool, default=True + Whether to normalize the scores by the L2 norm (singular values). + Returns ------- components: DataArray | Dataset | List[DataArray] Scores of the fitted model. """ - scores = self.data.scores - return self.preprocessor.inverse_transform_scores(scores) + return super().scores(normalized=normalized) def singular_values(self) -> DataArray: """Return the singular values of the Singular Value Decomposition. @@ -199,7 +192,7 @@ def singular_values(self) -> DataArray: Singular values obtained from the SVD. """ - return self.data.singular_values + return self.data["norms"] def explained_variance(self) -> DataArray: """Return explained variance. @@ -218,7 +211,7 @@ def explained_variance(self) -> DataArray: explained_variance: DataArray Explained variance. """ - return self.data.explained_variance + return self.data["explained_variance"] def explained_variance_ratio(self) -> DataArray: """Return explained variance ratio. @@ -236,13 +229,16 @@ def explained_variance_ratio(self) -> DataArray: explained_variance_ratio: DataArray Explained variance ratio. """ - return self.data.explained_variance_ratio + exp_var_ratio = self.data["explained_variance"] / self.data["total_variance"] + exp_var_ratio.attrs.update(self.data["explained_variance"].attrs) + exp_var_ratio.name = "explained_variance_ratio" + return exp_var_ratio class ComplexEOF(EOF): """Complex Empirical Orthogonal Functions (Complex EOF) analysis. - The Complex EOF analysis [1]_ [2]_ (also known as Hilbert EOF analysis) applies a Hilbert transform + The Complex EOF analysis [1]_ [2]_ [3]_ [4]_ (also known as Hilbert EOF analysis) applies a Hilbert transform to the data before performing the standard EOF analysis. The Hilbert transform is applied to each feature of the data individually. @@ -253,12 +249,6 @@ class ComplexEOF(EOF): ---------- n_modes : int Number of modes to calculate. - standardize : bool - Whether to standardize the input data. - use_coslat : bool - Whether to use cosine of latitude for scaling. - use_weights : bool - Whether to use weights. padding : str, optional Specifies the method used for padding the data prior to applying the Hilbert transform. This can help to mitigate the effect of spectral leakage. @@ -270,13 +260,33 @@ class ComplexEOF(EOF): A smaller value (e.g. 0.05) is recommended for data with high variability, while a larger value (e.g. 0.2) is recommended for data with low variability. Default is 0.2. - solver_kwargs : dict, optional - Additional keyword arguments to be passed to the SVD solver. + center: bool, default=True + Whether to center the input data. + standardize : bool + Whether to standardize the input data. + use_coslat : bool + Whether to use cosine of latitude for scaling. + sample_name: str, default="sample" + Name of the sample dimension. + feature_name: str, default="feature" + Name of the feature dimension. + compute: bool, default=True + Whether to compute the decomposition immediately. This is recommended + if the SVD result for the first ``n_modes`` can be accommodated in memory, as it + boosts computational efficiency compared to deferring the computation. + solver: {"auto", "full", "randomized"}, default="auto" + Solver to use for the SVD computation. + solver_kwargs: dict, default={} + Additional keyword arguments to be passed to the SVD solver. + solver_kwargs : dict, optional + Additional keyword arguments to be passed to the SVD solver. References ---------- - .. [1] Horel, J., 1984. Complex Principal Component Analysis: Theory and Examples. J. Climate Appl. Meteor. 23, 1660–1673. https://doi.org/10.1175/1520-0450(1984)023<1660:CPCATA>2.0.CO;2 - .. [2] Hannachi, A., Jolliffe, I., Stephenson, D., 2007. Empirical orthogonal functions and related techniques in atmospheric science: A review. International Journal of Climatology 27, 1119–1152. https://doi.org/10.1002/joc.1499 + .. [1] Rasmusson, E. M., Arkin, P. A., Chen, W.-Y. & Jalickee, J. B. Biennial variations in surface temperature over the United States as revealed by singular decomposition. Monthly Weather Review 109, 587–598 (1981). + .. [2] Barnett, T. P. Interaction of the Monsoon and Pacific Trade Wind System at Interannual Time Scales Part I: The Equatorial Zone. Monthly Weather Review 111, 756–773 (1983). + .. [3] Horel, J., 1984. Complex Principal Component Analysis: Theory and Examples. J. Climate Appl. Meteor. 23, 1660–1673. https://doi.org/10.1175/1520-0450(1984)023<1660:CPCATA>2.0.CO;2 + .. [4] Hannachi, A., Jolliffe, I., Stephenson, D., 2007. Empirical orthogonal functions and related techniques in atmospheric science: A review. International Journal of Climatology 27, 1119–1152. https://doi.org/10.1002/joc.1499 Examples -------- @@ -285,64 +295,84 @@ class ComplexEOF(EOF): """ - def __init__(self, padding="exp", decay_factor=0.2, **kwargs): - super().__init__(**kwargs) + def __init__( + self, + n_modes: int = 2, + padding: str = "exp", + decay_factor: float = 0.2, + center: bool = True, + standardize: bool = False, + use_coslat: bool = False, + sample_name: str = "sample", + feature_name: str = "feature", + compute: bool = True, + random_state: Optional[int] = None, + solver: str = "auto", + solver_kwargs: Dict = {}, + ): + super().__init__( + n_modes=n_modes, + center=center, + standardize=standardize, + use_coslat=use_coslat, + sample_name=sample_name, + feature_name=feature_name, + compute=compute, + random_state=random_state, + solver=solver, + solver_kwargs=solver_kwargs, + ) self.attrs.update({"model": "Complex EOF analysis"}) self._params.update({"padding": padding, "decay_factor": decay_factor}) - # Initialize the DataContainer to store the results - self.data: ComplexEOFDataContainer = ComplexEOFDataContainer() - - def fit(self, data: AnyDataObject, dim, weights=None): - # Preprocess the data - input_data: DataArray = self.preprocessor.fit_transform(data, dim, weights) + def _fit_algorithm(self, data: DataArray) -> Self: + sample_name = self.sample_name + feature_name = self.feature_name # Apply hilbert transform: padding = self._params["padding"] decay_factor = self._params["decay_factor"] - input_data = hilbert_transform( - input_data, dim="sample", padding=padding, decay_factor=decay_factor + data = hilbert_transform( + data, + dims=(sample_name, feature_name), + padding=padding, + decay_factor=decay_factor, ) # Compute the total variance - total_variance = compute_total_variance(input_data, dim="sample") + total_variance = compute_total_variance(data, dim=sample_name) # Decompose the complex data n_modes = self._params["n_modes"] - decomposer = Decomposer( - n_modes=n_modes, solver=self._params["solver"], **self._solver_kwargs - ) - decomposer.fit(input_data) + decomposer = Decomposer(n_modes=n_modes, **self._solver_kwargs) + decomposer.fit(data) singular_values = decomposer.s_ components = decomposer.V_ - scores = decomposer.U_ - - # Compute the explained variance - explained_variance = singular_values**2 / (input_data.sample.size - 1) - - # Index of the sorted explained variance - # It's already sorted, we just need to assign it to the DataContainer - # for the sake of consistency - idx_modes_sorted = explained_variance.compute().argsort()[::-1] - idx_modes_sorted.coords.update(explained_variance.coords) - - self.data.set_data( - input_data=input_data, - components=components, - scores=scores, - explained_variance=explained_variance, - total_variance=total_variance, - idx_modes_sorted=idx_modes_sorted, - ) + scores = decomposer.U_ * decomposer.s_ + + # Compute the explained variance per mode + n_samples = data.coords[self.sample_name].size + exp_var = singular_values**2 / (n_samples - 1) + exp_var.name = "explained_variance" + + # Store the results + self.data.add(data, "input_data", allow_compute=False) + self.data.add(components, "components") + self.data.add(scores, "scores") + self.data.add(singular_values, "norms") + self.data.add(exp_var, "explained_variance") + self.data.add(total_variance, "total_variance") + # Assign analysis-relevant meta data to the results self.data.set_attrs(self.attrs) + return self - def transform(self, data: AnyDataObject) -> DataArray: - raise NotImplementedError("ComplexEOF does not support transform method.") + def _transform_algorithm(self, data: DataArray) -> DataArray: + raise NotImplementedError("Complex EOF does not support transform method.") - def components_amplitude(self) -> AnyDataObject: + def components_amplitude(self) -> DataObject: """Return the amplitude of the (EOF) components. The amplitude of the components are defined as @@ -359,10 +389,11 @@ def components_amplitude(self) -> AnyDataObject: Amplitude of the components of the fitted model. """ - amplitudes = self.data.components_amplitude + amplitudes = abs(self.data["components"]) + amplitudes.name = "components_amplitude" return self.preprocessor.inverse_transform_components(amplitudes) - def components_phase(self) -> AnyDataObject: + def components_phase(self) -> DataObject: """Return the phase of the (EOF) components. The phase of the components are defined as @@ -379,10 +410,12 @@ def components_phase(self) -> AnyDataObject: Phase of the components of the fitted model. """ - phases = self.data.components_phase - return self.preprocessor.inverse_transform_components(phases) + comps = self.data["components"] + comp_phase = xr.apply_ufunc(np.angle, comps, dask="allowed", keep_attrs=True) + comp_phase.name = "components_phase" + return self.preprocessor.inverse_transform_components(comp_phase) - def scores_amplitude(self) -> DataArray: + def scores_amplitude(self, normalized=True) -> DataArray: """Return the amplitude of the (PC) scores. The amplitude of the scores are defined as @@ -393,13 +426,23 @@ def scores_amplitude(self) -> DataArray: where :math:`S_{ij}` is the :math:`i`-th entry of the :math:`j`-th score and :math:`|\\cdot|` denotes the absolute value. + Parameters + ---------- + normalized : bool, default=True + Whether to normalize the scores by the singular values. + Returns ------- scores_amplitude: DataArray | Dataset | List[DataArray] Amplitude of the scores of the fitted model. """ - amplitudes = self.data.scores_amplitude + scores = self.data["scores"].copy() + if normalized: + scores = scores / self.data["norms"] + + amplitudes = abs(scores) + amplitudes.name = "scores_amplitude" return self.preprocessor.inverse_transform_scores(amplitudes) def scores_phase(self) -> DataArray: @@ -419,5 +462,7 @@ def scores_phase(self) -> DataArray: Phase of the scores of the fitted model. """ - phases = self.data.scores_phase + scores = self.data["scores"] + phases = xr.apply_ufunc(np.angle, scores, dask="allowed", keep_attrs=True) + phases.name = "scores_phase" return self.preprocessor.inverse_transform_scores(phases) diff --git a/xeofs/models/eof_rotator.py b/xeofs/models/eof_rotator.py index ab5d0f5..fa97d3c 100644 --- a/xeofs/models/eof_rotator.py +++ b/xeofs/models/eof_rotator.py @@ -1,23 +1,14 @@ from datetime import datetime import numpy as np import xarray as xr -from dask.diagnostics.progress import ProgressBar -from typing import List +from typing_extensions import Self from .eof import EOF, ComplexEOF -from ..data_container.eof_rotator_data_container import ( - EOFRotatorDataContainer, - ComplexEOFRotatorDataContainer, -) - +from ..data_container import DataContainer from ..utils.rotation import promax -from ..utils.data_types import DataArray, AnyDataObject - -from typing import TypeVar +from ..utils.data_types import DataArray from .._version import __version__ -Model = TypeVar("Model", EOF, ComplexEOF) - class EOFRotator(EOF): """Rotate a solution obtained from ``xe.models.EOF``. @@ -30,7 +21,7 @@ class EOFRotator(EOF): Parameters ---------- - n_modes : int, default=10 + n_modes : int, default=2 Specify the number of modes to be rotated. power : int, default=1 Set the power for the Promax rotation. A ``power`` value of 1 results @@ -41,6 +32,8 @@ class EOFRotator(EOF): rtol : float, default=1e-8 Define the relative tolerance required to achieve convergence and terminate the iterative process. + compute: bool, default=True + Whether to compute the decomposition immediately. References ---------- @@ -58,10 +51,11 @@ class EOFRotator(EOF): def __init__( self, - n_modes: int = 10, + n_modes: int = 2, power: int = 1, max_iter: int = 1000, rtol: float = 1e-8, + compute: bool = True, ): # Define model parameters self._params = { @@ -69,6 +63,7 @@ def __init__( "power": power, "max_iter": max_iter, "rtol": rtol, + "compute": compute, } # Define analysis-relevant meta data @@ -82,12 +77,25 @@ def __init__( } ) - # Initialize the DataContainer to store the results - self.data: EOFRotatorDataContainer = EOFRotatorDataContainer() + # Define data container + self.data = DataContainer() + + def fit(self, model) -> Self: + """Rotate the solution obtained from ``xe.models.EOF``. + + Parameters + ---------- + model : ``xe.models.EOF`` + The EOF model to be rotated. + + """ + return self._fit_algorithm(model) - def fit(self, model): + def _fit_algorithm(self, model) -> Self: self.model = model self.preprocessor = model.preprocessor + self.sample_name = model.sample_name + self.feature_name = model.feature_name n_modes = self._params.get("n_modes") power = self._params.get("power") @@ -95,24 +103,19 @@ def fit(self, model): rtol = self._params.get("rtol") # Select modes to rotate - components = model.data.components.sel(mode=slice(1, n_modes)) - expvar = model.data.explained_variance.sel(mode=slice(1, n_modes)) + components = model.data["components"].sel(mode=slice(1, n_modes)) + expvar = model.explained_variance().sel(mode=slice(1, n_modes)) # Rotate loadings loadings = components * np.sqrt(expvar) - rot_loadings, rot_matrix, phi_matrix = xr.apply_ufunc( - promax, + promax_kwargs = {"power": power, "max_iter": max_iter, "rtol": rtol} + rot_loadings, rot_matrix, phi_matrix = promax( loadings, - power, - input_core_dims=[["feature", "mode"], []], - output_core_dims=[ - ["feature", "mode"], - ["mode_m", "mode_n"], - ["mode_m", "mode_n"], - ], - kwargs={"max_iter": max_iter, "rtol": rtol}, - dask="allowed", + feature_dim=self.feature_name, + compute=self._params["compute"], + **promax_kwargs ) + # Assign coordinates to the rotation/correlation matrices rot_matrix = rot_matrix.assign_coords( mode_m=np.arange(1, rot_matrix.mode_m.size + 1), @@ -124,7 +127,7 @@ def fit(self, model): ) # Reorder according to variance - expvar = (abs(rot_loadings) ** 2).sum("feature") + expvar = (abs(rot_loadings) ** 2).sum(self.feature_name) # NOTE: For delayed objects, the index must be computed. # NOTE: The index must be computed before sorting since argsort is not (yet) implemented in dask idx_sort = expvar.compute().argsort()[::-1] @@ -138,8 +141,16 @@ def fit(self, model): # Normalize loadings rot_components = rot_loadings / np.sqrt(expvar) + # Compute "pseudo" norms + n_samples = model.data["input_data"].coords[self.sample_name].size + norms = (expvar * (n_samples - 1)) ** 0.5 + norms.name = "singular_values" + + # Get unrotated, normalized scores + svals = model.data["norms"].sel(mode=slice(1, n_modes)) + scores = model.data["scores"].sel(mode=slice(1, n_modes)) + scores = scores / svals # Rotate scores - scores = model.data.scores.sel(mode=slice(1, n_modes)) RinvT = self._compute_rot_mat_inv_trans( rot_matrix, input_dims=("mode_m", "mode_n") ) @@ -151,8 +162,11 @@ def fit(self, model): # Reorder according to variance scores = scores.isel(mode=idx_sort.values).assign_coords(mode=scores.mode) - # Ensure consitent signs for deterministic output - idx_max_value = abs(rot_loadings).argmax("feature").compute() + # Scale scores by "pseudo" norms + scores = scores * norms + + # Ensure consistent signs for deterministic output + idx_max_value = abs(rot_loadings).argmax(self.feature_name).compute() modes_sign = xr.apply_ufunc( np.sign, rot_loadings.isel(feature=idx_max_value), dask="allowed" ) @@ -163,39 +177,36 @@ def fit(self, model): rot_components = rot_components * modes_sign scores = scores * modes_sign - # Create the data container - self.data.set_data( - input_data=model.data.input_data, - components=rot_components, - scores=scores, - explained_variance=expvar, - total_variance=model.data.total_variance, - idx_modes_sorted=idx_sort, - rotation_matrix=rot_matrix, - phi_matrix=phi_matrix, - modes_sign=modes_sign, - ) + # Store the results + self.data.add(model.data["input_data"], "input_data", allow_compute=False) + self.data.add(rot_components, "components") + self.data.add(scores, "scores") + self.data.add(norms, "norms") + self.data.add(expvar, "explained_variance") + self.data.add(model.data["total_variance"], "total_variance") + self.data.add(idx_sort, "idx_modes_sorted") + self.data.add(rot_matrix, "rotation_matrix") + self.data.add(phi_matrix, "phi_matrix") + self.data.add(modes_sign, "modes_sign") + # Assign analysis-relevant meta data self.data.set_attrs(self.attrs) + return self - def transform(self, data: AnyDataObject) -> DataArray: + def _transform_algorithm(self, data: DataArray) -> DataArray: n_modes = self._params["n_modes"] - svals = self.model.data.singular_values.sel( - mode=slice(1, self._params["n_modes"]) - ) + svals = self.model.singular_values().sel(mode=slice(1, self._params["n_modes"])) + pseudo_norms = self.data["norms"] # Select the (non-rotated) singular vectors of the first dataset - components = self.model.data.components.sel(mode=slice(1, n_modes)) + components = self.model.data["components"].sel(mode=slice(1, n_modes)) - # Preprocess the data - da: DataArray = self.preprocessor.transform(data) - - # Compute non-rotated scores by project the data onto non-rotated components - projections = xr.dot(da, components) / svals + # Compute non-rotated scores by projecting the data onto non-rotated components + projections = xr.dot(data, components) / svals projections.name = "scores" # Rotate the scores - R = self.data.rotation_matrix + R = self.data["rotation_matrix"] RinvT = self._compute_rot_mat_inv_trans(R, input_dims=("mode_m", "mode_n")) projections = projections.rename({"mode": "mode_m"}) RinvT = RinvT.rename({"mode_n": "mode"}) @@ -203,16 +214,22 @@ def transform(self, data: AnyDataObject) -> DataArray: # Reorder according to variance # this must be done in one line: i) select modes according to their variance, ii) replace coords with modes from 1 ... n projections = projections.isel( - mode=self.data.idx_modes_sorted.values + mode=self.data["idx_modes_sorted"].values ).assign_coords(mode=projections.mode) + # Scale scores by "pseudo" norms + projections = projections * pseudo_norms + # Adapt the sign of the scores - projections = projections * self.data.modes_sign + projections = projections * self.data["modes_sign"] - # Unstack the projections - projections = self.preprocessor.inverse_transform_scores(projections) return projections + def fit_transform(self, model) -> DataArray: + raise NotImplementedError( + "The fit_transform method is not implemented for the EOFRotator class." + ) + def _compute_rot_mat_inv_trans(self, rotation_matrix, input_dims) -> DataArray: """Compute the inverse transpose of the rotation matrix. @@ -250,7 +267,7 @@ class ComplexEOFRotator(EOFRotator, ComplexEOF): Parameters ---------- - n_modes : int, default=10 + n_modes : int, default=2 Specify the number of modes to be rotated. power : int, default=1 Set the power for the Promax rotation. A ``power`` value of 1 results @@ -261,6 +278,8 @@ class ComplexEOFRotator(EOFRotator, ComplexEOF): rtol : float, default=1e-8 Define the relative tolerance required to achieve convergence and terminate the iterative process. + compute: bool, default=True + Whether to compute the decomposition immediately. References ---------- @@ -279,16 +298,23 @@ class ComplexEOFRotator(EOFRotator, ComplexEOF): """ - def __init__(self, **kwargs): - super().__init__(**kwargs) + def __init__( + self, + n_modes: int = 2, + power: int = 1, + max_iter: int = 1000, + rtol: float = 1e-8, + compute: bool = True, + ): + super().__init__( + n_modes=n_modes, power=power, max_iter=max_iter, rtol=rtol, compute=compute + ) self.attrs.update({"model": "Rotated Complex EOF analysis"}) - # Initialize the DataContainer to store the results - self.data: ComplexEOFRotatorDataContainer = ComplexEOFRotatorDataContainer() - - def transform(self, data: AnyDataObject): - # Here we make use of the Method Resolution Order (MRO) to call the - # transform method of the first class in the MRO after `EOFRotator` - # that has a transform method. In this case it will be `ComplexEOF`, - # which will raise an error because it does not have a transform method. - super(EOFRotator, self).transform(data) + def _transform_algorithm(self, data: DataArray) -> DataArray: + # Here we leverage the Method Resolution Order (MRO) to invoke the + # transform method of the first class in the MRO after EOFRotator that + # has a transform method. In this case, it will be ComplexEOF. However, + # please note that `transform` is not implemented for ComplexEOF, so this + # line of code will actually raise an error. + return super(EOFRotator, self)._transform_algorithm(data) diff --git a/xeofs/models/gwpca.py b/xeofs/models/gwpca.py new file mode 100644 index 0000000..1293077 --- /dev/null +++ b/xeofs/models/gwpca.py @@ -0,0 +1,436 @@ +from typing import Sequence, Hashable, Optional, Callable +from typing_extensions import Self + +from sklearn.utils.extmath import randomized_svd + +from xeofs.utils.data_types import DataArray +from xeofs.utils.data_types import Data +from ._base_model import _BaseModel +from ..utils.sanity_checks import validate_input_type +from ..utils.xarray_utils import convert_to_dim_type +from ..utils.constants import ( + VALID_CARTESIAN_X_NAMES, + VALID_CARTESIAN_Y_NAMES, + VALID_LATITUDE_NAMES, + VALID_LONGITUDE_NAMES, +) +from .eof import EOF +import numpy as np +import xarray as xr + +from tqdm import trange +import numba +from numba import prange +from ..utils.distance_metrics import distance_nb, VALID_METRICS +from ..utils.kernels import kernel_weights_nb, VALID_KERNELS + + +class GWPCA(_BaseModel): + """Geographically weighted PCA (GWPCA). + + GWPCA [1]_ uses a geographically weighted approach to perform PCA for + each observation in the dataset based on its local neighbors. + + The neighbors for each observation are determined based on the provided + bandwidth and metric. Each neighbor is weighted based on its distance from + the observation using the provided kernel function. + + Parameters + ---------- + n_modes: int + Number of modes to calculate. + bandwidth: float + Bandwidth of the kernel function. Must be > 0. + metric: str, default="haversine" + Distance metric to use. Great circle distance (`haversine`) is always expressed in kilometers. + All other distance metrics are reported in the unit of the input data. + See scipy.spatial.distance.cdist for a list of available metrics. + kernel: str, default="bisquare" + Kernel function to use. Must be one of ['bisquare', 'gaussian', 'exponential']. + center: bool, default=True + If True, the data is centered by subtracting the mean (feature-wise). + standardize: bool, default=False + If True, the data is divided by the standard deviation (feature-wise). + use_coslat: bool, default=False + If True, the data is weighted by the square root of cosine of latitudes. + sample_name: str, default="sample" + Name of the sample dimension. + feature_name: str, default="feature" + Name of the feature dimension. + + Attributes + ---------- + bandwidth: float + Bandwidth of the kernel function. + metric: str + Distance metric to use. + kernel: str + Kernel function to use. + + Methods: + -------- + fit(X) : Fit the model with input data. + + explained_variance() : Return the explained variance of the local components. + + explained_variance_ratio() : Return the explained variance ratio of the local components. + + largest_locally_weighted_components() : Return the largest locally weighted components. + + + Notes + ----- + GWPCA is computationally expensive since it performs PCA for each sample. This implementation leverages + `numba` to speed up the computation on CPUs. However, for moderate to large datasets, this won't be sufficient. + Currently, GPU support is not implemented. If you're dataset is too large to be processed on a CPU, consider + using the R package `GWmodel` [2]_, which provides a GPU implementation of GWPCA. + + References + ---------- + .. [1] Harris, P., Brunsdon, C. & Charlton, M. Geographically weighted principal components analysis. International Journal of Geographical Information Science 25, 1717–1736 (2011). + .. [2] https://cran.r-project.org/web/packages/GWmodel/index.html + + + """ + + def __init__( + self, + n_modes: int, + bandwidth: float, + metric: str = "haversine", + kernel: str = "bisquare", + center: bool = True, + standardize: bool = False, + use_coslat: bool = False, + sample_name: str = "sample", + feature_name: str = "feature", + ): + super().__init__( + n_modes, + center=center, + standardize=standardize, + use_coslat=use_coslat, + sample_name=sample_name, + feature_name=feature_name, + ) + + self.attrs.update({"model": "GWPCA"}) + + if not kernel in VALID_KERNELS: + raise ValueError( + f"Invalid kernel: {kernel}. Must be one of {VALID_KERNELS}." + ) + + if not metric in VALID_METRICS: + raise ValueError( + f"Invalid metric: {metric}. Must be one of {VALID_METRICS}." + ) + + if bandwidth <= 0: + raise ValueError(f"Invalid bandwidth: {bandwidth}. Must be > 0.") + + self.bandwidth = bandwidth + self.metric = metric + self.kernel = kernel + + def _fit_algorithm(self, X: DataArray) -> Self: + # Convert Dask arrays + if not isinstance(X.data, np.ndarray): + print( + "Warning: GWPCA currently does not support Dask arrays. Data is being loaded into memory." + ) + X = X.compute() + # 1. Get sample coordinates + valid_x_names = VALID_CARTESIAN_X_NAMES + VALID_LONGITUDE_NAMES + valid_y_names = VALID_CARTESIAN_Y_NAMES + VALID_LATITUDE_NAMES + n_sample_dims = len(self.sample_dims) + if n_sample_dims == 1: + indexes = self.preprocessor.preconverter.transformers[0].original_indexes + sample_dims = self.preprocessor.renamer.transformers[0].sample_dims_after + xy = None + for dim in sample_dims: + keys = [k for k in indexes[dim].coords.keys()] + x_found = any([k.lower() in valid_x_names for k in keys]) + y_found = any([k.lower() in valid_y_names for k in keys]) + if x_found and y_found: + xy = np.asarray([*indexes[dim].values]) + break + if xy is None: + raise ValueError("Cannot find sample coordinates.") + elif n_sample_dims == 2: + indexes = self.preprocessor.postconverter.transformers[0].original_indexes + xy = np.asarray([*indexes[self.sample_name].values]) + + else: + raise ValueError( + "GWPCA requires number of sample dimensions to be <= 2, but got {n_sample_dims}." + ) + + # 2. Remove NaN samples from sample indexes + is_no_nan_sample = self.preprocessor.sanitizer.transformers[0].is_valid_sample + xy = xr.DataArray( + xy, + dims=[self.sample_name, "xy"], + coords={ + self.sample_name: is_no_nan_sample[self.sample_name], + "xy": ["x", "y"], + }, + name="index", + ) + + xy = xy[is_no_nan_sample] + + # Iterate over all samples + kwargs = { + "n_modes": self.n_modes, + "metric": self.metric, + "kernel": self.kernel, + "bandwidth": self.bandwidth, + } + components, exp_var, tot_var = xr.apply_ufunc( + _local_pcas, + X, + xy, + input_core_dims=[ + [self.sample_name, self.feature_name], + [self.sample_name, "xy"], + ], + output_core_dims=[ + [self.sample_name, self.feature_name, "mode"], + [self.sample_name, "mode"], + [self.sample_name], + ], + kwargs=kwargs, + dask="forbidden", + ) + components = components.assign_coords( + { + self.sample_name: X[self.sample_name], + self.feature_name: X[self.feature_name], + "mode": np.arange(1, self.n_modes + 1), + } + ) + exp_var = exp_var.assign_coords( + { + self.sample_name: X[self.sample_name], + "mode": np.arange(1, self.n_modes + 1), + } + ) + tot_var = tot_var.assign_coords({self.sample_name: X[self.sample_name]}) + + exp_var_ratio = exp_var / tot_var + + # self.data.add(X, "input_data") + self.data.add(components, "components") + self.data.add(exp_var, "explained_variance") + self.data.add(exp_var_ratio, "explained_variance_ratio") + + self.data.set_attrs(self.attrs) + + return self + + def explained_variance(self): + expvar = self.data["explained_variance"] + return self.preprocessor.inverse_transform_scores(expvar) + + def explained_variance_ratio(self): + expvar = self.data["explained_variance_ratio"] + return self.preprocessor.inverse_transform_scores(expvar) + + def largest_locally_weighted_components(self): + comps = self.data["components"] + idx_max = abs(comps).argmax(self.feature_name) + input_features = self.preprocessor.stacker.transformers[0].coords_out["feature"] + llwc = input_features[idx_max].drop_vars(self.feature_name) + llwc.name = "largest_locally_weighted_components" + return self.preprocessor.inverse_transform_scores(llwc) + + def scores(self): + raise NotImplementedError("GWPCA does not support scores() yet.") + + def _transform_algorithm(self, data: DataArray) -> DataArray: + raise NotImplementedError("GWPCA does not support transform() yet.") + + def _inverse_transform_algorithm(self, data: DataArray) -> DataArray: + raise NotImplementedError("GWPCA does not support inverse_transform() yet.") + + +# Additional utility functions for local PCA +# ============================================================================= + + +@numba.njit(fastmath=True, parallel=True) +def _local_pcas(X, xy, n_modes, metric, kernel, bandwidth): + """Perform local PCA on each sample. + + Parameters + ---------- + X: ndarray + Input data with shape (n_samples, n_features) + xy: ndarray + Sample coordinates with shape (n_samples, 2) + n_modes: int + Number of modes to calculate. + metric: str + Distance metric to use. Great circle distance (`haversine`) is always expressed in kilometers. + All other distance metrics are reported in the unit of the input data. + See scipy.spatial.distance.cdist for a list of available metrics. + kernel: str + Kernel function to use. Must be one of ['bisquare', 'gaussian', 'exponential']. + bandwidth: float + Bandwidth of the kernel function. + + Returns + ------- + ndarray + Array of local components with shape (n_samples, n_features, n_modes) + ndarray + Array of local explained variance with shape (n_samples, n_modes) + ndarray + Array of total variance with shape (n_samples,) + + """ + n_samples = X.shape[0] + n_features = X.shape[1] + Vs = np.empty((n_samples, n_features, n_modes)) + exp_var = np.empty((n_samples, n_modes)) + tot_var = np.empty(n_samples) + for i in prange(n_samples): + dist = distance_nb(xy, xy[i], metric=metric) + weights = kernel_weights_nb(dist, bandwidth, kernel) + valid_data = weights > 0 + + weights = weights[valid_data] + x = X[valid_data] + + wmean = _wmean_axis0(x, weights) + x -= wmean + + sqrt_weights = np.sqrt(weights) + x = _weigh_columns(x, sqrt_weights) + + Ui, si, ViT = np.linalg.svd(x, full_matrices=False) + # Renormalize singular values + si = si**2 / weights.sum() + ti = si.sum() + + si = si[:n_modes] + ViT = ViT[:n_modes] + Vi = ViT.T + + Vs[i] = Vi + exp_var[i, : len(si)] = si + tot_var[i] = ti + + return Vs, exp_var, tot_var + + +@numba.njit(fastmath=True) +def _wmean_axis0(X, weights): + """Compute weighted mean along axis 0. + + Numba version of np.average. Note that np.average is supported by Numba, + but is restricted to `X` and `weights` having the same shape. + """ + wmean = np.empty(X.shape[1]) + for i in prange(X.shape[1]): + wmean[i] = np.average(X[:, i], weights=weights) + return wmean + + +@numba.njit(fastmath=True) +def _weigh_columns(x, weights): + """Weigh columns of x by weights. + + Numba version of broadcasting. + + Parameters + ---------- + x: ndarray + Input data with shape (n_samples, n_features) + weights: ndarray + Weights with shape (n_samples,) + + Returns + ------- + x_weighted: ndarray + Weighted data with shape (n_samples, n_features) + """ + x_weighted = np.zeros_like(x) + for i in range(x.shape[1]): + x_weighted[:, i] = x[:, i] * weights + return x_weighted + + +@numba.guvectorize( + [ + ( + numba.float32[:, :], + numba.float32[:, :], + numba.float32[:], + numba.int32[:], + numba.int32, + numba.float32, + numba.int32, + numba.float32[:, :], + numba.float32[:], + numba.float32, + ) + ], + # In order to specify the output dimension which has not been defined in the input dimensions + # one has to use a dummy variable (see Numba #2797 https://github.com/numba/numba/issues/2797) + "(n,m),(n,o),(o),(n_out),(),(),()->(m,n_out),(n_out),()", +) +def local_pca_vectorized( + data, XY, xy, n_out, metric, bandwidth, kernel, comps, expvar, totvar +): + """Perform local PCA + + Numba vectorized version of local_pca. + + Parameters + ---------- + data: ndarray + Input data with shape (n_samples, n_features) + XY: ndarray + Sample coordinates with shape (n_samples, 2) + xy: ndarray + Coordinates of the sample to perform PCA on with shape (2,) + n_out: ndarray + Number of modes to calculate. (see comment above; workaround for Numba #2797) + metric: int + Numba only accepts int/floats; so metric str has to be converted first e.g. by a simple dictionary (not implemented yet) + see Numba #4404 (https://github.com/numba/numba/issues/4404) + bandwidth: float + Bandwidth of the kernel function. + kernel: int + Numba only accepts int/floats; so kernel str has to be converted first e.g. by a simple dictionary (not implemented yet) + see Numba #4404 (https://github.com/numba/numba/issues/4404) + comps: ndarray + Array of local components with shape (n_features, n_modes) + expvar: ndarray + Array of local explained variance with shape (n_modes) + totvar: ndarray + Array of total variance with shape (1) + + + """ + distance = distance_nb(XY, xy, metric=metric) + weights = kernel_weights_nb(distance, bandwidth, kernel) + is_positive_weight = weights > 0 + X = data[is_positive_weight] + weights = weights[is_positive_weight] + + wmean = _wmean_axis0(X, weights) + X -= wmean + + sqrt_weights = np.sqrt(weights) + X = _weigh_columns(X, sqrt_weights) + + U, s, Vt = np.linalg.svd(X, full_matrices=False) + Vt = Vt[: n_out.shape[0], :] + lbda = s**2 / weights.sum() + for i in range(n_out.shape[0]): + expvar[i] = lbda[i] + comps[:, i] = Vt[i, :] + totvar = lbda.sum() diff --git a/xeofs/models/mca.py b/xeofs/models/mca.py index 42f1037..f4e26b6 100644 --- a/xeofs/models/mca.py +++ b/xeofs/models/mca.py @@ -1,18 +1,14 @@ -from typing import Tuple +from typing import Tuple, Optional, Sequence, Dict +from typing_extensions import Self import numpy as np import xarray as xr -from dask.diagnostics.progress import ProgressBar from ._base_cross_model import _BaseCrossModel from .decomposer import Decomposer -from ..utils.data_types import AnyDataObject, DataArray -from ..data_container.mca_data_container import ( - MCADataContainer, - ComplexMCADataContainer, -) +from ..utils.data_types import DataObject, DataArray from ..utils.statistics import pearson_correlation -from ..utils.xarray_utils import hilbert_transform +from ..utils.hilbert_transform import hilbert_transform from ..utils.dimension_renamer import DimensionRenamer @@ -23,18 +19,14 @@ class MCA(_BaseCrossModel): Parameters ---------- - n_modes: int, default=10 + n_modes: int, default=2 Number of modes to calculate. + center: bool, default=True + Whether to center the input data. standardize: bool, default=False Whether to standardize the input data. use_coslat: bool, default=False Whether to use cosine of latitude for scaling. - use_weights: bool, default=False - Whether to use additional weights. - solver: {"auto", "full", "randomized"}, default="auto" - Solver to use for the SVD computation. - solver_kwargs: dict, default={} - Additional keyword arguments passed to the SVD solver. n_pca_modes: int, default=None The number of principal components to retain during the PCA preprocessing step applied to both data sets prior to executing MCA. @@ -43,6 +35,18 @@ class MCA(_BaseCrossModel): only the specified number of principal components. This reduction in dimensionality can be especially beneficial when dealing with high-dimensional data, where computing the cross-covariance matrix can become computationally intensive or in scenarios where multicollinearity is a concern. + compute: bool, default=True + Whether to compute the decomposition immediately. + sample_name: str, default="sample" + Name of the new sample dimension. + feature_name: str, default="feature" + Name of the new feature dimension. + solver: {"auto", "full", "randomized"}, default="auto" + Solver to use for the SVD computation. + random_state: int, default=None + Seed for the random number generator. + solver_kwargs: dict, default={} + Additional keyword arguments passed to the SVD solver. Notes ----- @@ -62,56 +66,68 @@ class MCA(_BaseCrossModel): """ - def __init__(self, solver_kwargs={}, **kwargs): - super().__init__(solver_kwargs=solver_kwargs, **kwargs) + def __init__( + self, + n_modes: int = 2, + center: bool = True, + standardize: bool = False, + use_coslat: bool = False, + n_pca_modes: Optional[int] = None, + compute: bool = True, + sample_name: str = "sample", + feature_name: str = "feature", + solver: str = "auto", + random_state: Optional[int] = None, + solver_kwargs: Dict = {}, + ): + super().__init__( + n_modes=n_modes, + center=center, + standardize=standardize, + use_coslat=use_coslat, + n_pca_modes=n_pca_modes, + compute=compute, + sample_name=sample_name, + feature_name=feature_name, + solver=solver, + random_state=random_state, + solver_kwargs=solver_kwargs, + ) self.attrs.update({"model": "MCA"}) - # Initialize the DataContainer to store the results - self.data: MCADataContainer = MCADataContainer() - def _compute_cross_covariance_matrix(self, X1, X2): """Compute the cross-covariance matrix of two data objects. Note: It is assumed that the data objects are centered. """ - if X1.sample.size != X2.sample.size: + sample_name = self.sample_name + n_samples = X1.coords[sample_name].size + if X1.coords[sample_name].size != X2.coords[sample_name].size: err_msg = "The two data objects must have the same number of samples." raise ValueError(err_msg) - return xr.dot(X1.conj(), X2, dims="sample") / (X1.sample.size - 1) + return xr.dot(X1.conj(), X2, dims=sample_name) / (n_samples - 1) - def fit( + def _fit_algorithm( self, - data1: AnyDataObject, - data2: AnyDataObject, - dim, - weights1=None, - weights2=None, - ): - # Preprocess the data - data1_processed: DataArray = self.preprocessor1.fit_transform( - data1, dim, weights1 - ) - data2_processed: DataArray = self.preprocessor2.fit_transform( - data2, dim, weights2 - ) + data1: DataArray, + data2: DataArray, + ) -> Self: + sample_name = self.sample_name + feature_name = self.feature_name # Initialize the SVD decomposer - decomposer = Decomposer( - n_modes=self._params["n_modes"], - solver=self._params["solver"], - **self._solver_kwargs, - ) + decomposer = Decomposer(n_modes=self._params["n_modes"], **self._solver_kwargs) # Perform SVD on PCA-reduced data if (self.pca1 is not None) and (self.pca2 is not None): # Fit the PCA models - self.pca1.fit(data1_processed, "sample") - self.pca2.fit(data2_processed, "sample") + self.pca1.fit(data1, dim=sample_name) + self.pca2.fit(data2, dim=sample_name) # Get the PCA scores - pca_scores1 = self.pca1.data.scores * self.pca1.data.singular_values - pca_scores2 = self.pca2.data.scores * self.pca2.data.singular_values + pca_scores1 = self.pca1.data["scores"] * self.pca1.singular_values() + pca_scores2 = self.pca2.data["scores"] * self.pca2.singular_values() # Compute the cross-covariance matrix of the PCA scores pca_scores1 = pca_scores1.rename({"mode": "feature1"}) pca_scores2 = pca_scores2.rename({"mode": "feature2"}) @@ -122,8 +138,9 @@ def fit( V1 = decomposer.U_ # left singular vectors (feature1 x mode) V2 = decomposer.V_ # right singular vectors (feature2 x mode) - V1pre = self.pca1.data.components # left PCA eigenvectors (feature x mode) - V2pre = self.pca2.data.components # right PCA eigenvectors (feature x mode) + # left and right PCA eigenvectors (feature x mode) + V1pre = self.pca1.data["components"] + V2pre = self.pca2.data["components"] # Compute the singular vectors V1pre = V1pre.rename({"mode": "feature1"}) @@ -134,14 +151,12 @@ def fit( # Perform SVD directly on data else: # Rename feature and associated dimensions of data objects to avoid index conflicts - dim_renamer1 = DimensionRenamer("feature", "1") - dim_renamer2 = DimensionRenamer("feature", "2") - data1_processed_temp = dim_renamer1.fit_transform(data1_processed) - data2_processed_temp = dim_renamer2.fit_transform(data2_processed) + dim_renamer1 = DimensionRenamer(feature_name, "1") + dim_renamer2 = DimensionRenamer(feature_name, "2") + data1_temp = dim_renamer1.fit_transform(data1) + data2_temp = dim_renamer2.fit_transform(data2) # Compute the cross-covariance matrix - cov_matrix = self._compute_cross_covariance_matrix( - data1_processed_temp, data2_processed_temp - ) + cov_matrix = self._compute_cross_covariance_matrix(data1_temp, data2_temp) # Perform the SVD decomposer.fit(cov_matrix, dims=("feature1", "feature2")) @@ -167,33 +182,37 @@ def fit( idx_sorted_modes.coords.update(squared_covariance.coords) # Project the data onto the singular vectors - scores1 = xr.dot(data1_processed, singular_vectors1, dims="feature") / norm1 - scores2 = xr.dot(data2_processed, singular_vectors2, dims="feature") / norm2 - - self.data.set_data( - input_data1=data1_processed, - input_data2=data2_processed, - components1=singular_vectors1, - components2=singular_vectors2, - scores1=scores1, - scores2=scores2, - squared_covariance=squared_covariance, - total_squared_covariance=total_squared_covariance, - idx_modes_sorted=idx_sorted_modes, - norm1=norm1, - norm2=norm2, - ) + scores1 = xr.dot(data1, singular_vectors1, dims=feature_name) / norm1 + scores2 = xr.dot(data2, singular_vectors2, dims=feature_name) / norm2 + + self.data.add(name="input_data1", data=data1, allow_compute=False) + self.data.add(name="input_data2", data=data2, allow_compute=False) + self.data.add(name="components1", data=singular_vectors1) + self.data.add(name="components2", data=singular_vectors2) + self.data.add(name="scores1", data=scores1) + self.data.add(name="scores2", data=scores2) + self.data.add(name="squared_covariance", data=squared_covariance) + self.data.add(name="total_squared_covariance", data=total_squared_covariance) + self.data.add(name="idx_modes_sorted", data=idx_sorted_modes) + self.data.add(name="norm1", data=norm1) + self.data.add(name="norm2", data=norm2) + # Assign analysis-relevant meta data self.data.set_attrs(self.attrs) + return self + + def transform( + self, data1: Optional[DataObject] = None, data2: Optional[DataObject] = None + ) -> Sequence[DataArray]: + """Get the expansion coefficients of "unseen" data. - def transform(self, **kwargs): - """Project new unseen data onto the singular vectors. + The expansion coefficients are obtained by projecting data onto the singular vectors. Parameters ---------- - data1: xr.DataArray or list of xarray.DataArray + data1: DataArray | Dataset | List[DataArray] Left input data. Must be provided if `data2` is not provided. - data2: xr.DataArray or list of xarray.DataArray + data2: DataArray | Dataset | List[DataArray] Right input data. Must be provided if `data1` is not provided. Returns @@ -204,26 +223,25 @@ def transform(self, **kwargs): Right scores. """ + return super().transform(data1, data2) + + def _transform_algorithm( + self, data1: Optional[DataArray] = None, data2: Optional[DataArray] = None + ) -> Sequence[DataArray]: results = [] - if "data1" in kwargs.keys(): - # Preprocess input data - data1 = kwargs["data1"] - data1 = self.preprocessor1.transform(data1) + if data1 is not None: # Project data onto singular vectors - comps1 = self.data.components1 - norm1 = self.data.norm1 + comps1 = self.data["components1"] + norm1 = self.data["norm1"] scores1 = xr.dot(data1, comps1) / norm1 # Inverse transform scores scores1 = self.preprocessor1.inverse_transform_scores(scores1) results.append(scores1) - if "data2" in kwargs.keys(): - # Preprocess input data - data2 = kwargs["data2"] - data2 = self.preprocessor2.transform(data2) + if data2 is not None: # Project data onto singular vectors - comps2 = self.data.components2 - norm2 = self.data.norm2 + comps2 = self.data["components2"] + norm2 = self.data["norm2"] scores2 = xr.dot(data2, comps2) / norm2 # Inverse transform scores scores2 = self.preprocessor2.inverse_transform_scores(scores2) @@ -252,16 +270,16 @@ def inverse_transform(self, mode): """ # Singular vectors - comps1 = self.data.components1.sel(mode=mode) - comps2 = self.data.components2.sel(mode=mode) + comps1 = self.data["components1"].sel(mode=mode) + comps2 = self.data["components2"].sel(mode=mode) # Scores = projections - scores1 = self.data.scores1.sel(mode=mode) - scores2 = self.data.scores2.sel(mode=mode) + scores1 = self.data["scores1"].sel(mode=mode) + scores2 = self.data["scores2"].sel(mode=mode) # Norms - norm1 = self.data.norm1.sel(mode=mode) - norm2 = self.data.norm2.sel(mode=mode) + norm1 = self.data["norm1"].sel(mode=mode) + norm2 = self.data["norm2"].sel(mode=mode) # Reconstruct the data data1 = xr.dot(scores1, comps1.conj() * norm1, dims="mode") @@ -284,7 +302,7 @@ def squared_covariance(self): squared singular values of the covariance matrix. """ - return self.data.squared_covariance + return self.data["squared_covariance"] def squared_covariance_fraction(self): """Calculate the squared covariance fraction (SCF). @@ -298,11 +316,31 @@ def squared_covariance_fraction(self): where `m` is the total number of modes and :math:`\\sigma_i` is the `ith` singular value of the covariance matrix. """ - return self.data.squared_covariance_fraction + return self.data["squared_covariance"] / self.data["total_squared_covariance"] def singular_values(self): """Get the singular values of the cross-covariance matrix.""" - return self.data.singular_values + singular_values = xr.apply_ufunc( + np.sqrt, + self.data["squared_covariance"], + dask="allowed", + vectorize=False, + keep_attrs=True, + ) + singular_values.name = "singular_values" + return singular_values + + def total_covariance(self) -> DataArray: + """Get the total covariance. + + This measure follows the defintion of Cheng and Dunkerton (1995). + Note that this measure is not an invariant in MCA. + + """ + tot_cov = self.singular_values().sum() + tot_cov.attrs.update(self.singular_values().attrs) + tot_cov.name = "total_covariance" + return tot_cov def covariance_fraction(self): """Get the covariance fraction (CF). @@ -331,7 +369,8 @@ def covariance_fraction(self): """ # Check how sensitive the CF is to the number of modes - svals = self.data.singular_values + svals = self.singular_values() + tot_var = self.total_covariance() cf = svals[0] / svals.cumsum() change_per_mode = cf.shift({"mode": 1}) - cf change_in_cf_in_last_mode = change_per_mode.isel(mode=-1) @@ -339,7 +378,10 @@ def covariance_fraction(self): print( f"Warning: CF is sensitive to the number of modes retained. Please increase `n_modes` for a better estimate." ) - return self.data.covariance_fraction + cov_frac = svals / tot_var + cov_frac.name = "covariance_fraction" + cov_frac.attrs.update(svals.attrs) + return cov_frac def components(self): """Return the singular vectors of the left and right field. @@ -416,11 +458,11 @@ def homogeneous_patterns(self, correction=None, alpha=0.05): Right p-values. """ - input_data1 = self.data.input_data1 - input_data2 = self.data.input_data2 + input_data1 = self.data["input_data1"] + input_data2 = self.data["input_data2"] - scores1 = self.data.scores1 - scores2 = self.data.scores2 + scores1 = self.data["scores1"] + scores2 = self.data["scores2"] hom_pat1, pvals1 = pearson_correlation( input_data1, scores1, correction=correction, alpha=alpha @@ -429,18 +471,18 @@ def homogeneous_patterns(self, correction=None, alpha=0.05): input_data2, scores2, correction=correction, alpha=alpha ) - hom_pat1 = self.preprocessor1.inverse_transform_components(hom_pat1) - hom_pat2 = self.preprocessor2.inverse_transform_components(hom_pat2) - - pvals1 = self.preprocessor1.inverse_transform_components(pvals1) - pvals2 = self.preprocessor2.inverse_transform_components(pvals2) - hom_pat1.name = "left_homogeneous_patterns" hom_pat2.name = "right_homogeneous_patterns" pvals1.name = "pvalues_of_left_homogeneous_patterns" pvals2.name = "pvalues_of_right_homogeneous_patterns" + hom_pat1 = self.preprocessor1.inverse_transform_components(hom_pat1) + hom_pat2 = self.preprocessor2.inverse_transform_components(hom_pat2) + + pvals1 = self.preprocessor1.inverse_transform_components(pvals1) + pvals2 = self.preprocessor2.inverse_transform_components(pvals2) + return (hom_pat1, hom_pat2), (pvals1, pvals2) def heterogeneous_patterns(self, correction=None, alpha=0.05): @@ -478,11 +520,11 @@ def heterogeneous_patterns(self, correction=None, alpha=0.05): The desired family-wise error rate. Not used if `correction` is None. """ - input_data1 = self.data.input_data1 - input_data2 = self.data.input_data2 + input_data1 = self.data["input_data1"] + input_data2 = self.data["input_data2"] - scores1 = self.data.scores1 - scores2 = self.data.scores2 + scores1 = self.data["scores1"] + scores2 = self.data["scores2"] patterns1, pvals1 = pearson_correlation( input_data1, scores2, correction=correction, alpha=alpha @@ -491,18 +533,18 @@ def heterogeneous_patterns(self, correction=None, alpha=0.05): input_data2, scores1, correction=correction, alpha=alpha ) - patterns1 = self.preprocessor1.inverse_transform_components(patterns1) - patterns2 = self.preprocessor2.inverse_transform_components(patterns2) - - pvals1 = self.preprocessor1.inverse_transform_components(pvals1) - pvals2 = self.preprocessor2.inverse_transform_components(pvals2) - patterns1.name = "left_heterogeneous_patterns" patterns2.name = "right_heterogeneous_patterns" pvals1.name = "pvalues_of_left_heterogeneous_patterns" pvals2.name = "pvalues_of_right_heterogeneous_patterns" + patterns1 = self.preprocessor1.inverse_transform_components(patterns1) + patterns2 = self.preprocessor2.inverse_transform_components(patterns2) + + pvals1 = self.preprocessor1.inverse_transform_components(pvals1) + pvals2 = self.preprocessor2.inverse_transform_components(pvals2) + return (patterns1, patterns2), (pvals1, pvals2) @@ -519,14 +561,8 @@ class ComplexMCA(MCA): Parameters ---------- - n_modes: int, default=10 + n_modes: int, default=2 Number of modes to calculate. - standardize: bool, default=False - Whether to standardize the input data. - use_coslat: bool, default=False - Whether to use cosine of latitude for scaling. - use_weights: bool, default=False - Whether to use additional weights. padding : str, optional Specifies the method used for padding the data prior to applying the Hilbert transform. This can help to mitigate the effect of spectral leakage. @@ -538,6 +574,30 @@ class ComplexMCA(MCA): A smaller value (e.g. 0.05) is recommended for data with high variability, while a larger value (e.g. 0.2) is recommended for data with low variability. Default is 0.2. + center: bool, default=True + Whether to center the input data. + standardize: bool, default=False + Whether to standardize the input data. + use_coslat: bool, default=False + Whether to use cosine of latitude for scaling. + n_pca_modes: int, default=None + The number of principal components to retain during the PCA preprocessing + step applied to both data sets prior to executing MCA. + If set to None, PCA preprocessing will be bypassed, and the MCA will be performed on the original datasets. + Specifying an integer value greater than 0 for `n_pca_modes` will trigger the PCA preprocessing, retaining + only the specified number of principal components. This reduction in dimensionality can be especially beneficial + when dealing with high-dimensional data, where computing the cross-covariance matrix can become computationally + intensive or in scenarios where multicollinearity is a concern. + compute: bool, default=True + Whether to compute the decomposition immediately. + sample_name: str, default="sample" + Name of the new sample dimension. + feature_name: str, default="feature" + Name of the new feature dimension. + solver: {"auto", "full", "randomized"}, default="auto" + Solver to use for the SVD computation. + random_state: int, optional + Random state for randomized SVD solver. solver_kwargs: dict, default={} Additional keyword arguments passed to the SVD solver. @@ -563,121 +623,102 @@ class ComplexMCA(MCA): """ - def __init__(self, padding="exp", decay_factor=0.2, **kwargs): - super().__init__(**kwargs) - self.attrs.update({"model": "Complex MCA"}) - self._params.update({"padding": padding, "decay_factor": decay_factor}) - - # Initialize the DataContainer to store the results - self.data: ComplexMCADataContainer = ComplexMCADataContainer() - - def fit( + def __init__( self, - data1: AnyDataObject, - data2: AnyDataObject, - dim, - weights1=None, - weights2=None, + n_modes: int = 2, + padding: str = "exp", + decay_factor: float = 0.2, + center: bool = True, + standardize: bool = False, + use_coslat: bool = False, + n_pca_modes: Optional[int] = None, + compute: bool = True, + sample_name: str = "sample", + feature_name: str = "feature", + solver: str = "auto", + random_state: Optional[bool] = None, + solver_kwargs: Dict = {}, ): - """Fit the model. - - Parameters - ---------- - data1: xr.DataArray or list of xarray.DataArray - Left input data. - data2: xr.DataArray or list of xarray.DataArray - Right input data. - dim: tuple - Tuple specifying the sample dimensions. The remaining dimensions - will be treated as feature dimensions. - weights1: xr.DataArray or xr.Dataset or None, default=None - If specified, the left input data will be weighted by this array. - weights2: xr.DataArray or xr.Dataset or None, default=None - If specified, the right input data will be weighted by this array. - - """ - - data1_processed: DataArray = self.preprocessor1.fit_transform( - data1, dim, weights1 - ) - data2_processed: DataArray = self.preprocessor2.fit_transform( - data2, dim, weights2 + super().__init__( + n_modes=n_modes, + center=center, + standardize=standardize, + use_coslat=use_coslat, + n_pca_modes=n_pca_modes, + compute=compute, + sample_name=sample_name, + feature_name=feature_name, + solver=solver, + random_state=random_state, + solver_kwargs=solver_kwargs, ) + self.attrs.update({"model": "Complex MCA"}) + self._params.update({"padding": padding, "decay_factor": decay_factor}) - # Apply Hilbert transform: - padding = self._params["padding"] - decay_factor = self._params["decay_factor"] - data1_processed = hilbert_transform( - data1_processed, - dim="sample", - padding=padding, - decay_factor=decay_factor, - ) - data2_processed = hilbert_transform( - data2_processed, - dim="sample", - padding=padding, - decay_factor=decay_factor, - ) + def _fit_algorithm(self, data1: DataArray, data2: DataArray) -> Self: + sample_name = self.sample_name + feature_name = self.feature_name + + # Settings for Hilbert transform + hilbert_kwargs = { + "padding": self._params["padding"], + "decay_factor": self._params["decay_factor"], + } # Initialize the SVD decomposer - decomposer = Decomposer( - n_modes=self._params["n_modes"], - solver=self._params["solver"], - **self._solver_kwargs, - ) + decomposer = Decomposer(n_modes=self._params["n_modes"], **self._solver_kwargs) # Perform SVD on PCA-reduced data if (self.pca1 is not None) and (self.pca2 is not None): # Fit the PCA models - self.pca1.fit(data1_processed, "sample") - self.pca2.fit(data2_processed, "sample") + self.pca1.fit(data1, sample_name) + self.pca2.fit(data2, sample_name) # Get the PCA scores - pca_scores1 = self.pca1.data.scores * self.pca1.data.singular_values - pca_scores2 = self.pca2.data.scores * self.pca2.data.singular_values + pca_scores1 = self.pca1.data["scores"] * self.pca1.singular_values() + pca_scores2 = self.pca2.data["scores"] * self.pca2.singular_values() # Apply hilbert transform pca_scores1 = hilbert_transform( - pca_scores1, - dim="sample", - padding=padding, - decay_factor=decay_factor, + pca_scores1, dims=(sample_name, "mode"), **hilbert_kwargs ) pca_scores2 = hilbert_transform( - pca_scores2, - dim="sample", - padding=padding, - decay_factor=decay_factor, + pca_scores2, dims=(sample_name, "mode"), **hilbert_kwargs ) # Compute the cross-covariance matrix of the PCA scores - pca_scores1 = pca_scores1.rename({"mode": "feature"}) - pca_scores2 = pca_scores2.rename({"mode": "feature"}) + pca_scores1 = pca_scores1.rename({"mode": "feature_temp1"}) + pca_scores2 = pca_scores2.rename({"mode": "feature_temp2"}) cov_matrix = self._compute_cross_covariance_matrix(pca_scores1, pca_scores2) # Perform the SVD - decomposer.fit(cov_matrix, dims=("feature1", "feature2")) - V1 = decomposer.U_ # left singular vectors (feature1 x mode) - V2 = decomposer.V_ # right singular vectors (feature2 x mode) + decomposer.fit(cov_matrix, dims=("feature_temp1", "feature_temp2")) + V1 = decomposer.U_ # left singular vectors (feature_temp1 x mode) + V2 = decomposer.V_ # right singular vectors (feature_temp2 x mode) - V1pre = self.pca1.data.components # left PCA eigenvectors (feature x mode) - V2pre = self.pca2.data.components # right PCA eigenvectors (feature x mode) + # left and right PCA eigenvectors (feature_name x mode) + V1pre = self.pca1.data["components"] + V2pre = self.pca2.data["components"] # Compute the singular vectors - V1pre = V1pre.rename({"mode": "feature1"}) - V2pre = V2pre.rename({"mode": "feature2"}) - singular_vectors1 = xr.dot(V1pre, V1, dims="feature1") - singular_vectors2 = xr.dot(V2pre, V2, dims="feature2") + V1pre = V1pre.rename({"mode": "feature_temp1"}) + V2pre = V2pre.rename({"mode": "feature_temp2"}) + singular_vectors1 = xr.dot(V1pre, V1, dims="feature_temp1") + singular_vectors2 = xr.dot(V2pre, V2, dims="feature_temp2") # Perform SVD directly on data else: + # Perform Hilbert transform + data1 = hilbert_transform( + data1, dims=(sample_name, feature_name), **hilbert_kwargs + ) + data2 = hilbert_transform( + data2, dims=(sample_name, feature_name), **hilbert_kwargs + ) # Rename feature and associated dimensions of data objects to avoid index conflicts - dim_renamer1 = DimensionRenamer("feature", "1") - dim_renamer2 = DimensionRenamer("feature", "2") - data1_processed_temp = dim_renamer1.fit_transform(data1_processed) - data2_processed_temp = dim_renamer2.fit_transform(data2_processed) + dim_renamer1 = DimensionRenamer(feature_name, "1") + dim_renamer2 = DimensionRenamer(feature_name, "2") + data1_temp = dim_renamer1.fit_transform(data1) + data2_temp = dim_renamer2.fit_transform(data2) # Compute the cross-covariance matrix - cov_matrix = self._compute_cross_covariance_matrix( - data1_processed_temp, data2_processed_temp - ) + cov_matrix = self._compute_cross_covariance_matrix(data1_temp, data2_temp) # Perform the SVD decomposer.fit(cov_matrix, dims=("feature1", "feature2")) @@ -703,26 +744,26 @@ def fit( idx_sorted_modes.coords.update(squared_covariance.coords) # Project the data onto the singular vectors - scores1 = xr.dot(data1_processed, singular_vectors1) / norm1 - scores2 = xr.dot(data2_processed, singular_vectors2) / norm2 - - self.data.set_data( - input_data1=data1_processed, - input_data2=data2_processed, - components1=singular_vectors1, - components2=singular_vectors2, - scores1=scores1, - scores2=scores2, - squared_covariance=squared_covariance, - total_squared_covariance=total_squared_covariance, - idx_modes_sorted=idx_sorted_modes, - norm1=norm1, - norm2=norm2, - ) + scores1 = xr.dot(data1, singular_vectors1) / norm1 + scores2 = xr.dot(data2, singular_vectors2) / norm2 + + self.data.add(name="input_data1", data=data1, allow_compute=False) + self.data.add(name="input_data2", data=data2, allow_compute=False) + self.data.add(name="components1", data=singular_vectors1) + self.data.add(name="components2", data=singular_vectors2) + self.data.add(name="scores1", data=scores1) + self.data.add(name="scores2", data=scores2) + self.data.add(name="squared_covariance", data=squared_covariance) + self.data.add(name="total_squared_covariance", data=total_squared_covariance) + self.data.add(name="idx_modes_sorted", data=idx_sorted_modes) + self.data.add(name="norm1", data=norm1) + self.data.add(name="norm2", data=norm2) + # Assign analysis relevant meta data self.data.set_attrs(self.attrs) + return self - def components_amplitude(self) -> Tuple[AnyDataObject, AnyDataObject]: + def components_amplitude(self) -> Tuple[DataObject, DataObject]: """Compute the amplitude of the components. The amplitude of the components are defined as @@ -735,21 +776,24 @@ def components_amplitude(self) -> Tuple[AnyDataObject, AnyDataObject]: Returns ------- - AnyDataObject + DataObject Amplitude of the left components. - AnyDataObject + DataObject Amplitude of the left components. """ - comps1 = self.data.components_amplitude1 - comps2 = self.data.components_amplitude2 + comps1 = abs(self.data["components1"]) + comps1.name = "left_components_amplitude" + + comps2 = abs(self.data["components2"]) + comps2.name = "right_components_amplitude" comps1 = self.preprocessor1.inverse_transform_components(comps1) comps2 = self.preprocessor2.inverse_transform_components(comps2) return (comps1, comps2) - def components_phase(self) -> Tuple[AnyDataObject, AnyDataObject]: + def components_phase(self) -> Tuple[DataObject, DataObject]: """Compute the phase of the components. The phase of the components are defined as @@ -762,14 +806,17 @@ def components_phase(self) -> Tuple[AnyDataObject, AnyDataObject]: Returns ------- - AnyDataObject + DataObject Phase of the left components. - AnyDataObject + DataObject Phase of the right components. """ - comps1 = self.data.components_phase1 - comps2 = self.data.components_phase2 + comps1 = xr.apply_ufunc(np.angle, self.data["components1"], keep_attrs=True) + comps1.name = "left_components_phase" + + comps2 = xr.apply_ufunc(np.angle, self.data["components2"], keep_attrs=True) + comps2.name = "right_components_phase" comps1 = self.preprocessor1.inverse_transform_components(comps1) comps2 = self.preprocessor2.inverse_transform_components(comps2) @@ -795,8 +842,11 @@ def scores_amplitude(self) -> Tuple[DataArray, DataArray]: Amplitude of the right scores. """ - scores1 = self.data.scores_amplitude1 - scores2 = self.data.scores_amplitude2 + scores1 = abs(self.data["scores1"]) + scores2 = abs(self.data["scores2"]) + + scores1.name = "left_scores_amplitude" + scores2.name = "right_scores_amplitude" scores1 = self.preprocessor1.inverse_transform_scores(scores1) scores2 = self.preprocessor2.inverse_transform_scores(scores2) @@ -821,15 +871,18 @@ def scores_phase(self) -> Tuple[DataArray, DataArray]: Phase of the right scores. """ - scores1 = self.data.scores_phase1 - scores2 = self.data.scores_phase2 + scores1 = xr.apply_ufunc(np.angle, self.data["scores1"], keep_attrs=True) + scores2 = xr.apply_ufunc(np.angle, self.data["scores2"], keep_attrs=True) + + scores1.name = "left_scores_phase" + scores2.name = "right_scores_phase" scores1 = self.preprocessor1.inverse_transform_scores(scores1) scores2 = self.preprocessor2.inverse_transform_scores(scores2) return (scores1, scores2) - def transform(self, data1: AnyDataObject, data2: AnyDataObject): + def transform(self, data1: DataObject, data2: DataObject): raise NotImplementedError("Complex MCA does not support transform method.") def homogeneous_patterns(self, correction=None, alpha=0.05): diff --git a/xeofs/models/mca_rotator.py b/xeofs/models/mca_rotator.py index da381e4..d73a93b 100644 --- a/xeofs/models/mca_rotator.py +++ b/xeofs/models/mca_rotator.py @@ -6,10 +6,7 @@ from .mca import MCA, ComplexMCA from ..utils.rotation import promax from ..utils.data_types import DataArray -from ..data_container.mca_rotator_data_container import ( - MCARotatorDataContainer, - ComplexMCARotatorDataContainer, -) +from ..data_container import DataContainer from .._version import __version__ @@ -40,6 +37,8 @@ class MCARotator(MCA): conserving the squared covariance under rotation. This allows estimation of mode importance after rotation. If False, the combined vectors are loaded with the square root of the singular values, following the method described by Cheng & Dunkerton [1]_. + compute : bool, default=True + Whether to compute the decomposition immediately. References ---------- @@ -62,7 +61,9 @@ def __init__( max_iter: int = 1000, rtol: float = 1e-8, squared_loadings: bool = False, + compute: bool = True, ): + self._compute = compute # Define model parameters self._params = { "n_modes": n_modes, @@ -83,8 +84,8 @@ def __init__( } ) - # Initialize the DataContainer to hold the rotated solution - self.data: MCARotatorDataContainer = MCARotatorDataContainer() + # Define data container to store the rotated solution + self.data = DataContainer() def _compute_rot_mat_inv_trans(self, rotation_matrix, input_dims) -> xr.DataArray: """Compute the inverse transpose of the rotation matrix. @@ -125,6 +126,9 @@ def fit(self, model: MCA | ComplexMCA): self.preprocessor1 = model.preprocessor1 self.preprocessor2 = model.preprocessor2 + sample_name = self.model.sample_name + feature_name = self.model.feature_name + n_modes = self._params["n_modes"] power = self._params["power"] max_iter = self._params["max_iter"] @@ -142,8 +146,8 @@ def fit(self, model: MCA | ComplexMCA): # or weighted with the singular values ("squared loadings"), as opposed to the square root of the singular values. # In doing so, the squared covariance remains conserved under rotation, allowing for the estimation of the # modes' importance. - norm1 = self.model.data.norm1.sel(mode=slice(1, n_modes)) - norm2 = self.model.data.norm2.sel(mode=slice(1, n_modes)) + norm1 = self.model.data["norm1"].sel(mode=slice(1, n_modes)) + norm2 = self.model.data["norm2"].sel(mode=slice(1, n_modes)) if use_squared_loadings: # Squared loadings approach conserving squared covariance scaling = norm1 * norm2 @@ -151,24 +155,19 @@ def fit(self, model: MCA | ComplexMCA): # Cheng & Dunkerton approach conserving covariance scaling = np.sqrt(norm1 * norm2) - comps1 = self.model.data.components1.sel(mode=slice(1, n_modes)) - comps2 = self.model.data.components2.sel(mode=slice(1, n_modes)) - loadings = xr.concat([comps1, comps2], dim="feature") * scaling + comps1 = self.model.data["components1"].sel(mode=slice(1, n_modes)) + comps2 = self.model.data["components2"].sel(mode=slice(1, n_modes)) + loadings = xr.concat([comps1, comps2], dim=feature_name) * scaling # Rotate loadings - rot_loadings, rot_matrix, phi_matrix = xr.apply_ufunc( - promax, - loadings, - power, - input_core_dims=[["feature", "mode"], []], - output_core_dims=[ - ["feature", "mode"], - ["mode_m", "mode_n"], - ["mode_m", "mode_n"], - ], - kwargs={"max_iter": max_iter, "rtol": rtol}, - dask="allowed", + promax_kwargs = {"power": power, "max_iter": max_iter, "rtol": rtol} + rot_loadings, rot_matrix, phi_matrix = promax( + loadings=loadings, + feature_dim=feature_name, + compute=self._compute, + **promax_kwargs ) + # Assign coordinates to the rotation/correlation matrices rot_matrix = rot_matrix.assign_coords( mode_m=np.arange(1, rot_matrix.mode_m.size + 1), @@ -180,18 +179,20 @@ def fit(self, model: MCA | ComplexMCA): ) # Rotated (loaded) singular vectors - comps1_rot = rot_loadings.isel(feature=slice(0, comps1.coords["feature"].size)) + comps1_rot = rot_loadings.isel( + {feature_name: slice(0, comps1.coords[feature_name].size)} + ) comps2_rot = rot_loadings.isel( - feature=slice(comps1.coords["feature"].size, None) + {feature_name: slice(comps1.coords[feature_name].size, None)} ) # Normalization factor of singular vectors norm1_rot = xr.apply_ufunc( np.linalg.norm, comps1_rot, - input_core_dims=[["feature", "mode"]], + input_core_dims=[[feature_name, "mode"]], output_core_dims=[["mode"]], - exclude_dims={"feature"}, + exclude_dims={feature_name}, kwargs={"axis": 0}, vectorize=False, dask="allowed", @@ -199,9 +200,9 @@ def fit(self, model: MCA | ComplexMCA): norm2_rot = xr.apply_ufunc( np.linalg.norm, comps2_rot, - input_core_dims=[["feature", "mode"]], + input_core_dims=[[feature_name, "mode"]], output_core_dims=[["mode"]], - exclude_dims={"feature"}, + exclude_dims={feature_name}, kwargs={"axis": 0}, vectorize=False, dask="allowed", @@ -244,8 +245,8 @@ def fit(self, model: MCA | ComplexMCA): ) # Rotate scores using rotation matrix - scores1 = self.model.data.scores1.sel(mode=slice(1, n_modes)) - scores2 = self.model.data.scores2.sel(mode=slice(1, n_modes)) + scores1 = self.model.data["scores1"].sel(mode=slice(1, n_modes)) + scores2 = self.model.data["scores2"].sel(mode=slice(1, n_modes)) RinvT = self._compute_rot_mat_inv_trans( rot_matrix, input_dims=("mode_m", "mode_n") @@ -266,9 +267,9 @@ def fit(self, model: MCA | ComplexMCA): ) # Ensure consitent signs for deterministic output - idx_max_value = abs(rot_loadings).argmax("feature").compute() + idx_max_value = abs(rot_loadings).argmax(feature_name).compute() modes_sign = xr.apply_ufunc( - np.sign, rot_loadings.isel(feature=idx_max_value), dask="allowed" + np.sign, rot_loadings.isel({feature_name: idx_max_value}), dask="allowed" ) # Drop all dimensions except 'mode' so that the index is clean for dim, coords in modes_sign.coords.items(): @@ -280,22 +281,27 @@ def fit(self, model: MCA | ComplexMCA): scores2_rot = scores2_rot * modes_sign # Create data container - self.data.set_data( - input_data1=self.model.data.input_data1, - input_data2=self.model.data.input_data2, - components1=comps1_rot, - components2=comps2_rot, - scores1=scores1_rot, - scores2=scores2_rot, - squared_covariance=squared_covariance, - total_squared_covariance=self.model.data.total_squared_covariance, - idx_modes_sorted=idx_modes_sorted, - norm1=norm1_rot, - norm2=norm2_rot, - rotation_matrix=rot_matrix, - phi_matrix=phi_matrix, - modes_sign=modes_sign, + self.data.add( + name="input_data1", data=self.model.data["input_data1"], allow_compute=False ) + self.data.add( + name="input_data2", data=self.model.data["input_data2"], allow_compute=False + ) + self.data.add(name="components1", data=comps1_rot) + self.data.add(name="components2", data=comps2_rot) + self.data.add(name="scores1", data=scores1_rot) + self.data.add(name="scores2", data=scores2_rot) + self.data.add(name="squared_covariance", data=squared_covariance) + self.data.add( + name="total_squared_covariance", + data=self.model.data["total_squared_covariance"], + ) + self.data.add(name="idx_modes_sorted", data=idx_modes_sorted) + self.data.add(name="norm1", data=norm1_rot) + self.data.add(name="norm2", data=norm2_rot) + self.data.add(name="rotation_matrix", data=rot_matrix) + self.data.add(name="phi_matrix", data=phi_matrix) + self.data.add(name="modes_sign", data=modes_sign) # Assign analysis-relevant meta data self.data.set_attrs(self.attrs) @@ -305,9 +311,9 @@ def transform(self, **kwargs) -> DataArray | List[DataArray]: Parameters ---------- - data1 : DataArray | Dataset | DataArraylist + data1 : DataArray | Dataset | List[DataArray] Data to be projected onto the rotated singular vectors of the first dataset. - data2 : DataArray | Dataset | DataArraylist + data2 : DataArray | Dataset | List[DataArray] Data to be projected onto the rotated singular vectors of the second dataset. Returns @@ -321,7 +327,7 @@ def transform(self, **kwargs) -> DataArray | List[DataArray]: raise ValueError("No data provided. Please provide data1 and/or data2.") n_modes = self._params["n_modes"] - rot_matrix = self.data.rotation_matrix + rot_matrix = self.data["rotation_matrix"] RinvT = self._compute_rot_mat_inv_trans( rot_matrix, input_dims=("mode_m", "mode_n") ) @@ -332,8 +338,8 @@ def transform(self, **kwargs) -> DataArray | List[DataArray]: if "data1" in kwargs.keys(): data1 = kwargs["data1"] # Select the (non-rotated) singular vectors of the first dataset - comps1 = self.model.data.components1.sel(mode=slice(1, n_modes)) - norm1 = self.model.data.norm1.sel(mode=slice(1, n_modes)) + comps1 = self.model.data["components1"].sel(mode=slice(1, n_modes)) + norm1 = self.model.data["norm1"].sel(mode=slice(1, n_modes)) # Preprocess the data data1 = self.preprocessor1.transform(data1) @@ -345,10 +351,10 @@ def transform(self, **kwargs) -> DataArray | List[DataArray]: projections1 = xr.dot(projections1, RinvT, dims="mode_m") # Reorder according to variance projections1 = projections1.isel( - mode=self.data.idx_modes_sorted.values + mode=self.data["idx_modes_sorted"].values ).assign_coords(mode=projections1.mode) # Adapt the sign of the scores - projections1 = projections1 * self.data.modes_sign + projections1 = projections1 * self.data["modes_sign"] # Unstack the projections projections1 = self.preprocessor1.inverse_transform_scores(projections1) @@ -358,8 +364,8 @@ def transform(self, **kwargs) -> DataArray | List[DataArray]: if "data2" in kwargs.keys(): data2 = kwargs["data2"] # Select the (non-rotated) singular vectors of the second dataset - comps2 = self.model.data.components2.sel(mode=slice(1, n_modes)) - norm2 = self.model.data.norm2.sel(mode=slice(1, n_modes)) + comps2 = self.model.data["components2"].sel(mode=slice(1, n_modes)) + norm2 = self.model.data["norm2"].sel(mode=slice(1, n_modes)) # Preprocess the data data2 = self.preprocessor2.transform(data2) @@ -371,10 +377,10 @@ def transform(self, **kwargs) -> DataArray | List[DataArray]: projections2 = xr.dot(projections2, RinvT, dims="mode_m") # Reorder according to variance projections2 = projections2.isel( - mode=self.data.idx_modes_sorted.values + mode=self.data["idx_modes_sorted"].values ).assign_coords(mode=projections2.mode) # Determine the sign of the scores - projections2 = projections2 * self.data.modes_sign + projections2 = projections2 * self.data["modes_sign"] # Unstack the projections projections2 = self.preprocessor2.inverse_transform_scores(projections2) @@ -438,9 +444,6 @@ def __init__(self, **kwargs): super().__init__(**kwargs) self.attrs.update({"model": "Complex Rotated MCA"}) - # Initialize the DataContainer to hold the rotated solution - self.data: ComplexMCARotatorDataContainer = ComplexMCARotatorDataContainer() - def transform(self, **kwargs): # Here we make use of the Method Resolution Order (MRO) to call the # transform method of the first class in the MRO after `MCARotator` diff --git a/xeofs/models/opa.py b/xeofs/models/opa.py index 0222ed7..6db9916 100644 --- a/xeofs/models/opa.py +++ b/xeofs/models/opa.py @@ -1,4 +1,5 @@ -from typing import Optional +from typing import Optional, Dict +from typing_extensions import Self import xarray as xr import numpy as np @@ -6,14 +7,14 @@ from ._base_model import _BaseModel from .eof import EOF from .decomposer import Decomposer -from ..data_container.opa_data_container import OPADataContainer -from ..utils.data_types import AnyDataObject, DataArray +from ..utils.data_types import DataObject, DataArray class OPA(_BaseModel): """Optimal Persistence Analysis (OPA). - OPA identifies the optimal persistence patterns (OPP) with the + OPA identifies the optimal persistence patterns or + optimally persistent patterns (OPP) with the largest decorrelation time in a time-varying field. Introduced by DelSole in 2001 [1]_, and further developed in 2006 [2]_, it's a method used to find patterns whose time series show strong persistence over time. @@ -24,8 +25,24 @@ class OPA(_BaseModel): Number of optimal persistence patterns (OPP) to be computed. tau_max : int Maximum time lag for the computation of the covariance matrix. + center : bool, default=True + Whether to center the input data. + standardize : bool, default=False + Whether to standardize the input data. + use_coslat : bool, default=False + Whether to use cosine of latitude for scaling. n_pca_modes : int Number of modes to be computed in the pre-processing step using EOF. + compute : bool, default=True + Whether to compute the decomposition immediately. + sample_name : str, default="sample" + Name of the sample dimension. + feature_name : str, default="feature" + Name of the feature dimension. + solver : {"auto", "full", "randomized"}, default="auto" + Solver to use for the SVD computation. + solver_kwargs : dict, default={} + Additional keyword arguments to pass to the solver. References ---------- @@ -38,36 +55,61 @@ class OPA(_BaseModel): >>> model = OPA(n_modes=10, tau_max=50, n_pca_modes=100) >>> model.fit(data, dim=("time")) - Retrieve the optimal perstitence patterns (OPP) and their time series: + Retrieve the optimally persistent patterns (OPP) and their time series: >>> opp = model.components() >>> opp_ts = model.scores() - Retrieve the decorrelation time of the optimal persistence patterns (OPP): + Retrieve the decorrelation time of the OPPs: >>> decorrelation_time = model.decorrelation_time() """ - def __init__(self, n_modes, tau_max, n_pca_modes, **kwargs): + def __init__( + self, + n_modes: int, + tau_max: int, + center: bool = True, + standardize: bool = False, + use_coslat: bool = False, + n_pca_modes: int = 100, + compute: bool = True, + sample_name: str = "sample", + feature_name: str = "feature", + solver: str = "auto", + random_state: Optional[int] = None, + solver_kwargs: Dict = {}, + ): if n_modes > n_pca_modes: raise ValueError( f"n_modes must be smaller or equal to n_pca_modes (n_modes={n_modes}, n_pca_modes={n_pca_modes})" ) - super().__init__(n_modes=n_modes, **kwargs) + super().__init__( + n_modes=n_modes, + center=center, + standardize=standardize, + use_coslat=use_coslat, + compute=compute, + sample_name=sample_name, + feature_name=feature_name, + solver=solver, + random_state=random_state, + solver_kwargs=solver_kwargs, + ) self.attrs.update({"model": "OPA"}) self._params.update({"tau_max": tau_max, "n_pca_modes": n_pca_modes}) - # Initialize the DataContainer to store the results - self.data: OPADataContainer = OPADataContainer() - def _Ctau(self, X, tau: int) -> DataArray: """Compute the time-lage covariance matrix C(tau) of the data X.""" + sample_name = self.preprocessor.sample_name X0 = X.copy(deep=True) - Xtau = X.shift(sample=-tau).dropna("sample") + Xtau = X.shift({sample_name: -tau}).dropna(sample_name) X0 = X0.rename({"mode": "feature1"}) Xtau = Xtau.rename({"mode": "feature2"}) - return xr.dot(X0, Xtau, dims=["sample"]) / (Xtau.sample.size - 1) + + n_samples = Xtau[sample_name].size + return xr.dot(X0, Xtau, dims=[sample_name]) / (n_samples - 1) @staticmethod def _compute_matrix_inverse(X, dims): @@ -81,24 +123,33 @@ def _compute_matrix_inverse(X, dims): dask="allowed", ) - def fit(self, data: AnyDataObject, dim, weights: Optional[AnyDataObject] = None): - # Preprocess the data - input_data: DataArray = self.preprocessor.fit_transform(data, dim, weights) + def _fit_algorithm(self, data: DataArray) -> Self: + sample_name = self.sample_name + feature_name = self.feature_name # Perform PCA as a pre-processing step - pca = EOF(n_modes=self._params["n_pca_modes"], use_coslat=False) - pca.fit(input_data, dim="sample") - svals = pca.data.singular_values - expvar = pca.data.explained_variance - comps = pca.data.components * svals / np.sqrt(expvar) + pca = EOF( + n_modes=self._params["n_pca_modes"], + standardize=False, + use_coslat=False, + sample_name=self.sample_name, + feature_name=self.feature_name, + solver=self._params["solver"], + compute=self._params["compute"], + random_state=self._params["random_state"], + solver_kwargs=self._solver_kwargs, + ) + pca.fit(data, dim=sample_name) + n_samples = data.coords[sample_name].size + comps = pca.data["components"] * np.sqrt(n_samples - 1) # -> comps (feature x mode) - scores = pca.data.scores * np.sqrt(expvar) + scores = pca.data["scores"] / np.sqrt(n_samples - 1) # -> scores (sample x mode) # Compute the covariance matrix with zero time lag C0 = self._Ctau(scores, 0) # -> C0 (feature1 x feature2) - C0inv = self._compute_matrix_inverse(C0, dims=("feature1", "feature2")) + # C0inv = self._compute_matrix_inverse(C0, dims=("feature1", "feature2")) # -> C0inv (feature2 x feature1) M = 0.5 * C0 # -> M (feature1 x feature2) @@ -120,7 +171,12 @@ def fit(self, data: AnyDataObject, dim, weights: Optional[AnyDataObject] = None) # using a symmtric matrix given in # A. Hannachi (2021), Patterns Identification and # Data Mining in Weather and Climate, Equation (8.20) - decomposer = Decomposer(n_modes=C0.shape[0], flip_signs=False, solver="full") + decomposer = Decomposer( + n_modes=C0.shape[0], + flip_signs=False, + compute=self._params["compute"], + solver="full", + ) decomposer.fit(C0, dims=("feature1", "feature2")) C0_sqrt = decomposer.U_ * np.sqrt(decomposer.s_) # -> C0_sqrt (feature1 x mode) @@ -177,35 +233,48 @@ def fit(self, data: AnyDataObject, dim, weights: Optional[AnyDataObject] = None) # -> W (feature x mode2) # Rename dimensions - U = U.rename({"feature1": "feature"}) # -> (feature x mode) + U = U.rename({"feature1": feature_name}) # -> (feature x mode) V = V.rename({"mode2": "mode"}) # -> (feature x mode) W = W.rename({"mode2": "mode"}) # -> (feature x mode) P = P.rename({"mode2": "mode"}) # -> (sample x mode) + scores = scores.rename({"mode": feature_name}) # -> (sample x feature) - # Store the results - self.data.set_data( - input_data=scores.rename({"mode": "feature"}), - components=W, - scores=P, - filter_patterns=V, - decorrelation_time=lbda, + # Compute the norms of the scores + norms = xr.apply_ufunc( + np.linalg.norm, + P, + input_core_dims=[["sample"]], + vectorize=False, + dask="allowed", + kwargs={"axis": -1}, ) + + # Store the results + # NOTE: not sure if "scores" should be taken as input data here, "data" may be more correct -> to be verified + self.data.add(name="input_data", data=scores, allow_compute=False) + self.data.add(name="components", data=W, allow_compute=True) + self.data.add(name="scores", data=P, allow_compute=True) + self.data.add(name="norms", data=norms, allow_compute=True) + self.data.add(name="filter_patterns", data=V, allow_compute=True) + self.data.add(name="decorrelation_time", data=lbda, allow_compute=True) + self.data.set_attrs(self.attrs) self._U = U # store U for testing purposes of orthogonality self._C0 = C0 # store C0 for testing purposes of orthogonality + return self - def transform(self, data: AnyDataObject): - raise NotImplementedError() + def _transform_algorithm(self, data: DataArray) -> DataArray: + raise NotImplementedError("OPA does not (yet) support transform()") - def inverse_transform(self, mode): - raise NotImplementedError() + def _inverse_transform_algorithm(self, mode) -> DataObject: + raise NotImplementedError("OPA does not (yet) support inverse_transform()") - def components(self) -> AnyDataObject: - """Return the optimal persistence pattern (OPP).""" + def components(self) -> DataObject: + """Return the optimally persistent patterns (OPPs).""" return super().components() def scores(self) -> DataArray: - """Return the time series of the optimal persistence pattern (OPP). + """Return the time series of the OPPs. The time series have a maximum decorrelation time that are uncorrelated with each other. """ @@ -213,9 +282,9 @@ def scores(self) -> DataArray: def decorrelation_time(self) -> DataArray: """Return the decorrelation time of the optimal persistence pattern (OPP).""" - return self.data.decorrelation_time + return self.data["decorrelation_time"] - def filter_patterns(self) -> DataArray: + def filter_patterns(self) -> DataObject: """Return the filter patterns.""" - fps = self.data.filter_patterns + fps = self.data["filter_patterns"] return self.preprocessor.inverse_transform_components(fps) diff --git a/xeofs/models/rotator_factory.py b/xeofs/models/rotator_factory.py index e2b91cf..1b685d3 100644 --- a/xeofs/models/rotator_factory.py +++ b/xeofs/models/rotator_factory.py @@ -1,15 +1,7 @@ -import numpy as np -import xarray as xr -from typing import Optional, Union, List, Tuple - -from xeofs.utils.data_types import DataArrayList, XarrayData - from .eof import EOF, ComplexEOF from .mca import MCA, ComplexMCA from .eof_rotator import EOFRotator, ComplexEOFRotator from .mca_rotator import MCARotator, ComplexMCARotator -from ..utils.rotation import promax -from ..utils.data_types import XarrayData, DataArrayList, Dataset, DataArray class RotatorFactory: diff --git a/xeofs/preprocessing/__init__.py b/xeofs/preprocessing/__init__.py index e69de29..54b31f3 100644 --- a/xeofs/preprocessing/__init__.py +++ b/xeofs/preprocessing/__init__.py @@ -0,0 +1,12 @@ +from .scaler import Scaler +from .sanitizer import Sanitizer +from .multi_index_converter import MultiIndexConverter +from .stacker import DataArrayStacker, DataSetStacker + +__all__ = [ + "Scaler", + "Sanitizer", + "MultiIndexConverter", + "DataArrayStacker", + "DataSetStacker", +] diff --git a/xeofs/preprocessing/_base_scaler.py b/xeofs/preprocessing/_base_scaler.py deleted file mode 100644 index 44b19eb..0000000 --- a/xeofs/preprocessing/_base_scaler.py +++ /dev/null @@ -1,32 +0,0 @@ -from abc import ABC, abstractmethod - - -class _BaseScaler(ABC): - def __init__(self, with_std=True, with_coslat=False, with_weights=False): - self._params = dict( - with_std=with_std, with_coslat=with_coslat, with_weights=with_weights - ) - - self.mean = None - self.std = None - self.coslat_weights = None - self.weights = None - - @abstractmethod - def fit(self, X, sample_dims, feature_dims, weights=None): - raise NotImplementedError - - @abstractmethod - def transform(self, X): - raise NotImplementedError - - @abstractmethod - def fit_transform(self, X, sample_dims, feature_dims, weights=None): - raise NotImplementedError - - @abstractmethod - def inverse_transform(self, X): - raise NotImplementedError - - def get_params(self): - return self._params.copy() diff --git a/xeofs/preprocessing/_base_stacker.py b/xeofs/preprocessing/_base_stacker.py deleted file mode 100644 index 7cd742c..0000000 --- a/xeofs/preprocessing/_base_stacker.py +++ /dev/null @@ -1,152 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Sequence, Hashable, List - -import xarray as xr - - -class _BaseStacker(ABC): - """Abstract base class for stacking data into a 2D array. - - Every multi-dimensional array is be reshaped into a 2D array with the - dimensions (sample x feature). - - - Attributes - ---------- - dims_in_ : tuple - The dimensions of the data used to fit the stacker. - dims_out_ : dict['sample': ..., 'feature': ...] - The dimensions of the stacked data. - coords_in_ : dict - The coordinates of the data used to fit the stacker. - coords_out_ : dict['sample': ..., 'feature': ...] - The coordinates of the stacked data. Typically consist of MultiIndex. - - """ - - def __init__(self): - pass - - def fit( - self, - data, - sample_dims: Hashable | Sequence[Hashable], - feature_dims: Hashable | Sequence[Hashable] | List[Sequence[Hashable]], - ): - """Invoking a `fit` operation for a stacker object isn't practical because it requires stacking the data, - only to ascertain the output dimensions. This step is computationally expensive and unnecessary. - Therefore, instead of using a separate `fit` method, we combine the fit and transform steps - into the `fit_transform` method for efficiency. However, to maintain consistency with other classes - that do utilize a `fit` method, we retain the `fit` method here, albeit unimplemented. - - """ - raise NotImplementedError( - "Stacker does not implement fit method. Use fit_transform instead." - ) - - @abstractmethod - def fit_transform( - self, - data, - sample_dims: Hashable | Sequence[Hashable], - feature_dims: Hashable | Sequence[Hashable] | List[Sequence[Hashable]], - ) -> xr.DataArray: - """Fit the stacker to the data and then transform the data. - - Parameters - ---------- - data : DataArray - The data to be reshaped. - sample_dims : Hashable or Sequence[Hashable] - The dimensions of the data that will be stacked along the `sample` dimension. - feature_dims : Hashable or Sequence[Hashable] - The dimensions of the data that will be stacked along the `feature` dimension. - - Returns - ------- - DataArray - The reshaped data. - - Raises - ------ - ValueError - If any of the dimensions in `sample_dims` or `feature_dims` are not present in the data. - """ - raise NotImplementedError - - @abstractmethod - def transform(self, data) -> xr.DataArray: - """Reshape the data into a 2D version. - - Parameters - ---------- - data : DataArray - The data to be reshaped. - - Returns - ------- - DataArray - The reshaped data. - - Raises - ------ - ValueError - If the data to be transformed has different dimensions than the data used to fit the stacker. - ValueError - If the data to be transformed has different feature coordinates than the data used to fit the stacker. - ValueError - If the data to be transformed has individual NaNs. - - """ - raise NotImplementedError - - @abstractmethod - def inverse_transform_data(self, data: xr.DataArray): - """Reshape the 2D data (sample x feature) back into its original shape. - - Parameters - ---------- - data : DataArray - The data to be reshaped. - - Returns - ------- - DataArray - The reshaped data. - - """ - raise NotImplementedError - - @abstractmethod - def inverse_transform_components(self, data: xr.DataArray): - """Reshape the 2D data (mode x feature) back into its original shape. - - Parameters - ---------- - data : DataArray - The data to be reshaped. - - Returns - ------- - DataArray - The reshaped data. - - """ - raise NotImplementedError - - @abstractmethod - def inverse_transform_scores(self, data: xr.DataArray): - """Reshape the 2D data (sample x mode) back into its original shape. - - Parameters - ---------- - data : DataArray - The data to be reshaped. - - Returns - ------- - DataArray - The reshaped data. - - """ - raise NotImplementedError diff --git a/xeofs/preprocessing/concatenator.py b/xeofs/preprocessing/concatenator.py new file mode 100644 index 0000000..1eaf893 --- /dev/null +++ b/xeofs/preprocessing/concatenator.py @@ -0,0 +1,118 @@ +from typing import List, Optional +from typing_extensions import Self + +import pandas as pd +import numpy as np +import xarray as xr + +from .transformer import Transformer +from ..utils.data_types import ( + Dims, + DimsList, + DataArray, + DataSet, + Data, + DataVar, + DataList, + DataArrayList, + DataSetList, + DataVarList, +) + + +class Concatenator(Transformer): + """Concatenate a list of DataArrays along the feature dimensions.""" + + def __init__(self, sample_name: str = "sample", feature_name: str = "feature"): + super().__init__(sample_name, feature_name) + + self.stackers = [] + + def fit( + self, + X: List[DataArray], + sample_dims: Optional[Dims] = None, + feature_dims: Optional[DimsList] = None, + ) -> Self: + # Check that all inputs are DataArrays + if not all([isinstance(data, DataArray) for data in X]): + raise ValueError("Input must be a list of DataArrays") + + # Check that all inputs have shape 2 + if not all([len(data.dims) == 2 for data in X]): + raise ValueError("Input DataArrays must have shape 2") + + # Check that all inputs have the same sample_name and feature_name + if not all([data.dims == (self.sample_name, self.feature_name) for data in X]): + raise ValueError("Input DataArrays must have the same dimensions") + + self.n_data = len(X) + + # Set input feature coordinates + self.coords_in = [data.coords[self.feature_name] for data in X] + self.n_features = [coord.size for coord in self.coords_in] + + return self + + def transform(self, X: List[DataArray]) -> DataArray: + # Test whether the input list has same length as the number of stackers + if len(X) != self.n_data: + raise ValueError( + f"Invalid input. Number of DataArrays ({len(X)}) does not match the number of fitted DataArrays ({self.n_data})." + ) + + reindexed_data_list: List[DataArray] = [] + dummy_feature_coords = [] + + idx_range = np.cumsum([0] + self.n_features) + for i, data in enumerate(X): + # Create dummy feature coordinates for DataArray + new_coords = np.arange(idx_range[i], idx_range[i + 1]) + + # Replace original feature coordiantes with dummy coordinates + data = data.drop_vars(self.feature_name) + reindexed = data.assign_coords({self.feature_name: new_coords}) + + # Store dummy feature coordinates + dummy_feature_coords.append(new_coords) + reindexed_data_list.append(reindexed) + + self._dummy_feature_coords = dummy_feature_coords + + X_concat: DataArray = xr.concat(reindexed_data_list, dim=self.feature_name) + self.coords_out = X_concat.coords[self.feature_name] + + return X_concat + + def fit_transform( + self, + X: List[DataArray], + sample_dims: Optional[Dims] = None, + feature_dims: Optional[DimsList] = None, + ) -> DataArray: + return self.fit(X, sample_dims, feature_dims).transform(X) + + def _split_dataarray_into_list(self, data: DataArray) -> List[DataArray]: + feature_name = self.feature_name + data_list: List[DataArray] = [] + + for coords, features in zip(self.coords_in, self._dummy_feature_coords): + # Select the features corresponding to the current DataArray + sub_selection = data.sel({feature_name: features}) + # Replace dummy feature coordinates with original feature coordinates + sub_selection = sub_selection.assign_coords({feature_name: coords}) + data_list.append(sub_selection) + + return data_list + + def inverse_transform_data(self, X: DataArray) -> List[DataArray]: + """Reshape the 2D data (sample x feature) back into its original shape.""" + return self._split_dataarray_into_list(X) + + def inverse_transform_components(self, X: DataArray) -> List[DataArray]: + """Reshape the 2D components (sample x feature) back into its original shape.""" + return self._split_dataarray_into_list(X) + + def inverse_transform_scores(self, X: DataArray) -> DataArray: + """Reshape the 2D scores (sample x mode) back into its original shape.""" + return X diff --git a/xeofs/preprocessing/dimension_renamer.py b/xeofs/preprocessing/dimension_renamer.py new file mode 100644 index 0000000..7824173 --- /dev/null +++ b/xeofs/preprocessing/dimension_renamer.py @@ -0,0 +1,61 @@ +from typing_extensions import Self + +from .transformer import Transformer +from ..utils.data_types import Dims, DataArray, DataSet, Data, DataVar, DataVarBound + + +class DimensionRenamer(Transformer): + """Rename dimensions of a DataArray or Dataset. + + Parameters + ---------- + base: str + Base string for the new dimension names. + start: int + Start index for the new dimension names. + + """ + + def __init__(self, base="dim", start=0): + super().__init__() + self.base = base + self.start = start + self.dim_mapping = {} + + def fit(self, X: Data, sample_dims: Dims, feature_dims: Dims, **kwargs) -> Self: + self.sample_dims_before = sample_dims + self.feature_dims_before = feature_dims + + self.dim_mapping = { + dim: f"{self.base}{i}" for i, dim in enumerate(X.dims, start=self.start) + } + + self.sample_dims_after: Dims = tuple( + [self.dim_mapping[dim] for dim in self.sample_dims_before] + ) + self.feature_dims_after: Dims = tuple( + [self.dim_mapping[dim] for dim in self.feature_dims_before] + ) + + return self + + def transform(self, X: DataVarBound) -> DataVarBound: + try: + return X.rename(self.dim_mapping) + except ValueError: + raise ValueError("Cannot transform data. Dimensions are different.") + + def _inverse_transform(self, X: DataVarBound) -> DataVarBound: + given_dims = set(X.dims) + expected_dims = set(self.dim_mapping.values()) + dims = given_dims.intersection(expected_dims) + return X.rename({v: k for k, v in self.dim_mapping.items() if v in dims}) + + def inverse_transform_data(self, X: DataVarBound) -> DataVarBound: + return self._inverse_transform(X) + + def inverse_transform_components(self, X: DataVarBound) -> DataVarBound: + return self._inverse_transform(X) + + def inverse_transform_scores(self, X: DataArray) -> DataArray: + return self._inverse_transform(X) diff --git a/xeofs/preprocessing/factory.py b/xeofs/preprocessing/factory.py new file mode 100644 index 0000000..19c9a0c --- /dev/null +++ b/xeofs/preprocessing/factory.py @@ -0,0 +1,57 @@ +# import xarray as xr + +# from .scaler import DataArrayScaler, DataSetScaler, DataListScaler +# from .stacker import DataArrayStacker, DataSetStacker, DataListStacker +# from .multi_index_converter import ( +# DataArrayMultiIndexConverter, +# DataSetMultiIndexConverter, +# DataListMultiIndexConverter, +# ) +# from ..utils.data_types import DataObject + + +# class ScalerFactory: +# @staticmethod +# def create_scaler(data: DataObject, **kwargs): +# if isinstance(data, xr.DataArray): +# return DataArrayScaler(**kwargs) +# elif isinstance(data, xr.Dataset): +# return DataSetScaler(**kwargs) +# elif isinstance(data, list) and all( +# isinstance(da, xr.DataArray) for da in data +# ): +# return DataListScaler(**kwargs) +# else: +# raise ValueError("Invalid data type") + + +# class MultiIndexConverterFactory: +# @staticmethod +# def create_converter( +# data: DataObject, **kwargs +# ) -> DataArrayMultiIndexConverter | DataListMultiIndexConverter: +# if isinstance(data, xr.DataArray): +# return DataArrayMultiIndexConverter(**kwargs) +# elif isinstance(data, xr.Dataset): +# return DataSetMultiIndexConverter(**kwargs) +# elif isinstance(data, list) and all( +# isinstance(da, xr.DataArray) for da in data +# ): +# return DataListMultiIndexConverter(**kwargs) +# else: +# raise ValueError("Invalid data type") + + +# class StackerFactory: +# @staticmethod +# def create_stacker(data: DataObject, **kwargs): +# if isinstance(data, xr.DataArray): +# return DataArrayStacker(**kwargs) +# elif isinstance(data, xr.Dataset): +# return DataSetStacker(**kwargs) +# elif isinstance(data, list) and all( +# isinstance(da, xr.DataArray) for da in data +# ): +# return DataListStacker(**kwargs) +# else: +# raise ValueError("Invalid data type") diff --git a/xeofs/preprocessing/list_processor.py b/xeofs/preprocessing/list_processor.py new file mode 100644 index 0000000..d7e910b --- /dev/null +++ b/xeofs/preprocessing/list_processor.py @@ -0,0 +1,105 @@ +from typing import List, TypeVar, Generic, Type, Dict, Any +from typing_extensions import Self + +from .dimension_renamer import DimensionRenamer +from .scaler import Scaler +from .sanitizer import Sanitizer +from .multi_index_converter import MultiIndexConverter +from .stacker import Stacker +from ..utils.data_types import ( + Data, + DataVar, + DataVarBound, + DataArray, + DataSet, + Dims, + DimsList, +) + +T = TypeVar( + "T", + bound=(DimensionRenamer | Scaler | MultiIndexConverter | Stacker | Sanitizer), +) + + +class GenericListTransformer(Generic[T]): + """Apply a Transformer to each of the elements of a list. + + Parameters + ---------- + transformer: Transformer + Transformer class to apply to list elements. + kwargs: dict + Keyword arguments for the transformer. + """ + + def __init__(self, transformer: Type[T], **kwargs): + self.transformer_class = transformer + self.transformers: List[T] = [] + self.init_kwargs = kwargs + + def fit( + self, + X: List[DataVar], + sample_dims: Dims, + feature_dims: DimsList, + iter_kwargs: Dict[str, List[Any]] = {}, + ) -> Self: + """Fit transformer to each data element in the list. + + Parameters + ---------- + X: List[Data] + List of data elements. + sample_dims: Dims + Sample dimensions. + feature_dims: DimsList + Feature dimensions. + iter_kwargs: Dict[str, List[Any]] + Keyword arguments for the transformer that should be iterated over. + + """ + self._sample_dims = sample_dims + self._feature_dims = feature_dims + self._iter_kwargs = iter_kwargs + + for i, x in enumerate(X): + # Add transformer specific keyword arguments + # For iterable kwargs, use the i-th element of the iterable + kwargs = {k: v[i] for k, v in self._iter_kwargs.items()} + proc: T = self.transformer_class(**self.init_kwargs) + proc.fit(x, sample_dims, feature_dims[i], **kwargs) + self.transformers.append(proc) + return self + + def transform(self, X: List[Data]) -> List[Data]: + X_transformed: List[Data] = [] + for x, proc in zip(X, self.transformers): + X_transformed.append(proc.transform(x)) # type: ignore + return X_transformed + + def fit_transform( + self, + X: List[Data], + sample_dims: Dims, + feature_dims: DimsList, + iter_kwargs: Dict[str, List[Any]] = {}, + ) -> List[Data]: + return self.fit(X, sample_dims, feature_dims, iter_kwargs).transform(X) # type: ignore + + def inverse_transform_data(self, X: List[Data]) -> List[Data]: + X_inverse_transformed: List[Data] = [] + for x, proc in zip(X, self.transformers): + x_inv_trans = proc.inverse_transform_data(x) # type: ignore + X_inverse_transformed.append(x_inv_trans) + return X_inverse_transformed + + def inverse_transform_components(self, X: List[Data]) -> List[Data]: + X_inverse_transformed: List[Data] = [] + for x, proc in zip(X, self.transformers): + x_inv_trans = proc.inverse_transform_components(x) # type: ignore + X_inverse_transformed.append(x_inv_trans) + return X_inverse_transformed + + def inverse_transform_scores(self, X: DataArray) -> DataArray: + return self.transformers[0].inverse_transform_scores(X) diff --git a/xeofs/preprocessing/multi_index_converter.py b/xeofs/preprocessing/multi_index_converter.py new file mode 100644 index 0000000..74c497a --- /dev/null +++ b/xeofs/preprocessing/multi_index_converter.py @@ -0,0 +1,109 @@ +from typing import List, Optional +from typing_extensions import Self +import pandas as pd + +from .transformer import Transformer +from ..utils.data_types import Dims, DataArray, DataSet, Data, DataVar, DataVarBound + + +class MultiIndexConverter(Transformer): + """Convert MultiIndexes of an ND DataArray or Dataset to regular indexes.""" + + def __init__(self): + super().__init__() + self.original_indexes = {} + self.modified_dimensions = [] + + def fit( + self, + X: Data, + sample_dims: Optional[Dims] = None, + feature_dims: Optional[Dims] = None, + **kwargs + ) -> Self: + # Store original MultiIndexes and replace with simple index + for dim in X.dims: + index = X.indexes[dim] + if isinstance(index, pd.MultiIndex): + self.original_indexes[dim] = X.coords[dim] + self.modified_dimensions.append(dim) + + return self + + def transform(self, X: DataVar) -> DataVar: + X_transformed = X.copy(deep=True) + + # Replace MultiIndexes with simple index + for dim in self.modified_dimensions: + size = X_transformed.coords[dim].size + X_transformed = X_transformed.drop_vars(dim) + X_transformed.coords[dim] = range(size) + + return X_transformed + + def _inverse_transform(self, X: DataVarBound) -> DataVarBound: + X_inverse_transformed = X.copy(deep=True) + + # Restore original MultiIndexes + for dim, original_index in self.original_indexes.items(): + if dim in X_inverse_transformed.dims: + X_inverse_transformed.coords[dim] = original_index + # Set indexes to original MultiIndexes + indexes = [ + idx + for idx in self.original_indexes[dim].indexes.keys() + if idx != dim + ] + X_inverse_transformed = X_inverse_transformed.set_index({dim: indexes}) + + return X_inverse_transformed + + def inverse_transform_data(self, X: DataVarBound) -> DataVarBound: + return self._inverse_transform(X) + + def inverse_transform_components(self, X: DataVarBound) -> DataVarBound: + return self._inverse_transform(X) + + def inverse_transform_scores(self, X: DataArray) -> DataArray: + return self._inverse_transform(X) + + +# class DataListMultiIndexConverter(BaseEstimator, TransformerMixin): +# """Converts MultiIndexes to simple indexes and vice versa.""" + +# def __init__(self): +# self.converters: List[MultiIndexConverter] = [] + +# def fit(self, X: List[Data], y=None): +# for x in X: +# converter = MultiIndexConverter() +# converter.fit(x) +# self.converters.append(converter) + +# return self + +# def transform(self, X: List[Data]) -> List[Data]: +# X_transformed: List[Data] = [] +# for x, converter in zip(X, self.converters): +# X_transformed.append(converter.transform(x)) + +# return X_transformed + +# def fit_transform(self, X: List[Data], y=None) -> List[Data]: +# return self.fit(X, y).transform(X) + +# def _inverse_transform(self, X: List[Data]) -> List[Data]: +# X_inverse_transformed: List[Data] = [] +# for x, converter in zip(X, self.converters): +# X_inverse_transformed.append(converter._inverse_transform(x)) + +# return X_inverse_transformed + +# def inverse_transform_data(self, X: List[Data]) -> List[Data]: +# return self._inverse_transform(X) + +# def inverse_transform_components(self, X: List[Data]) -> List[Data]: +# return self._inverse_transform(X) + +# def inverse_transform_scores(self, X: DataArray) -> DataArray: +# return self.converters[0].inverse_transform_scores(X) diff --git a/xeofs/preprocessing/preprocessor.py b/xeofs/preprocessing/preprocessor.py index 9a7137d..8da82bc 100644 --- a/xeofs/preprocessing/preprocessor.py +++ b/xeofs/preprocessing/preprocessor.py @@ -1,9 +1,55 @@ -from typing import Optional, Sequence, Hashable, List +from typing import Optional, Sequence, Hashable, List, Tuple, Any, Type -from .scaler_factory import ScalerFactory -from .stacker_factory import StackerFactory -from ..utils.xarray_utils import get_dims -from ..utils.data_types import AnyDataObject, DataArray +import numpy as np + +from .list_processor import GenericListTransformer +from .dimension_renamer import DimensionRenamer +from .scaler import Scaler +from .stacker import StackerFactory, Stacker +from .multi_index_converter import MultiIndexConverter +from .sanitizer import Sanitizer +from .concatenator import Concatenator +from ..utils.xarray_utils import ( + get_dims, + unwrap_singleton_list, + process_parameter, + _check_parameter_number, + convert_to_list, +) +from ..utils.data_types import ( + DataArray, + Data, + DataVar, + DataVarBound, + DataList, + Dims, + DimsList, +) + + +def extract_new_dim_names(X: List[DimensionRenamer]) -> Tuple[Dims, DimsList]: + """Extract the new dimension names from a list of DimensionRenamer objects. + + Parameters + ---------- + X : list of DimensionRenamer + List of DimensionRenamer objects. + + Returns + ------- + Dims + Sample dimensions + DimsList + Feature dimenions + + """ + new_sample_dims = [] + new_feature_dims: DimsList = [] + for x in X: + new_sample_dims.append(x.sample_dims_after) + new_feature_dims.append(x.feature_dims_after) + new_sample_dims: Dims = tuple(np.unique(np.asarray(new_sample_dims))) + return new_sample_dims, new_feature_dims class Preprocessor: @@ -17,85 +63,136 @@ class Preprocessor: Parameters ---------- + sample_name : str, default="sample" + Name of the sample dimension. + feature_name : str, default="feature" + Name of the feature dimension. + with_center : bool, default=True + If True, the data is centered by subtracting the mean. with_std : bool, default=True If True, the data is divided by the standard deviation. with_coslat : bool, default=False If True, the data is multiplied by the square root of cosine of latitude weights. with_weights : bool, default=False If True, the data is multiplied by additional user-defined weights. + return_list : bool, default=True + If True, the output is returned as a list of DataArrays. If False, the output is returned as a single DataArray if possible. """ - def __init__(self, with_std=True, with_coslat=False, with_weights=False): - # Define model parameters - self._params = { - "with_std": with_std, - "with_coslat": with_coslat, - "with_weights": with_weights, - } + def __init__( + self, + sample_name: str = "sample", + feature_name: str = "feature", + with_center: bool = True, + with_std: bool = False, + with_coslat: bool = False, + return_list: bool = True, + ): + # Set parameters + self.sample_name = sample_name + self.feature_name = feature_name + self.with_center = with_center + self.with_std = with_std + self.with_coslat = with_coslat + self.return_list = return_list def fit( self, - data: AnyDataObject, - dim: Hashable | Sequence[Hashable] | List[Sequence[Hashable]], - weights: Optional[AnyDataObject] = None, + X: List[Data] | Data, + sample_dims: Dims, + weights: Optional[List[Data] | Data] = None, ): - """Just for consistency with the other classes.""" - raise NotImplementedError( - "Preprocessor does not implement fit method. Use fit_transform instead." - ) + self._set_return_list(X) + X = convert_to_list(X) + self.n_data = len(X) + sample_dims, feature_dims = get_dims(X, sample_dims) - def fit_transform( - self, - data: AnyDataObject, - dim: Hashable | Sequence[Hashable] | List[Sequence[Hashable]], - weights: Optional[AnyDataObject] = None, - ) -> DataArray: - """Preprocess the data. + # Set sample and feature dimensions + self.dims = { + self.sample_name: sample_dims, + self.feature_name: feature_dims, + } - This will scale and stack the data. + # However, for each DataArray a list of feature dimensions must be provided + _check_parameter_number("feature_dims", feature_dims, self.n_data) - Parameters: - ------------- - data: xr.DataArray or list of xarray.DataArray - Input data. - dim: tuple - Tuple specifying the sample dimensions. The remaining dimensions - will be treated as feature dimensions. - weights: xr.DataArray or xr.Dataset or None, default=None - If specified, the input data will be weighted by this array. + # Ensure that weights are provided as a list + weights = process_parameter("weights", weights, None, self.n_data) - """ - # Set sample and feature dimensions - sample_dims, feature_dims = get_dims(data, sample_dims=dim) - self.dims = {"sample": sample_dims, "feature": feature_dims} + # 1 | Center, scale and weigh the data + scaler_kwargs = { + "with_center": self.with_center, + "with_std": self.with_std, + "with_coslat": self.with_coslat, + } + scaler_ikwargs = { + "weights": weights, + } + self.scaler = GenericListTransformer(Scaler, **scaler_kwargs) + X = self.scaler.fit_transform(X, sample_dims, feature_dims, scaler_ikwargs) - # Scale the data - self.scaler = ScalerFactory.create_scaler(data, **self._params) - data = self.scaler.fit_transform(data, sample_dims, feature_dims, weights) + # 2 | Rename dimensions + self.renamer = GenericListTransformer(DimensionRenamer) + X = self.renamer.fit_transform(X, sample_dims, feature_dims) + sample_dims, feature_dims = extract_new_dim_names(self.renamer.transformers) - # Stack the data - self.stacker = StackerFactory.create_stacker(data) - return self.stacker.fit_transform(data, sample_dims, feature_dims) + # 3 | Convert MultiIndexes (before stacking) + self.preconverter = GenericListTransformer(MultiIndexConverter) + X = self.preconverter.fit_transform(X, sample_dims, feature_dims) - def transform(self, data: AnyDataObject) -> DataArray: - """Project new unseen data onto the components (EOFs/eigenvectors). + # 4 | Stack the data to 2D DataArray + stacker_kwargs = { + "sample_name": self.sample_name, + "feature_name": self.feature_name, + } + stack_type: Type[Stacker] = StackerFactory.create(X[0]) + self.stacker = GenericListTransformer(stack_type, **stacker_kwargs) + X = self.stacker.fit_transform(X, sample_dims, feature_dims) + # 5 | Convert MultiIndexes (after stacking) + self.postconverter = GenericListTransformer(MultiIndexConverter) + X = self.postconverter.fit_transform(X, sample_dims, feature_dims) + # 6 | Remove NaNs + sanitizer_kwargs = { + "sample_name": self.sample_name, + "feature_name": self.feature_name, + } + self.sanitizer = GenericListTransformer(Sanitizer, **sanitizer_kwargs) + X = self.sanitizer.fit_transform(X, sample_dims, feature_dims) - Parameters: - ------------- - data: xr.DataArray or list of xarray.DataArray - Input data. + # 7 | Concatenate into one 2D DataArray + self.concatenator = Concatenator(self.sample_name, self.feature_name) + self.concatenator.fit(X) # type: ignore - Returns: - ---------- - projections: DataArray | Dataset | List[DataArray] - Projections of the new data onto the components. + return self - """ - data = self.scaler.transform(data) - return self.stacker.transform(data) + def transform(self, X: List[Data] | Data) -> DataArray: + X = convert_to_list(X) + + if len(X) != self.n_data: + raise ValueError( + f"number of data objects passed should match number of data objects used for fitting" + f"len(data objects)={len(X)} and " + f"len(data objects used for fitting)={self.n_data}" + ) + + X = self.scaler.transform(X) + X = self.renamer.transform(X) + X = self.preconverter.transform(X) + X = self.stacker.transform(X) + X = self.postconverter.transform(X) + X = self.sanitizer.transform(X) + return self.concatenator.transform(X) # type: ignore + + def fit_transform( + self, + X: List[Data] | Data, + sample_dims: Dims, + weights: Optional[List[Data] | Data] = None, + ) -> DataArray: + return self.fit(X, sample_dims, weights).transform(X) - def inverse_transform_data(self, data: DataArray) -> AnyDataObject: + def inverse_transform_data(self, X: DataArray) -> List[Data] | Data: """Inverse transform the data. Parameters: @@ -109,10 +206,16 @@ def inverse_transform_data(self, data: DataArray) -> AnyDataObject: The inverse transformed data. """ - data = self.stacker.inverse_transform_data(data) - return self.scaler.inverse_transform(data) + X_list = self.concatenator.inverse_transform_data(X) + X_list = self.sanitizer.inverse_transform_data(X_list) # type: ignore + X_list = self.postconverter.inverse_transform_data(X_list) + X_list_ND = self.stacker.inverse_transform_data(X_list) + X_list_ND = self.preconverter.inverse_transform_data(X_list_ND) + X_list_ND = self.renamer.inverse_transform_data(X_list_ND) + X_list_ND = self.scaler.inverse_transform_data(X_list_ND) + return self._process_output(X_list_ND) - def inverse_transform_components(self, data: DataArray) -> AnyDataObject: + def inverse_transform_components(self, X: DataArray) -> List[Data] | Data: """Inverse transform the components. Parameters: @@ -126,9 +229,16 @@ def inverse_transform_components(self, data: DataArray) -> AnyDataObject: The inverse transformed components. """ - return self.stacker.inverse_transform_components(data) + X_list = self.concatenator.inverse_transform_components(X) + X_list = self.sanitizer.inverse_transform_components(X_list) # type: ignore + X_list = self.postconverter.inverse_transform_components(X_list) + X_list_ND = self.stacker.inverse_transform_components(X_list) + X_list_ND = self.preconverter.inverse_transform_components(X_list_ND) + X_list_ND = self.renamer.inverse_transform_components(X_list_ND) + X_list_ND = self.scaler.inverse_transform_components(X_list_ND) + return self._process_output(X_list_ND) - def inverse_transform_scores(self, data: DataArray) -> AnyDataObject: + def inverse_transform_scores(self, X: DataArray) -> DataArray: """Inverse transform the scores. Parameters: @@ -142,4 +252,23 @@ def inverse_transform_scores(self, data: DataArray) -> AnyDataObject: The inverse transformed scores. """ - return self.stacker.inverse_transform_scores(data) + X_list = self.concatenator.inverse_transform_scores(X) + X_list = self.sanitizer.inverse_transform_scores(X_list) + X_list = self.postconverter.inverse_transform_scores(X_list) + X_list_ND = self.stacker.inverse_transform_scores(X_list) + X_list_ND = self.preconverter.inverse_transform_scores(X_list_ND) + X_list_ND = self.renamer.inverse_transform_scores(X_list_ND) + X_list_ND = self.scaler.inverse_transform_scores(X_list_ND) + return X_list_ND + + def _process_output(self, X: List[Data]) -> List[Data] | Data: + if self.return_list: + return X + else: + return unwrap_singleton_list(X) + + def _set_return_list(self, X): + if isinstance(X, (list, tuple)): + self.return_list = True + else: + self.return_list = False diff --git a/xeofs/preprocessing/sanitizer.py b/xeofs/preprocessing/sanitizer.py new file mode 100644 index 0000000..f9e05a3 --- /dev/null +++ b/xeofs/preprocessing/sanitizer.py @@ -0,0 +1,112 @@ +from typing import Optional +from typing_extensions import Self +import xarray as xr + +from .transformer import Transformer +from ..utils.data_types import Dims, DataArray, DataSet, Data, DataVar + + +class Sanitizer(Transformer): + """ + Removes NaNs from the feature dimension of a 2D DataArray. + + """ + + def __init__(self, sample_name="sample", feature_name="feature"): + super().__init__(sample_name=sample_name, feature_name=feature_name) + + def _check_input_type(self, X) -> None: + if not isinstance(X, xr.DataArray): + raise ValueError("Input must be an xarray DataArray") + + def _check_input_dims(self, X) -> None: + if set(X.dims) != set([self.sample_name, self.feature_name]): + raise ValueError( + "Input must have dimensions ({:}, {:})".format( + self.sample_name, self.feature_name + ) + ) + + def _check_input_coords(self, X) -> None: + if not X.coords[self.feature_name].identical(self.feature_coords): + raise ValueError( + "Cannot transform data. Feature coordinates are different." + ) + + def fit( + self, + X: Data, + sample_dims: Optional[Dims] = None, + feature_dims: Optional[Dims] = None, + **kwargs + ) -> Self: + # Check if input is a DataArray + self._check_input_type(X) + + # Check if input has the correct dimensions + self._check_input_dims(X) + + self.feature_coords = X.coords[self.feature_name] + self.sample_coords = X.coords[self.sample_name] + + # Identify NaN locations + self.is_valid_feature = ~X.isnull().all(self.sample_name).compute() + # NOTE: We must also consider the presence of valid samples. For + # instance, when PCA is applied with "longitude" and "latitude" as + # sample dimensions, certain grid cells may be masked (e.g., due to + # ocean areas). To ensure correct reconstruction of scores, + # we need to identify the sample positions of NaNs in the fitted + # dataset. Keep in mind that when transforming new data, + # we have to recheck for valid samples, as the new dataset may have + # different samples. + X_valid = X.sel({self.feature_name: self.is_valid_feature}) + self.is_valid_sample = ~X_valid.isnull().all(self.feature_name).compute() + + return self + + def transform(self, X: DataArray) -> DataArray: + # Check if input is a DataArray + self._check_input_type(X) + + # Check if input has the correct dimensions + self._check_input_dims(X) + + # Check if input has the correct coordinates + self._check_input_coords(X) + + # Remove NaN entries; only consider full-dimensional NaNs + # We already know valid features from the fitted dataset + X = X.isel({self.feature_name: self.is_valid_feature}) + # However, we need to recheck for valid samples, as the new dataset may + # have different samples + is_valid_sample = ~X.isnull().all(self.feature_name).compute() + X = X.isel({self.sample_name: is_valid_sample}) + + return X + + def inverse_transform_data(self, X: DataArray) -> DataArray: + # Reindex only if feature coordinates are different + coords_are_equal = X.coords[self.feature_name].identical(self.feature_coords) + + if coords_are_equal: + return X + else: + return X.reindex({self.feature_name: self.feature_coords.values}) + + def inverse_transform_components(self, X: DataArray) -> DataArray: + # Reindex only if feature coordinates are different + coords_are_equal = X.coords[self.feature_name].identical(self.feature_coords) + + if coords_are_equal: + return X + else: + return X.reindex({self.feature_name: self.feature_coords.values}) + + def inverse_transform_scores(self, X: DataArray) -> DataArray: + # Reindex only if sample coordinates are different + coords_are_equal = X.coords[self.sample_name].identical(self.sample_coords) + + if coords_are_equal: + return X + else: + return X.reindex({self.sample_name: self.sample_coords.values}) diff --git a/xeofs/preprocessing/scaler.py b/xeofs/preprocessing/scaler.py index 4eec3cc..c7eb920 100644 --- a/xeofs/preprocessing/scaler.py +++ b/xeofs/preprocessing/scaler.py @@ -1,28 +1,15 @@ -from typing import List, Union, Tuple, Dict, Optional, TypeVar, Any, Sequence, Hashable +from typing import Optional +from typing_extensions import Self import numpy as np import xarray as xr +from .transformer import Transformer +from ..utils.data_types import Dims, DataArray, DataSet, Data, DataVar, DataVarBound +from ..utils.xarray_utils import compute_sqrt_cos_lat_weights, feature_ones_like -from ._base_scaler import _BaseScaler -from ..utils.constants import VALID_LATITUDE_NAMES -from ..utils.sanity_checks import ( - assert_single_dataarray, - assert_single_dataset, - assert_list_dataarrays, - ensure_tuple, -) -from ..utils.data_types import ( - DataArray, - Dataset, - DataArrayList, - ModelDims, - SingleDataObject, -) -from ..utils.xarray_utils import compute_sqrt_cos_lat_weights - - -class _SingleDataScaler(_BaseScaler): + +class Scaler(Transformer): """Scale the data along sample dimensions. Scaling includes (i) removing the mean and, optionally, (ii) dividing by the standard deviation, @@ -31,420 +18,315 @@ class _SingleDataScaler(_BaseScaler): Parameters ---------- + with_center : bool, default=True + If True, the data is centered by subtracting the mean. with_std : bool, default=True If True, the data is divided by the standard deviation. with_coslat : bool, default=False If True, the data is multiplied by the square root of cosine of latitude weights. - with_weights : bool, default=False - If True, the data is multiplied by additional user-defined weights. - + weights : DataArray | Dataset, optional + Weights to be applied to the data. Must have the same dimensions as the data. + If None, no weights are applied. """ - def _verify_input(self, data: SingleDataObject, name: str): - raise NotImplementedError - - def _compute_sqrt_cos_lat_weights( - self, data: SingleDataObject, dim - ) -> SingleDataObject: - """Compute the square root of cosine of latitude weights. - - Parameters - ---------- - data : SingleDataObject - Data to be scaled. - dim : sequence of hashable - Dimensions along which the data is considered to be a feature. - - Returns - ------- - SingleDataObject - Square root of cosine of latitude weights. + def __init__( + self, + with_center: bool = True, + with_std: bool = False, + with_coslat: bool = False, + ): + super().__init__() + self.with_center = with_center + self.with_std = with_std + self.with_coslat = with_coslat - """ - self._verify_input(data, "data") + def _verify_input(self, X, name: str): + if not isinstance(X, (xr.DataArray, xr.Dataset)): + raise TypeError(f"{name} must be an xarray DataArray or Dataset") - weights = compute_sqrt_cos_lat_weights(data, dim) - weights.name = "coslat_weights" + def _process_weights(self, X: DataVarBound, weights) -> DataVarBound: + if weights is None: + wghts: DataVarBound = feature_ones_like(X, self.feature_dims) + else: + wghts: DataVarBound = weights - return weights + return wghts def fit( self, - data: SingleDataObject, - sample_dims: Hashable | Sequence[Hashable], - feature_dims: Hashable | Sequence[Hashable], - weights: Optional[SingleDataObject] = None, - ): + X: DataVar, + sample_dims: Dims, + feature_dims: Dims, + weights: Optional[DataVar] = None, + ) -> Self: """Fit the scaler to the data. Parameters ---------- - data : SingleDataObject + X : DataArray | Dataset Data to be scaled. sample_dims : sequence of hashable Dimensions along which the data is considered to be a sample. feature_dims : sequence of hashable Dimensions along which the data is considered to be a feature. - weights : SingleDataObject, optional + weights : DataArray | Dataset, optional Weights to be applied to the data. Must have the same dimensions as the data. If None, no weights are applied. """ # Check input types - self._verify_input(data, "data") - if weights is not None: - self._verify_input(weights, "weights") - - sample_dims = ensure_tuple(sample_dims) - feature_dims = ensure_tuple(feature_dims) + self._verify_input(X, "data") + self.sample_dims = sample_dims + self.feature_dims = feature_dims # Store sample and feature dimensions for later use - self.dims: ModelDims = {"sample": sample_dims, "feature": feature_dims} + self.dims = {"sample": sample_dims, "feature": feature_dims} + + params = self.get_params() # Scaling parameters are computed along sample dimensions - self.mean: SingleDataObject = data.mean(sample_dims).compute() + if params["with_center"]: + self.mean_: DataVar = X.mean(self.sample_dims).compute() - if self._params["with_std"]: - self.std: SingleDataObject = data.std(sample_dims).compute() + if params["with_std"]: + self.std_: DataVar = X.std(self.sample_dims).compute() - if self._params["with_coslat"]: - self.coslat_weights: SingleDataObject = self._compute_sqrt_cos_lat_weights( - data, feature_dims + if params["with_coslat"]: + self.coslat_weights_: DataVar = compute_sqrt_cos_lat_weights( + data=X, feature_dims=self.feature_dims ).compute() - if self._params["with_weights"]: - if weights is None: - raise ValueError("Weights must be provided when with_weights is True") - self.weights: SingleDataObject = weights.compute() + # Convert None weights to ones + self.weights_: DataVar = self._process_weights(X, weights).compute() - def transform(self, data: SingleDataObject) -> SingleDataObject: + return self + + def transform(self, X: DataVarBound) -> DataVarBound: """Scale the data. Parameters ---------- - data : SingleDataObject + data : DataArray | Dataset Data to be scaled. Returns ------- - SingleDataObject + DataArray | Dataset Scaled data. """ - self._verify_input(data, "data") + self._verify_input(X, "X") + + params = self.get_params() - data = data - self.mean + if params["with_center"]: + X = X - self.mean_ + if params["with_std"]: + X = X / self.std_ + if params["with_coslat"]: + X = X * self.coslat_weights_ - if self._params["with_std"]: - data = data / self.std - if self._params["with_coslat"]: - data = data * self.coslat_weights - if self._params["with_weights"]: - data = data * self.weights - return data + X = X * self.weights_ + return X def fit_transform( self, - data: SingleDataObject, - sample_dims: Hashable | Sequence[Hashable], - feature_dims: Hashable | Sequence[Hashable], - weights: Optional[SingleDataObject] = None, - ) -> SingleDataObject: - """Fit the scaler to the data and scale it. - - Parameters - ---------- - data : SingleDataObject - Data to be scaled. - sample_dims : sequence of hashable - Dimensions along which the data is considered to be a sample. - feature_dims : sequence of hashable - Dimensions along which the data is considered to be a feature. - weights : SingleDataObject, optional - Weights to be applied to the data. Must have the same dimensions as the data. - If None, no weights are applied. - - Returns - ------- - SingleDataObject - Scaled data. - - """ - self.fit(data, sample_dims, feature_dims, weights) - return self.transform(data) - - def inverse_transform(self, data: SingleDataObject) -> SingleDataObject: + X: DataVarBound, + sample_dims: Dims, + feature_dims: Dims, + weights: Optional[DataVarBound] = None, + ) -> DataVarBound: + return self.fit(X, sample_dims, feature_dims, weights).transform(X) + + def inverse_transform_data(self, X: DataVarBound) -> DataVarBound: """Unscale the data. Parameters ---------- - data : SingleDataObject + X : DataArray | DataSet Data to be unscaled. Returns ------- - SingleDataObject + DataArray | DataSet Unscaled data. """ - self._verify_input(data, "data") - - if self._params["with_weights"]: - data = data / self.weights - if self._params["with_coslat"]: - data = data / self.coslat_weights - if self._params["with_std"]: - data = data * self.std - - data = data + self.mean - - return data - - -class SingleDataArrayScaler(_SingleDataScaler): - def _verify_input(self, data: DataArray, name: str): - """Verify that the input data is a DataArray. - - Parameters - ---------- - data : xarray.Dataset - Data to be checked. - - """ - assert_single_dataarray(data, name) - - def _compute_sqrt_cos_lat_weights(self, data: DataArray, dim) -> DataArray: - return super()._compute_sqrt_cos_lat_weights(data, dim) - - def fit( - self, - data: DataArray, - sample_dims: Hashable | Sequence[Hashable], - feature_dims: Hashable | Sequence[Hashable], - weights: Optional[DataArray] = None, - ): - super().fit(data, sample_dims, feature_dims, weights) - - def transform(self, data: DataArray) -> DataArray: - return super().transform(data) - - def fit_transform( - self, - data: DataArray, - sample_dims: Hashable | Sequence[Hashable], - feature_dims: Hashable | Sequence[Hashable], - weights: Optional[DataArray] = None, - ) -> DataArray: - return super().fit_transform(data, sample_dims, feature_dims, weights) - - def inverse_transform(self, data: DataArray) -> DataArray: - return super().inverse_transform(data) - - -class SingleDatasetScaler(_SingleDataScaler): - def _verify_input(self, data: Dataset, name: str): - """Verify that the input data is a Dataset. - - Parameters - ---------- - data : xarray.Dataset - Data to be checked. - - """ - assert_single_dataset(data, name) - - def _compute_sqrt_cos_lat_weights(self, data: Dataset, dim) -> Dataset: - return super()._compute_sqrt_cos_lat_weights(data, dim) - - def fit( - self, - data: Dataset, - sample_dims: Hashable | Sequence[Hashable], - feature_dims: Hashable | Sequence[Hashable], - weights: Optional[Dataset] = None, - ): - super().fit(data, sample_dims, feature_dims, weights) - - def transform(self, data: Dataset) -> Dataset: - return super().transform(data) - - def fit_transform( - self, - data: Dataset, - sample_dims: Hashable | Sequence[Hashable], - feature_dims: Hashable | Sequence[Hashable], - weights: Optional[Dataset] = None, - ) -> Dataset: - return super().fit_transform(data, sample_dims, feature_dims, weights) - - def inverse_transform(self, data: Dataset) -> Dataset: - return super().inverse_transform(data) - - -class ListDataArrayScaler(_BaseScaler): - """Scale a list of xr.DataArray along sample dimensions. - - Scaling includes (i) removing the mean and, optionally, (ii) dividing by the standard deviation, - (iii) multiplying by the square root of cosine of latitude weights (area weighting; coslat weighting), - and (iv) multiplying by additional user-defined weights. - - Parameters - ---------- - with_std : bool, default=True - If True, the data is divided by the standard deviation. - with_coslat : bool, default=False - If True, the data is multiplied by the square root of cosine of latitude weights. - with_weights : bool, default=False - If True, the data is multiplied by additional user-defined weights. - - """ - - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.scalers = [] - - def _verify_input(self, data: DataArrayList, name: str): - """Verify that the input data is a list of DataArrays. - - Parameters - ---------- - data : list of xarray.DataArray - Data to be checked. - - """ - assert_list_dataarrays(data, name) - - def fit( - self, - data: DataArrayList, - sample_dims: Hashable | Sequence[Hashable], - feature_dims_list: List[Hashable | Sequence[Hashable]], - weights=None, - ): - """Fit the scaler to the data. - - Parameters - ---------- - data : list of xarray.DataArray - Data to be scaled. - sample_dims : hashable or sequence of hashable - Dimensions along which the data is considered to be a sample. - feature_dims_list : list of hashable or list of sequence of hashable - List of dimensions along which the data is considered to be a feature. - weights : list of xarray.DataArray, optional - List of weights to be applied to the data. Must have the same dimensions as the data. - - """ - self._verify_input(data, "data") - - # Check input - if not isinstance(feature_dims_list, list): - err_message = "feature dims must be a list of the feature dimensions of each DataArray, " - err_message += 'e.g. [("lon", "lat"), ("lon")]' - raise TypeError(err_message) - - sample_dims = ensure_tuple(sample_dims) - feature_dims = [ensure_tuple(fdims) for fdims in feature_dims_list] - - # Sample dimensions are the same for all data arrays - # Feature dimensions may be different for each data array - self.dims: ModelDims = {"sample": sample_dims, "feature": feature_dims} - - # However, for each DataArray a list of feature dimensions must be provided - if len(data) != len(feature_dims): - err_message = ( - "Number of data arrays and feature dimensions must be the same. " - ) - err_message += f"Got {len(data)} data arrays and {len(feature_dims)} feature dimensions" - raise ValueError(err_message) - - self.weights = weights - # If no weights are provided, create a list of None - if self.weights is None: - self.weights = [None] * len(data) - # Check that number of weights is the same as number of data arrays - if self._params["with_weights"]: - if len(data) != len(self.weights): - err_message = "Number of data arrays and weights must be the same. " - err_message += ( - f"Got {len(data)} data arrays and {len(self.weights)} weights" - ) - raise ValueError(err_message) - - for da, wghts, fdims in zip(data, self.weights, feature_dims): - # Create SingleDataArrayScaler object for each data array - params = self.get_params() - scaler = SingleDataArrayScaler(**params) - scaler.fit(da, sample_dims=sample_dims, feature_dims=fdims, weights=wghts) - self.scalers.append(scaler) - - def transform(self, da_list: DataArrayList) -> DataArrayList: - """Scale the data. - - Parameters - ---------- - da_list : list of xarray.DataArray - Data to be scaled. - - Returns - ------- - list of xarray.DataArray - Scaled data. - - """ - self._verify_input(da_list, "da_list") - - da_list_transformed = [] - for scaler, da in zip(self.scalers, da_list): - da_list_transformed.append(scaler.transform(da)) - return da_list_transformed - - def fit_transform( - self, - data: DataArrayList, - sample_dims: Hashable | Sequence[Hashable], - feature_dims_list: List[Hashable | Sequence[Hashable]], - weights=None, - ) -> DataArrayList: - """Fit the scaler to the data and scale it. - - Parameters - ---------- - data : list of xr.DataArray - Data to be scaled. - sample_dims : hashable or sequence of hashable - Dimensions along which the data is considered to be a sample. - feature_dims_list : list of hashable or list of sequence of hashable - List of dimensions along which the data is considered to be a feature. - weights : list of xr.DataArray, optional - List of weights to be applied to the data. Must have the same dimensions as the data. - - Returns - ------- - list of xarray.DataArray - Scaled data. - - """ - self.fit(data, sample_dims, feature_dims_list, weights) - return self.transform(data) - - def inverse_transform(self, da_list: DataArrayList) -> DataArrayList: - """Unscale the data. - - Parameters - ---------- - da_list : list of xarray.DataArray - Data to be scaled. - - Returns - ------- - list of xarray.DataArray - Scaled data. - - """ - self._verify_input(da_list, "da_list") - - da_list_transformed = [] - for scaler, da in zip(self.scalers, da_list): - da_list_transformed.append(scaler.inverse_transform(da)) - return da_list_transformed + self._verify_input(X, "X") + + params = self.get_params() + X = X / self.weights_ + if params["with_coslat"]: + X = X / self.coslat_weights_ + if params["with_std"]: + X = X * self.std_ + if params["with_center"]: + X = X + self.mean_ + + return X + + def inverse_transform_components(self, X: DataVarBound) -> DataVarBound: + return X + + def inverse_transform_scores(self, X: DataArray) -> DataArray: + return X + + +# class DataListScaler(Scaler): +# """Scale a list of xr.DataArray along sample dimensions. + +# Scaling includes (i) removing the mean and, optionally, (ii) dividing by the standard deviation, +# (iii) multiplying by the square root of cosine of latitude weights (area weighting; coslat weighting), +# and (iv) multiplying by additional user-defined weights. + +# Parameters +# ---------- +# with_std : bool, default=True +# If True, the data is divided by the standard deviation. +# with_coslat : bool, default=False +# If True, the data is multiplied by the square root of cosine of latitude weights. +# with_weights : bool, default=False +# If True, the data is multiplied by additional user-defined weights. + +# """ + +# def __init__(self, with_std=False, with_coslat=False): +# super().__init__(with_std=with_std, with_coslat=with_coslat) +# self.scalers = [] + +# def _verify_input(self, data, name: str): +# """Verify that the input data is a list of DataArrays. + +# Parameters +# ---------- +# data : list of xarray.DataArray +# Data to be checked. + +# """ +# if not isinstance(data, list): +# raise TypeError(f"{name} must be a list of xarray DataArrays or Datasets") +# if not all(isinstance(da, (xr.DataArray, xr.Dataset)) for da in data): +# raise TypeError(f"{name} must be a list of xarray DataArrays or Datasets") + +# def fit( +# self, +# data: List[Data], +# sample_dims: Dims, +# feature_dims_list: DimsList, +# weights: Optional[List[Data] | Data] = None, +# ) -> Self: +# """Fit the scaler to the data. + +# Parameters +# ---------- +# data : list of xarray.DataArray +# Data to be scaled. +# sample_dims : hashable or sequence of hashable +# Dimensions along which the data is considered to be a sample. +# feature_dims_list : list of hashable or list of sequence of hashable +# List of dimensions along which the data is considered to be a feature. +# weights : list of xarray.DataArray, optional +# List of weights to be applied to the data. Must have the same dimensions as the data. + +# """ +# self._verify_input(data, "data") + +# # Check input +# if not isinstance(feature_dims_list, list): +# err_message = "feature dims must be a list of the feature dimensions of each DataArray, " +# err_message += 'e.g. [("lon", "lat"), ("lon")]' +# raise TypeError(err_message) + +# # Sample dimensions are the same for all data arrays +# # Feature dimensions may be different for each data array +# self.dims = {"sample": sample_dims, "feature": feature_dims_list} + +# # However, for each DataArray a list of feature dimensions must be provided +# _check_parameter_number("feature_dims", feature_dims_list, len(data)) + +# # If no weights are provided, create a list of None +# self.weights = process_parameter("weights", weights, None, len(data)) + +# params = self.get_params() + +# for da, wghts, fdims in zip(data, self.weights, feature_dims_list): +# # Create Scaler object for each data array +# scaler = Scaler(**params) +# scaler.fit(da, sample_dims=sample_dims, feature_dims=fdims, weights=wghts) +# self.scalers.append(scaler) + +# return self + +# def transform(self, da_list: List[Data]) -> List[Data]: +# """Scale the data. + +# Parameters +# ---------- +# da_list : list of xarray.DataArray +# Data to be scaled. + +# Returns +# ------- +# list of xarray.DataArray +# Scaled data. + +# """ +# self._verify_input(da_list, "da_list") + +# da_list_transformed = [] +# for scaler, da in zip(self.scalers, da_list): +# da_list_transformed.append(scaler.transform(da)) +# return da_list_transformed + +# def fit_transform( +# self, +# data: List[Data], +# sample_dims: Dims, +# feature_dims_list: DimsList, +# weights: Optional[List[Data] | Data] = None, +# ) -> List[Data]: +# """Fit the scaler to the data and scale it. + +# Parameters +# ---------- +# data : list of xr.DataArray +# Data to be scaled. +# sample_dims : hashable or sequence of hashable +# Dimensions along which the data is considered to be a sample. +# feature_dims_list : list of hashable or list of sequence of hashable +# List of dimensions along which the data is considered to be a feature. +# weights : list of xr.DataArray, optional +# List of weights to be applied to the data. Must have the same dimensions as the data. + +# Returns +# ------- +# list of xarray.DataArray +# Scaled data. + +# """ +# self.fit(data, sample_dims, feature_dims_list, weights) +# return self.transform(data) + +# def inverse_transform_data(self, da_list: List[Data]) -> List[Data]: +# """Unscale the data. + +# Parameters +# ---------- +# da_list : list of xarray.DataArray +# Data to be scaled. + +# Returns +# ------- +# list of xarray.DataArray +# Scaled data. + +# """ +# self._verify_input(da_list, "da_list") + +# da_list_transformed = [] +# for scaler, da in zip(self.scalers, da_list): +# da_list_transformed.append(scaler.inverse_transform_data(da)) +# return da_list_transformed + +# def inverse_transform_components(self, da_list: List[Data]) -> List[Data]: +# return da_list diff --git a/xeofs/preprocessing/scaler_factory.py b/xeofs/preprocessing/scaler_factory.py deleted file mode 100644 index 15564b2..0000000 --- a/xeofs/preprocessing/scaler_factory.py +++ /dev/null @@ -1,20 +0,0 @@ -import xarray as xr - -from ._base_scaler import _BaseScaler -from .scaler import SingleDataArrayScaler, SingleDatasetScaler, ListDataArrayScaler -from ..utils.data_types import AnyDataObject - - -class ScalerFactory: - @staticmethod - def create_scaler(data: AnyDataObject, **kwargs) -> _BaseScaler: - if isinstance(data, xr.DataArray): - return SingleDataArrayScaler(**kwargs) - elif isinstance(data, xr.Dataset): - return SingleDatasetScaler(**kwargs) - elif isinstance(data, list) and all( - isinstance(da, xr.DataArray) for da in data - ): - return ListDataArrayScaler(**kwargs) - else: - raise ValueError("Invalid data type") diff --git a/xeofs/preprocessing/stacker.py b/xeofs/preprocessing/stacker.py index d51c4ad..3f599b5 100644 --- a/xeofs/preprocessing/stacker.py +++ b/xeofs/preprocessing/stacker.py @@ -1,182 +1,179 @@ -from typing import List, Sequence, Hashable, Tuple +from abc import abstractmethod +from typing import List, Optional, Type +from typing_extensions import Self import numpy as np import pandas as pd import xarray as xr -from xeofs.utils.data_types import DataArray +from .transformer import Transformer +from ..utils.data_types import Dims, DataArray, DataSet, Data, DataVar, DataVarBound +from ..utils.sanity_checks import convert_to_dim_type -from ._base_stacker import _BaseStacker -from ..utils.data_types import ( - DataArray, - DataArrayList, - Dataset, - SingleDataObject, - AnyDataObject, -) -from ..utils.sanity_checks import ensure_tuple +class Stacker(Transformer): + """Converts a DataArray of any dimensionality into a 2D structure. -class SingleDataStacker(_BaseStacker): - def __init__(self): - super().__init__() + Attributes + ---------- + sample_dims : Sequence[Hashable] + The dimensions of the data that will be stacked along the `sample` dimension. + feature_dims : Sequence[Hashable] + The dimensions of the data that will be stacked along the `feature` dimension. + sample_name : str + The name of the sample dimension. + feature_name : str + The name of the feature dimension. + dims_in : Tuple[str] + The dimensions of the input data. + dims_out : Tuple[str] + The dimensions of the output data. + dims_mapping : Dict[str, Tuple[str]] + The mapping between the input and output dimensions. + coords_in : Dict[str, xr.Coordinates] + The coordinates of the input data. + coords_out : Dict[str, xr.Coordinates] + The coordinates of the output data. + """ + + def __init__( + self, + sample_name: str = "sample", + feature_name: str = "feature", + ): + super().__init__(sample_name, feature_name) - def _validate_matching_dimensions(self, data: SingleDataObject): + self.dims_in = tuple() + self.dims_out = tuple((sample_name, feature_name)) + self.dims_mapping = {} + self.dims_mapping.update({d: tuple() for d in self.dims_out}) + + self.coords_in = {} + self.coords_out = {} + + def _validate_matching_dimensions(self, X: Data): """Verify that the dimensions of the data are consistent with the dimensions used to fit the stacker.""" # Test whether sample and feature dimensions are present in data array - expected_dims = set(self.dims_out_["sample"] + self.dims_out_["feature"]) - given_dims = set(data.dims) + expected_sample_dims = set(self.dims_mapping[self.sample_name]) + expected_feature_dims = set(self.dims_mapping[self.feature_name]) + expected_dims = expected_sample_dims | expected_feature_dims + given_dims = set(X.dims) if not (expected_dims == given_dims): raise ValueError( f"One or more dimensions in {expected_dims} are not present in data." ) - def _validate_matching_feature_coords(self, data: SingleDataObject): + def _validate_matching_feature_coords(self, X: Data): """Verify that the feature coordinates of the data are consistent with the feature coordinates used to fit the stacker.""" + feature_dims = self.dims_mapping[self.feature_name] coords_are_equal = [ - data.coords[dim].equals(self.coords_in_[dim]) - for dim in self.dims_out_["feature"] + X.coords[dim].equals(self.coords_in[dim]) for dim in feature_dims ] if not all(coords_are_equal): raise ValueError( "Data to be transformed has different coordinates than the data used to fit." ) - def _reorder_dims(self, data): - """Reorder dimensions to original order; catch ('mode') dimensions via ellipsis""" - order_input_dims = [ - valid_dim for valid_dim in self.dims_in_ if valid_dim in data.dims - ] - return data.transpose(..., *order_input_dims) - - def _stack(self, data: SingleDataObject, sample_dims, feature_dims) -> DataArray: - """Reshape a SingleDataObject to 2D DataArray.""" - raise NotImplementedError - - def _unstack(self, data: SingleDataObject) -> SingleDataObject: - """Unstack `sample` and `feature` dimension of an DataArray to its original dimensions. + def _validate_dimension_names(self, sample_dims, feature_dims): + if len(sample_dims) > 1: + if self.sample_name in sample_dims: + raise ValueError( + f"Name of sample dimension ({self.sample_name}) is already present in data. Please use another name." + ) + if len(feature_dims) > 1: + if self.feature_name in feature_dims: + raise ValueError( + f"Name of feature dimension ({self.feature_name}) is already present in data. Please use another name." + ) + + def _validate_indices(self, X: Data): + """Check that the indices of the data are no MultiIndex""" + if any([isinstance(index, pd.MultiIndex) for index in X.indexes.values()]): + raise ValueError(f"Cannot stack data containing a MultiIndex.") + + def _sanity_check(self, X: Data, sample_dims, feature_dims): + self._validate_dimension_names(sample_dims, feature_dims) + self._validate_indices(X) + + @abstractmethod + def _stack(self, X: Data, sample_dims: Dims, feature_dims: Dims) -> DataArray: + """Stack data to 2D. Parameters ---------- data : DataArray - The data to be unstacked. + The data to be reshaped. + sample_dims : Hashable or Sequence[Hashable] + The dimensions of the data that will be stacked along the `sample` dimension. + feature_dims : Hashable or Sequence[Hashable] + The dimensions of the data that will be stacked along the `feature` dimension. Returns ------- - data_unstacked : DataArray - The unstacked data. + data_stacked : DataArray + The reshaped 2d-data. """ - raise NotImplementedError() - def _reindex_dim( - self, data: SingleDataObject, stacked_dim: str - ) -> SingleDataObject: - """Reindex data to original coordinates in case that some features at the boundaries were dropped + @abstractmethod + def _unstack(self, X: DataArray) -> Data: + """Unstack 2D DataArray to its original dimensions. Parameters ---------- data : DataArray - The data to be reindex. - stacked_dim : str ['sample', 'feature'] - The dimension to be reindexed. + The data to be unstacked. Returns ------- - DataArray - The reindexed data. - + data_unstacked : DataArray + The unstacked data. """ - # check if coordinates in self.coords have different length from data.coords - # if so, reindex data.coords to self.coords - # input_dim : dimensions of input data - # stacked_dim : dimensions of model data i.e. sample or feature - dims_in = self.dims_out_[stacked_dim] - for dim in dims_in: - if self.coords_in_[dim].size != data.coords[dim].size: - data = data.reindex({dim: self.coords_in_[dim]}, copy=False) - return data + def _reorder_dims(self, X: DataVarBound) -> DataVarBound: + """Reorder dimensions to original order; catch ('mode') dimensions via ellipsis""" + order_input_dims = [ + valid_dim for valid_dim in self.dims_in if valid_dim in X.dims + ] + if order_input_dims != X.dims: + X = X.transpose(..., *order_input_dims) + return X - def fit_transform( - self, - data: SingleDataObject, - sample_dims: Hashable | Sequence[Hashable], - feature_dims: Hashable | Sequence[Hashable], - ) -> DataArray: - """Fit the stacker and transform data to 2D. + def fit(self, X: Data, sample_dims: Dims, feature_dims: Dims) -> Self: + """Fit the stacker. Parameters ---------- data : DataArray The data to be reshaped. - sample_dims : Hashable or Sequence[Hashable] - The dimensions of the data that will be stacked along the `sample` dimension. - feature_dims : Hashable or Sequence[Hashable] - The dimensions of the data that will be stacked along the `feature` dimension. Returns ------- - DataArray - The reshaped data. - - Raises - ------ - ValueError - If any of the dimensions in `sample_dims` or `feature_dims` are not present in the data. - ValueError - If data to be transformed has individual NaNs. - ValueError - If data is empty + self : DataArrayStacker + The fitted stacker. """ - - sample_dims = ensure_tuple(sample_dims) - feature_dims = ensure_tuple(feature_dims) - - # The two sets `sample_dims` and `feature_dims` are disjoint/mutually exclusive - if not (set(sample_dims + feature_dims) == set(data.dims)): - raise ValueError( - f"One or more dimensions in {sample_dims + feature_dims} are not present in data dimensions: {data.dims}" - ) - - # Set in/out dimensions - self.dims_in_ = data.dims - self.dims_out_ = {"sample": sample_dims, "feature": feature_dims} - - # Set in/out coordinates - self.coords_in_ = {dim: data.coords[dim] for dim in data.dims} - - # Stack data - da: DataArray = self._stack( - data, self.dims_out_["sample"], self.dims_out_["feature"] + self.sample_dims = sample_dims + self.feature_dims = feature_dims + self.dims_mapping.update( + { + self.sample_name: sample_dims, + self.feature_name: feature_dims, + } ) - # Remove NaN samples/features - da = da.dropna("feature", how="all") - da = da.dropna("sample", how="all") - - self.coords_out_ = { - "sample": da.coords["sample"], - "feature": da.coords["feature"], - } - - # Ensure that no NaNs are present in the data - if da.isnull().any(): - raise ValueError( - "Isolated NaNs are present in the data. Please remove them before fitting the model." - ) + self._sanity_check(X, sample_dims, feature_dims) - # Ensure that data is not empty - if da.size == 0: - raise ValueError("Data is empty.") + # Set dimensions and coordinates + self.dims_in = X.dims + self.coords_in = {dim: X.coords[dim] for dim in X.dims} - return da + return self - def transform(self, data: SingleDataObject) -> DataArray: - """Transform new "unseen" data to 2D version. + def transform(self, X: Data) -> DataArray: + """Reshape DataArray to 2D. Parameters ---------- - data : DataArray + X : DataArray The data to be reshaped. Returns @@ -190,94 +187,97 @@ def transform(self, data: SingleDataObject) -> DataArray: If the data to be transformed has different dimensions than the data used to fit the stacker. ValueError If the data to be transformed has different coordinates than the data used to fit the stacker. - ValueError - If the data to be transformed has individual NaNs. - ValueError - If data is empty """ # Test whether sample and feature dimensions are present in data array - self._validate_matching_dimensions(data) + self._validate_matching_dimensions(X) # Check if data to be transformed has the same feature coordinates as the data used to fit the stacker - self._validate_matching_feature_coords(data) + self._validate_matching_feature_coords(X) - # Stack data and remove NaN features + # Stack data + sample_dims = self.dims_mapping[self.sample_name] + feature_dims = self.dims_mapping[self.feature_name] da: DataArray = self._stack( - data, self.dims_out_["sample"], self.dims_out_["feature"] + X, sample_dims=sample_dims, feature_dims=feature_dims ) - da = da.dropna("feature", how="all") - da = da.dropna("sample", how="all") - - # Ensure that no NaNs are present in the data - if da.isnull().any(): - raise ValueError( - "Isolated NaNs are present in the data. Please remove them before fitting the model." - ) - - # Ensure that data is not empty - if da.size == 0: - raise ValueError("Data is empty.") + # Set out coordinates + self.coords_out.update( + { + self.sample_name: da.coords[self.sample_name], + self.feature_name: da.coords[self.feature_name], + } + ) return da + def fit_transform( + self, + X: DataVar, + sample_dims: Dims, + feature_dims: Dims, + ) -> DataArray: + return self.fit(X, sample_dims, feature_dims).transform(X) -class SingleDataArrayStacker(SingleDataStacker): - """Converts a DataArray of any dimensionality into a 2D structure. + def inverse_transform_data(self, X: DataArray) -> Data: + """Reshape the 2D data (sample x feature) back into its original dimensions. - This operation generates a reshaped DataArray with two distinct dimensions: 'sample' and 'feature'. + Parameters + ---------- + X : DataArray + The data to be reshaped. - The handling of NaNs is specific: if they are found to populate an entire dimension (be it 'sample' or 'feature'), - they are temporarily removed during transformations and subsequently reinstated. - However, the presence of isolated NaNs will trigger an error. + Returns + ------- + DataArray + The reshaped data. - """ + """ + Xnd = self._unstack(X) + Xnd = self._reorder_dims(Xnd) + return Xnd - @staticmethod - def _validate_dimensions(sample_dims: Tuple[str], feature_dims: Tuple[str]): - """Verify the dimensions are correctly specified. - For example, valid input dimensions (sample, feature) are: + def inverse_transform_components(self, X: DataArray) -> Data: + """Reshape the 2D components (sample x feature) back into its original dimensions. - (("year", "month"), ("lon", "lat")), - (("year",), ("lat", "lon")), - (("year", "month"), ("lon",)), - (("year",), ("lon",)), - (("sample",), ("feature",)), <-- special case only valid for DataArrays + Parameters + ---------- + data : DataArray + The data to be reshaped. - """ + Returns + ------- + DataArray + The reshaped data. - # Check for `sample` and `feature` special cases - if sample_dims == ("sample",) and feature_dims != ("feature",): - err_msg = """Due to the internal logic of this package, - when using the 'sample' dimension in sample_dims, it should only be - paired with the 'feature' dimension in feature_dims. Please rename or remove - other dimensions.""" - raise ValueError(err_msg) + """ + Xnd = self._unstack(X) + Xnd = self._reorder_dims(Xnd) + return Xnd - if feature_dims == ("feature",) and sample_dims != ("sample",): - err_msg = """Invalid combination: 'feature' dimension in feature_dims should only - be paired with 'sample' dimension in sample_dims.""" - raise ValueError(err_msg) + def inverse_transform_scores(self, data: DataArray) -> DataArray: + """Reshape the 2D scores (sample x feature) back into its original dimensions. - if "sample" in sample_dims and len(sample_dims) > 1: - err_msg = """Invalid combination: 'sample' dimension should not be combined with other - dimensions in sample_dims.""" - raise ValueError(err_msg) + Parameters + ---------- + data : DataArray + The data to be reshaped. - if "feature" in feature_dims and len(feature_dims) > 1: - err_msg = """Invalid combination: 'feature' dimension should not be combined with other - dimensions in feature_dims.""" - raise ValueError(err_msg) + Returns + ------- + DataArray + The reshaped data. - if "sample" in feature_dims: - err_msg = """Invalid combination: 'sample' dimension should not appear in feature_dims.""" - raise ValueError(err_msg) + """ + data = self._unstack(data) # type: ignore + data = self._reorder_dims(data) + return data - if "feature" in sample_dims: - err_msg = """Invalid combination: 'feature' dimension should not appear in sample_dims.""" - raise ValueError(err_msg) - def _stack(self, data: DataArray, sample_dims, feature_dims) -> DataArray: +class DataArrayStacker(Stacker): + def _stack( + self, data: DataArray, sample_dims: Dims, feature_dims: Dims + ) -> DataArray: """Reshape a DataArray to 2D. Parameters @@ -294,40 +294,46 @@ def _stack(self, data: DataArray, sample_dims, feature_dims) -> DataArray: data_stacked : DataArray The reshaped 2d-data. """ - self._validate_dimensions(sample_dims, feature_dims) + sample_name = self.sample_name + feature_name = self.feature_name + # 3 cases: # 1. uni-dimensional with correct feature/sample name ==> do nothing # 2. uni-dimensional with name different from feature/sample ==> rename # 3. multi-dimensinoal with names different from feature/sample ==> stack - # - FEATURE - - if len(feature_dims) == 1: + # - SAMPLE - + if len(sample_dims) == 1: # Case 1 - if feature_dims[0] == "feature": + if sample_dims[0] == sample_name: pass # Case 2 else: - data = data.rename({feature_dims[0]: "feature"}) + data = data.rename({sample_dims[0]: sample_name}) # Case 3 else: - data = data.stack(feature=feature_dims) + data = data.stack({sample_name: sample_dims}) - # - SAMPLE - - if len(sample_dims) == 1: + # - FEATURE - + if len(feature_dims) == 1: # Case 1 - if sample_dims[0] == "sample": + if feature_dims[0] == feature_name: pass # Case 2 else: - data = data.rename({sample_dims[0]: "sample"}) + data = data.rename({feature_dims[0]: feature_name}) # Case 3 else: - data = data.stack(sample=sample_dims) + data = data.stack({feature_name: feature_dims}) + + # Reorder dimensions to be always (sample, feature) + if data.dims == (feature_name, sample_name): + data = data.transpose(sample_name, feature_name) - return data.transpose("sample", "feature") + return data def _unstack(self, data: DataArray) -> DataArray: - """Unstack `sample` and `feature` dimension of an DataArray to its original dimensions. + """Unstack 2D DataArray to its original dimensions. Parameters ---------- @@ -339,113 +345,54 @@ def _unstack(self, data: DataArray) -> DataArray: data_unstacked : DataArray The unstacked data. """ + sample_name = self.sample_name + feature_name = self.feature_name + # pass if feature/sample dimensions do not exist in data - if "feature" in data.dims: + if feature_name in data.dims: # If sample dimensions is one dimensional, rename is sufficient, otherwise unstack - if len(self.dims_out_["feature"]) == 1: - if self.dims_out_["feature"][0] != "feature": - data = data.rename({"feature": self.dims_out_["feature"][0]}) + if len(self.dims_mapping[feature_name]) == 1: + if self.dims_mapping[feature_name][0] != feature_name: + data = data.rename( + {feature_name: self.dims_mapping[feature_name][0]} + ) else: - data = data.unstack("feature") + data = data.unstack(feature_name) - if "sample" in data.dims: + if sample_name in data.dims: # If sample dimensions is one dimensional, rename is sufficient, otherwise unstack - if len(self.dims_out_["sample"]) == 1: - if self.dims_out_["sample"][0] != "sample": - data = data.rename({"sample": self.dims_out_["sample"][0]}) + if len(self.dims_mapping[sample_name]) == 1: + if self.dims_mapping[sample_name][0] != sample_name: + data = data.rename({sample_name: self.dims_mapping[sample_name][0]}) else: - data = data.unstack("sample") - - # Reorder dimensions to original order - data = self._reorder_dims(data) + data = data.unstack(sample_name) - return data - - def _reindex_dim(self, data: DataArray, stacked_dim: str) -> DataArray: - return super()._reindex_dim(data, stacked_dim) - - def fit_transform( - self, - data: DataArray, - sample_dims: Hashable | Sequence[Hashable], - feature_dims: Hashable | Sequence[Hashable], - ) -> DataArray: - return super().fit_transform(data, sample_dims, feature_dims) - - def transform(self, data: DataArray) -> DataArray: - return super().transform(data) - - def inverse_transform_data(self, data: DataArray) -> DataArray: - """Reshape the 2D data (sample x feature) back into its original shape.""" - - data = self._unstack(data) - - # Reindex data to original coordinates in case that some features at the boundaries were dropped - data = self._reindex_dim(data, "feature") - data = self._reindex_dim(data, "sample") - - return data - - def inverse_transform_components(self, data: DataArray) -> DataArray: - """Reshape the 2D data (mode x feature) back into its original shape.""" - - data = self._unstack(data) - - # Reindex data to original coordinates in case that some features at the boundaries were dropped - data = self._reindex_dim(data, "feature") - - return data - - def inverse_transform_scores(self, data: DataArray) -> DataArray: - """Reshape the 2D data (sample x mode) back into its original shape.""" - - data = self._unstack(data) - - # Scores are not to be reindexed since they new data typically has different sample coordinates - # than the original data used for fitting the model + else: + pass return data -class SingleDatasetStacker(SingleDataStacker): - """Converts a Dataset of any dimensionality into a 2D structure. - - This operation generates a reshaped Dataset with two distinct dimensions: 'sample' and 'feature'. - - The handling of NaNs is specific: if they are found to populate an entire dimension (be it 'sample' or 'feature'), - they are temporarily removed during transformations and subsequently reinstated. - However, the presence of isolated NaNs will trigger an error. - - """ - - @staticmethod - def _validate_dimensions(sample_dims: Tuple[str], feature_dims: Tuple[str]): - """Verify the dimensions are correctly specified. - - For example, valid input dimensions (sample, feature) are: - - (("year", "month"), ("lon", "lat")), - (("year",), ("lat", "lon")), - (("year", "month"), ("lon",)), - (("year",), ("lon",)), - - - Invalid examples are: - any combination that contains 'sample' and/or 'feature' dimension - - """ - if "sample" in sample_dims or "sample" in feature_dims: - err_msg = ( - "The dimension 'sample' is reserved for internal used. Please rename." - ) - raise ValueError(err_msg) - if "feature" in sample_dims or "feature" in feature_dims: - err_msg = ( - "The dimension 'feature' is reserved for internal used. Please rename." +class DataSetStacker(Stacker): + """Converts a Dataset of any dimensionality into a 2D structure.""" + + def _validate_dimension_names(self, sample_dims, feature_dims): + if len(sample_dims) > 1: + if self.sample_name in sample_dims: + raise ValueError( + f"Name of sample dimension ({self.sample_name}) is already present in data. Please use another name." + ) + if len(feature_dims) >= 1: + if self.feature_name in feature_dims: + raise ValueError( + f"Name of feature dimension ({self.feature_name}) is already present in data. Please use another name." + ) + else: + raise ValueError( + f"Datasets without feature dimension are currently not supported. Please convert your Dataset to a DataArray first, e.g. by using `to_array()`." ) - raise ValueError(err_msg) - def _stack(self, data: Dataset, sample_dims, feature_dims) -> DataArray: + def _stack(self, data: DataSet, sample_dims, feature_dims) -> DataArray: """Reshape a Dataset to 2D. Parameters @@ -459,258 +406,107 @@ def _stack(self, data: Dataset, sample_dims, feature_dims) -> DataArray: Returns ------- - data_stacked : DataArray | Dataset + data_stacked : DataArray The reshaped 2d-data. """ - self._validate_dimensions(sample_dims, feature_dims) - # 2 cases: - # 1. uni-dimensional with name different from feature/sample ==> rename - # 2. multi-dimensinoal with names different from feature/sample ==> stack + sample_name = self.sample_name + feature_name = self.feature_name - # - FEATURE - - # Convert Dataset -> DataArray, stacking all non-sample dimensions to feature dimension, including data variables - # Case 1 & 2 - da = data.to_stacked_array(new_dim="feature", sample_dims=sample_dims) + # 3 cases: + # 1. uni-dimensional with correct feature/sample name ==> do nothing + # 2. uni-dimensional with name different from feature/sample ==> rename + # 3. multi-dimensinoal with names different from feature/sample ==> stack - # Rename if sample dimensions is one dimensional, otherwise stack - # Case 1 + # - SAMPLE - if len(sample_dims) == 1: - da = da.rename({sample_dims[0]: "sample"}) - # Case 2 + # Case 1 + if sample_dims[0] == sample_name: + pass + # Case 2 + else: + data = data.rename({sample_dims[0]: sample_name}) + # Case 3 else: - da = da.stack(sample=sample_dims) + data = data.stack({sample_name: sample_dims}) - return da.transpose("sample", "feature") + # - FEATURE - + # Convert Dataset -> DataArray, stacking all non-sample dimensions to feature dimension, including data variables + err_msg = f"Feature dimension {feature_dims[0]} already exists in data. Please choose another feature dimension name." + # Case 2 & 3 + if (len(feature_dims) == 1) & (feature_dims[0] == feature_name): + raise ValueError(err_msg) + else: + try: + da = data.to_stacked_array( + new_dim=feature_name, sample_dims=(self.sample_name,) + ) + except ValueError: + raise ValueError(err_msg) - def _unstack_data(self, data: DataArray) -> Dataset: + # Reorder dimensions to be always (sample, feature) + if da.dims == (feature_name, sample_name): + da = da.transpose(sample_name, feature_name) + + return da + + def _unstack_data(self, data: DataArray) -> DataSet: """Unstack `sample` and `feature` dimension of an DataArray to its original dimensions.""" - if len(self.dims_out_["sample"]) == 1: - data = data.rename({"sample": self.dims_out_["sample"][0]}) - ds: Dataset = data.to_unstacked_dataset("feature", "variable").unstack() + sample_name = self.sample_name + feature_name = self.feature_name + has_only_one_sample_dim = len(self.dims_mapping[sample_name]) == 1 + + if has_only_one_sample_dim: + data = data.rename({sample_name: self.dims_mapping[sample_name][0]}) + + ds: DataSet = data.to_unstacked_dataset(feature_name, "variable").unstack() ds = self._reorder_dims(ds) return ds - def _unstack_components(self, data: DataArray) -> Dataset: - ds: Dataset = data.to_unstacked_dataset("feature", "variable").unstack() + def _unstack_components(self, data: DataArray) -> DataSet: + feature_name = self.feature_name + ds: DataSet = data.to_unstacked_dataset(feature_name, "variable").unstack() ds = self._reorder_dims(ds) return ds def _unstack_scores(self, data: DataArray) -> DataArray: - if len(self.dims_out_["sample"]) == 1: - data = data.rename({"sample": self.dims_out_["sample"][0]}) + sample_name = self.sample_name + has_only_one_sample_dim = len(self.dims_mapping[sample_name]) == 1 + + if has_only_one_sample_dim: + data = data.rename({sample_name: self.dims_mapping[sample_name][0]}) + data = data.unstack() data = self._reorder_dims(data) return data - def _reindex_dim(self, data: Dataset, model_dim: str) -> Dataset: - return super()._reindex_dim(data, model_dim) - - def fit_transform( - self, - data: Dataset, - sample_dims: Hashable | Sequence[Hashable], - feature_dims: Hashable | Sequence[Hashable] | List[Sequence[Hashable]], - ) -> xr.DataArray: - return super().fit_transform(data, sample_dims, feature_dims) - - def transform(self, data: Dataset) -> DataArray: - return super().transform(data) - - def inverse_transform_data(self, data: DataArray) -> Dataset: + def inverse_transform_data(self, X: DataArray) -> DataSet: """Reshape the 2D data (sample x feature) back into its original shape.""" - data_ds: Dataset = self._unstack_data(data) + X_ds: DataSet = self._unstack_data(X) + return X_ds - # Reindex data to original coordinates in case that some features at the boundaries were dropped - data_ds = self._reindex_dim(data_ds, "feature") - data_ds = self._reindex_dim(data_ds, "sample") + def inverse_transform_components(self, X: DataArray) -> DataSet: + """Reshape the 2D components (sample x feature) back into its original shape.""" + X_ds: DataSet = self._unstack_components(X) + return X_ds - return data_ds + def inverse_transform_scores(self, X: DataArray) -> DataArray: + """Reshape the 2D scores (sample x feature) back into its original shape.""" + X = self._unstack_scores(X) + return X - def inverse_transform_components(self, data: DataArray) -> Dataset: - """Reshape the 2D data (mode x feature) back into its original shape.""" - data_ds: Dataset = self._unstack_components(data) - # Reindex data to original coordinates in case that some features at the boundaries were dropped - data_ds = self._reindex_dim(data_ds, "feature") - - return data_ds - - def inverse_transform_scores(self, data: DataArray) -> DataArray: - """Reshape the 2D data (sample x mode) back into its original shape.""" - data = self._unstack_scores(data) - - # Scores are not to be reindexed since they new data typically has different sample coordinates - # than the original data used for fitting the model - - return data - - -class ListDataArrayStacker(_BaseStacker): - """Converts a list of DataArrays of any dimensionality into a 2D structure. - - This operation generates a reshaped DataArray with two distinct dimensions: 'sample' and 'feature'. - - The handling of NaNs is specific: if they are found to populate an entire dimension (be it 'sample' or 'feature'), - they are temporarily removed during transformations and subsequently reinstated. - However, the presence of isolated NaNs will trigger an error. - - At a minimum, the `sample` dimension must be present in all DataArrays. The `feature` dimension can be different - for each DataArray and must be specified as a list of dimensions. - - """ +class StackerFactory: + """Factory class for creating stackers.""" def __init__(self): - self.stackers = [] - - def fit_transform( - self, - data: DataArrayList, - sample_dims: Hashable | Sequence[Hashable], - feature_dims: Hashable | Sequence[Hashable] | List[Sequence[Hashable]], - ) -> DataArray: - """Fit the stacker to the data. - - Parameters - ---------- - data : DataArray - The data to be reshaped. - sample_dims : Hashable or Sequence[Hashable] - The dimensions of the data that will be stacked along the `sample` dimension. - feature_dims : Hashable or Sequence[Hashable] - The dimensions of the data that will be stacked along the `feature` dimension. - - """ - # Check input - if not isinstance(feature_dims, list): - raise TypeError( - "feature dims must be a list of the feature dimensions of each DataArray" - ) - - sample_dims = ensure_tuple(sample_dims) - feature_dims = [ensure_tuple(fdims) for fdims in feature_dims] - - if len(data) != len(feature_dims): - err_message = ( - "Number of data arrays and feature dimensions must be the same. " - ) - err_message += f"Got {len(data)} data arrays and {len(feature_dims)} feature dimensions" - raise ValueError(err_message) + pass - # Set in/out dimensions - self.dims_in_ = [da.dims for da in data] - self.dims_out_ = {"sample": sample_dims, "feature": feature_dims} - - # Set in/out coordinates - self.coords_in_ = [ - {dim: coords for dim, coords in da.coords.items()} for da in data - ] - - for da, fdims in zip(data, feature_dims): - stacker = SingleDataArrayStacker() - da_stacked = stacker.fit_transform(da, sample_dims, fdims) - self.stackers.append(stacker) - - stacked_data_list = [] - idx_coords_size = [] - dummy_feature_coords = [] - - # Stack individual DataArrays - for da, fdims in zip(data, feature_dims): - stacker = SingleDataArrayStacker() - da_stacked = stacker.fit_transform(da, sample_dims, fdims) - idx_coords_size.append(da_stacked.coords["feature"].size) - stacked_data_list.append(da_stacked) - - # Create dummy feature coordinates for each DataArray - idx_range = np.cumsum([0] + idx_coords_size) - for i in range(len(idx_range) - 1): - dummy_feature_coords.append(np.arange(idx_range[i], idx_range[i + 1])) - - # Replace original feature coordiantes with dummy coordinates - for i, data in enumerate(stacked_data_list): - data = data.drop("feature") # type: ignore - stacked_data_list[i] = data.assign_coords(feature=dummy_feature_coords[i]) # type: ignore - - self._dummy_feature_coords = dummy_feature_coords - - stacked_data_list = xr.concat(stacked_data_list, dim="feature") - - self.coords_out_ = { - "sample": stacked_data_list.coords["sample"], - "feature": stacked_data_list.coords["feature"], - } - return stacked_data_list - - def transform(self, data: DataArrayList) -> DataArray: - """Reshape the data into a 2D version. - - Parameters - ---------- - data: list of DataArrays - The data to be reshaped. - - Returns - ------- - DataArray - The reshaped 2D data. - - """ - stacked_data_list = [] - - # Stack individual DataArrays - for i, (stacker, da) in enumerate(zip(self.stackers, data)): - stacked_data = stacker.transform(da) - stacked_data = stacked_data.drop("feature") - # Replace original feature coordiantes with dummy coordinates - stacked_data.coords.update({"feature": self._dummy_feature_coords[i]}) - stacked_data_list.append(stacked_data) - - return xr.concat(stacked_data_list, dim="feature") - - def inverse_transform_data(self, data: DataArray) -> DataArrayList: - """Reshape the 2D data (sample x feature) back into its original shape.""" - dalist = [] - for stacker, features in zip(self.stackers, self._dummy_feature_coords): - # Select the features corresponding to the current DataArray - subda = data.sel(feature=features) - # Replace dummy feature coordinates with original feature coordinates - subda = subda.assign_coords(feature=stacker.coords_out_["feature"]) - - # In case of MultiIndex we have to set the index to the feature dimension again - if isinstance(subda.indexes["feature"], pd.MultiIndex): - subda = subda.set_index(feature=stacker.dims_out_["feature"]) - else: - # NOTE: This is a workaround for the case where the feature dimension is a tuple of length 1 - # the problem is described here: https://github.com/pydata/xarray/discussions/7958 - subda = subda.rename(feature=stacker.dims_out_["feature"][0]) - - # Inverse transform the data using the corresponding stacker - subda = stacker.inverse_transform_data(subda) - dalist.append(subda) - return dalist - - def inverse_transform_components(self, data: DataArray) -> DataArrayList: - """Reshape the 2D data (mode x feature) back into its original shape.""" - dalist = [] - for stacker, features in zip(self.stackers, self._dummy_feature_coords): - # Select the features corresponding to the current DataArray - subda = data.sel(feature=features) - # Replace dummy feature coordinates with original feature coordinates - subda = subda.assign_coords(feature=stacker.coords_out_["feature"]) - - # In case of MultiIndex we have to set the index to the feature dimension again - if isinstance(subda.indexes["feature"], pd.MultiIndex): - subda = subda.set_index(feature=stacker.dims_out_["feature"]) - else: - # NOTE: This is a workaround for the case where the feature dimension is a tuple of length 1 - # the problem is described here: https://github.com/pydata/xarray/discussions/7958 - subda = subda.rename(feature=stacker.dims_out_["feature"][0]) - - # Inverse transform the data using the corresponding stacker - subda = stacker.inverse_transform_components(subda) - dalist.append(subda) - return dalist - - def inverse_transform_scores(self, data: DataArray) -> DataArray: - """Reshape the 2D data (sample x mode) back into its original shape.""" - return self.stackers[0].inverse_transform_scores(data) + @staticmethod + def create(data: Data) -> Type[DataArrayStacker] | Type[DataSetStacker]: + """Create a stacker for the given data.""" + if isinstance(data, xr.DataArray): + return DataArrayStacker + elif isinstance(data, xr.Dataset): + return DataSetStacker + else: + raise TypeError(f"Invalid data type {type(data)}.") diff --git a/xeofs/preprocessing/stacker_factory.py b/xeofs/preprocessing/stacker_factory.py deleted file mode 100644 index ce73b25..0000000 --- a/xeofs/preprocessing/stacker_factory.py +++ /dev/null @@ -1,20 +0,0 @@ -import xarray as xr - -from ._base_stacker import _BaseStacker -from .stacker import SingleDataArrayStacker, SingleDatasetStacker, ListDataArrayStacker -from ..utils.data_types import AnyDataObject - - -class StackerFactory: - @staticmethod - def create_stacker(data: AnyDataObject, **kwargs) -> _BaseStacker: - if isinstance(data, xr.DataArray): - return SingleDataArrayStacker(**kwargs) - elif isinstance(data, xr.Dataset): - return SingleDatasetStacker(**kwargs) - elif isinstance(data, list) and all( - isinstance(da, xr.DataArray) for da in data - ): - return ListDataArrayStacker(**kwargs) - else: - raise ValueError("Invalid data type") diff --git a/xeofs/preprocessing/transformer.py b/xeofs/preprocessing/transformer.py new file mode 100644 index 0000000..26e33d5 --- /dev/null +++ b/xeofs/preprocessing/transformer.py @@ -0,0 +1,68 @@ +from typing import Optional +from typing_extensions import Self +from abc import abstractmethod + +from sklearn.base import BaseEstimator, TransformerMixin + +from ..utils.data_types import Dims, DataVar, DataArray, DataSet, Data, DataVarBound + + +class Transformer(BaseEstimator, TransformerMixin): + """ + Abstract base class to transform an xarray DataArray/Dataset. + + """ + + def __init__( + self, + sample_name: str = "sample", + feature_name: str = "feature", + ): + self.sample_name = sample_name + self.feature_name = feature_name + + @abstractmethod + def fit( + self, + X: Data, + sample_dims: Optional[Dims] = None, + feature_dims: Optional[Dims] = None, + **kwargs + ) -> Self: + """Fit transformer to data. + + Parameters: + ------------- + X: xr.DataArray | xr.Dataset + Input data. + sample_dims: Sequence[Hashable], optional + Sample dimensions. + feature_dims: Sequence[Hashable], optional + Feature dimensions. + """ + pass + + @abstractmethod + def transform(self, X: Data) -> Data: + return X + + def fit_transform( + self, + X: Data, + sample_dims: Optional[Dims] = None, + feature_dims: Optional[Dims] = None, + **kwargs + ) -> Data: + return self.fit(X, sample_dims, feature_dims, **kwargs).transform(X) + + @abstractmethod + def inverse_transform_data(self, X: Data) -> Data: + return X + + @abstractmethod + def inverse_transform_components(self, X: Data) -> Data: + return X + + @abstractmethod + def inverse_transform_scores(self, X: DataArray) -> DataArray: + return X diff --git a/xeofs/utils/constants.py b/xeofs/utils/constants.py index a09a3c3..aab85e4 100644 --- a/xeofs/utils/constants.py +++ b/xeofs/utils/constants.py @@ -10,6 +10,9 @@ "LAT", ] +VALID_LONGITUDE_NAMES = ["lon", "lons", "longitude", "longitudes"] +VALID_CARTESIAN_X_NAMES = ["x", "x_coord"] +VALID_CARTESIAN_Y_NAMES = ["y", "y_coord"] MULTIPLE_TESTS = [ "bonferroni", @@ -23,3 +26,6 @@ "fdr_tsbh", "fdr_tsbky", ] + + +AVG_EARTH_RADIUS = 6371.0 # in km diff --git a/xeofs/utils/data_types.py b/xeofs/utils/data_types.py index 379ff2a..4e77a82 100644 --- a/xeofs/utils/data_types.py +++ b/xeofs/utils/data_types.py @@ -1,18 +1,33 @@ -from typing import List, TypeAlias, TypedDict, Optional, Tuple, TypeVar +from typing import ( + List, + TypeAlias, + Sequence, + Tuple, + TypeVar, + Hashable, +) +import dask.array as da import xarray as xr +from xarray.core import dataarray as xr_dataarray +from xarray.core import dataset as xr_dataset + +DataArray: TypeAlias = xr_dataarray.DataArray +DataSet: TypeAlias = xr_dataset.Dataset +Data: TypeAlias = DataArray | DataSet +DataVar = TypeVar("DataVar", DataArray, DataSet) +DataVarBound = TypeVar("DataVarBound", bound=Data) -DataArray: TypeAlias = xr.DataArray -Dataset: TypeAlias = xr.Dataset DataArrayList: TypeAlias = List[DataArray] -SingleDataObject = TypeVar("SingleDataObject", DataArray, Dataset) -AnyDataObject = TypeVar("AnyDataObject", DataArray, Dataset, DataArrayList) +DataSetList: TypeAlias = List[DataSet] +DataList: TypeAlias = List[Data] +DataVarList: TypeAlias = List[DataVar] + + +DaskArray: TypeAlias = da.Array # type: ignore +DataObject: TypeAlias = DataArray | DataSet | DataList -XarrayData: TypeAlias = DataArray | Dataset -# Model dimensions are always 2-dimensional: sample and feature -Dims: TypeAlias = Tuple[str] +Dims: TypeAlias = Sequence[Hashable] +DimsTuple: TypeAlias = Tuple[Dims, ...] DimsList: TypeAlias = List[Dims] -SampleDims: TypeAlias = Dims -FeatureDims: TypeAlias = Dims | DimsList -# can be either like ('lat', 'lon') (1 DataArray) or (('lat', 'lon'), ('lon')) (multiple DataArrays) -ModelDims = TypedDict("ModelDims", {"sample": SampleDims, "feature": FeatureDims}) +DimsListTuple: TypeAlias = Tuple[DimsList, ...] diff --git a/xeofs/utils/distance_metrics.py b/xeofs/utils/distance_metrics.py new file mode 100644 index 0000000..1b80009 --- /dev/null +++ b/xeofs/utils/distance_metrics.py @@ -0,0 +1,134 @@ +import numpy as np +import numba +from numba import prange +from scipy.spatial.distance import cdist + +from .constants import AVG_EARTH_RADIUS + +VALID_METRICS = ["euclidean", "haversine"] + + +def distance_matrix_bc(A, B, metric="haversine"): + """Compute a distance matrix between two arrays using broadcasting. + + Parameters + ---------- + A: 2D darray + Array of longitudes and latitudes with shape (N, 2) + B: 2D darray + Array of longitudes and latitudes with shape (M, 2) + metric: str + Distance metric to use. Great circle distance (`haversine`) is always expressed in kilometers. + All other distance metrics are reported in the unit of the input data. + See scipy.spatial.distance.cdist for a list of available metrics. + + Returns + ------- + distance: 2D darray + Distance matrix with shape (N, M) + + + """ + if metric == "haversine": + return _haversine_distance_bc(A, B) + else: + return cdist(XA=A, XB=B, metric=metric) + + +def _haversine_distance_bc(lonlats1, lonlats2): + """Compute the great circle distance matrix between two arrays + + This implementation uses numpy broadcasting. + + Parameters + ---------- + lonlats1: 2D darray + Array of longitudes and latitudes with shape (N, 2) + lonlats2: 2D darray + Array of longitudes and latitudes with shape (M, 2) + + Returns + ------- + distance: 2D darray + Great circle distance matrix with shape (N, M) in kilometers + + """ + # Convert to radians + lonlats1 = np.radians(lonlats1) + lonlats2 = np.radians(lonlats2) + + # Extract longitudes and latitudes + lon1, lat1 = lonlats1[:, 0], lonlats1[:, 1] + lon2, lat2 = lonlats2[:, 0], lonlats2[:, 1] + + # Compute differences in longitudes and latitudes + dlon = lon2 - lon1[:, np.newaxis] + dlat = lat2 - lat1[:, np.newaxis] + + # Haversine formula + a = ( + np.sin(dlat / 2) ** 2 + + np.cos(lat1)[..., None] * np.cos(lat2) * np.sin(dlon / 2) ** 2 + ) + c = 2 * np.arctan2(np.sqrt(a), np.sqrt(1 - a)) + distance = AVG_EARTH_RADIUS * c + + return distance + + +@numba.njit(fastmath=True) +def distance_nb(A, b, metric="euclidean"): + if metric == "euclidean": + return _euclidian_distance_nb(A, b) + elif metric == "haversine": + return _haversine_distance_nb(A, b) + else: + raise ValueError( + f"Invalid metric: {metric}. Must be one of ['euclidean', 'haversine']." + ) + + +@numba.njit(fastmath=True) +def _euclidian_distance_nb(A, b): + """Compute the Euclidian distance between two arrays. + + This implementation uses numba. + + Parameters + ---------- + A: 2D array + Array of shape (N, P) + b: 1D array + Array of shape (P,) + + Returns + ------- + distance: 1D array + Distance matrix with shape (N,) + + """ + dist = np.zeros(A.shape[0]) + for r in prange(A.shape[0]): + d = 0 + for c in range(A.shape[1]): + d += (b[c] - A[r, c]) ** 2 + dist[r] = d + return np.sqrt(dist) + + +@numba.njit(fastmath=True) +def _haversine_distance_nb(A, b): + # Convert to radians + A = np.radians(A) + b = np.radians(b) + + # Compute differences in longitudes and latitudes + dlon = b[0] - A[:, 0] + dlat = b[1] - A[:, 1] + + # Haversine formula + a = np.sin(dlat / 2) ** 2 + np.cos(A[:, 1]) * np.cos(b[1]) * np.sin(dlon / 2) ** 2 + c = 2 * np.arctan2(np.sqrt(a), np.sqrt(1 - a)) + distance = AVG_EARTH_RADIUS * c + + return distance diff --git a/xeofs/utils/hilbert_transform.py b/xeofs/utils/hilbert_transform.py new file mode 100644 index 0000000..7f0d5f2 --- /dev/null +++ b/xeofs/utils/hilbert_transform.py @@ -0,0 +1,114 @@ +import numpy as np +import xarray as xr +from scipy.signal import hilbert # type: ignore +from .data_types import DataArray + + +def hilbert_transform( + data: DataArray, dims, padding: str = "exp", decay_factor: float = 0.2 +) -> DataArray: + """Hilbert transform with optional padding to mitigate spectral leakage. + + Parameters: + ------------ + data: DataArray + Input data. + dim: str + Dimension along which to apply the Hilbert transform. + padding: str + Padding type. Can be 'exp' or None. + decay_factor: float + Decay factor of the exponential function. + + Returns: + --------- + data: DataArray + Hilbert transform of the input data. + + """ + return xr.apply_ufunc( + _hilbert_transform_with_padding, + data, + input_core_dims=[dims], + output_core_dims=[dims], + kwargs={"padding": padding, "decay_factor": decay_factor}, + dask="parallelized", + dask_gufunc_kwargs={"allow_rechunk": True}, + ) + + +def _hilbert_transform_with_padding(y, padding: str = "exp", decay_factor: float = 0.2): + """Hilbert transform with optional padding to mitigate spectral leakage. + + Parameters: + ------------ + y: np.ndarray + Input array. + padding: str + Padding type. Can be 'exp' or None. + decay_factor: float + Decay factor of the exponential function. + + Returns: + --------- + y: np.ndarray + Hilbert transform of the input array. + + """ + n_samples = y.shape[0] + + if padding == "exp": + y = _pad_exp(y, decay_factor=decay_factor) + + y = hilbert(y, axis=0) + + if padding == "exp": + y = y[n_samples : 2 * n_samples] + + # Padding can introduce a shift in the mean of the imaginary part + # of the Hilbert transform. Correct for this shift. + y = y - y.mean(axis=0) # type: ignore + + return y + + +def _pad_exp(y, decay_factor: float = 0.2): + """Pad the input array with an exponential decay function. + + The start and end of the input array are padded with an exponential decay + function falling to a reference line given by a linear fit of the data array. + + Parameters: + ------------ + y: np.ndarray + Input array. + decay_factor: float + Decay factor of the exponential function. + + Returns: + --------- + y_ext: np.ndarray + Padded array. + + """ + x = np.arange(y.shape[0]) + x_ext = np.arange(-x.size, 2 * x.size) + + coefs = np.polynomial.polynomial.polyfit(x, y, deg=1) + yfit = np.polynomial.polynomial.polyval(x, coefs).T + yfit_ext = np.polynomial.polynomial.polyval(x_ext, coefs).T + + y_ano = y - yfit + + amp_pre = np.take(y_ano, 0, axis=0)[:, None] + amp_pos = np.take(y_ano, -1, axis=0)[:, None] + + exp_ext = np.exp(-x / x.size / decay_factor) + exp_ext_reverse = exp_ext[::-1] + + pad_pre = amp_pre * exp_ext_reverse + pad_pos = amp_pos * exp_ext + + y_ext = np.concatenate([pad_pre.T, y_ano, pad_pos.T], axis=0) + y_ext += yfit_ext + return y_ext diff --git a/xeofs/utils/kernels.py b/xeofs/utils/kernels.py new file mode 100644 index 0000000..a01b5a3 --- /dev/null +++ b/xeofs/utils/kernels.py @@ -0,0 +1,34 @@ +import numpy as np +import numba + +VALID_KERNELS = ["bisquare", "gaussian", "exponential"] + + +@numba.njit(fastmath=True) +def kernel_weights_nb(distance, bandwidth, kernel): + if kernel == "bisquare": + return _bisquare_nb(distance, bandwidth) + elif kernel == "gaussian": + return _gaussian_nb(distance, bandwidth) + elif kernel == "exponential": + return _exponential_nb(distance, bandwidth) + else: + raise ValueError( + f"Invalid kernel: {kernel}. Must be one of ['bisquare', 'gaussian', 'exponential']." + ) + + +@numba.njit(fastmath=True) +def _bisquare_nb(distance, bandwidth): + weights = (1 - (distance / bandwidth) ** 2) ** 2 + return np.where(distance <= bandwidth, weights, 0) + + +@numba.njit(fastmath=True) +def _gaussian_nb(distance, bandwidth): + return np.exp(-0.5 * (distance / bandwidth) ** 2) + + +@numba.njit(fastmath=True) +def _exponential_nb(distance, bandwidth): + return np.exp(-0.5 * (distance / bandwidth)) diff --git a/xeofs/utils/rotation.py b/xeofs/utils/rotation.py index 462975b..8f57d9d 100644 --- a/xeofs/utils/rotation.py +++ b/xeofs/utils/rotation.py @@ -1,15 +1,107 @@ -""" Implementation of VARIMAX and PROMAX rotation. """ - -# ============================================================================= -# Imports -# ============================================================================= import numpy as np +import xarray as xr + +from .data_types import DataArray + + +def promax(loadings: DataArray, feature_dim, compute=True, **kwargs): + rotated, rot_mat, phi_mat = xr.apply_ufunc( + _promax, + loadings, + input_core_dims=[[feature_dim, "mode"]], + output_core_dims=[ + [feature_dim, "mode"], + ["mode_m", "mode_n"], + ["mode_m", "mode_n"], + ], + kwargs=kwargs, + dask="allowed", + ) + if compute: + rotated = rotated.compute() + rot_mat = rot_mat.compute() + phi_mat = phi_mat.compute() + return rotated, rot_mat, phi_mat + + +def _promax(X: np.ndarray, power: int = 1, max_iter: int = 1000, rtol: float = 1e-8): + """ + Perform (oblique) Promax rotation. + + This implementation also works for complex numbers. + + Parameters + ---------- + X : np.ndarray + 2D matrix to be rotated. Must have shape ``p x m`` containing + p features and m modes. + power : int + Rotation parameter defining the power the Varimax solution is raised + to. For ``power=1``, this is equivalent to the Varimax solution + (the default is 1). + max_iter: int + Maximum number of iterations for finding the rotation matrix + (the default is 1000). + rtol: + The relative tolerance for the rotation process to achieve + (the default is 1e-8). + Returns + ------- + Xrot : np.ndarray + 2D matrix containing the rotated modes. + rot_mat : np.ndarray + Rotation matrix of shape ``m x m`` with m being number of modes. + phi : np.ndarray + Correlation matrix of PCs of shape ``m x m`` with m being number + of modes. For Varimax solution (``power=1``), the correlation matrix + is diagonal i.e. the modes are uncorrelated. -# ============================================================================= -# VARIMAX -# ============================================================================= -def varimax(X: np.ndarray, gamma: float = 1, max_iter: int = 1000, rtol: float = 1e-8): + """ + X = X.copy() + + # Perform varimax rotation + X, rot_mat = _varimax(X=X, max_iter=max_iter, rtol=rtol) + + # Pre-normalization by communalities (sum of squared rows) + h = np.sqrt(np.sum(X * X.conj(), axis=1)) + # Add a stabilizer to avoid zero communalities + eps = 1e-9 + X = (1.0 / (h + eps))[:, np.newaxis] * X + + # Max-normalisation of columns + Xnorm = X / np.max(abs(X), axis=0) + + # "Procustes" equation + P = Xnorm * np.abs(Xnorm) ** (power - 1) + + # Fit linear regression model of "Procrustes" equation + # see Richman 1986 for derivation + L = np.linalg.inv(X.conj().T @ X) @ X.conj().T @ P + + # calculate diagonal of inverse square + try: + sigma_inv = np.diag(np.diag(np.linalg.inv(L.conj().T @ L))) + except np.linalg.LinAlgError: + sigma_inv = np.diag(np.diag(np.linalg.pinv(L.conj().T @ L))) + + # transform and calculate inner products + L = L @ np.sqrt(sigma_inv) + Xrot = X @ L + + # Post-normalization based on Kaiser + Xrot = h[:, np.newaxis] * Xrot + + rot_mat = rot_mat @ L + + # Correlation matrix + L_inv = np.linalg.inv(L) + phi = L_inv @ L_inv.conj().T + + return Xrot, rot_mat, phi + + +def _varimax(X: np.ndarray, gamma: float = 1, max_iter: int = 1000, rtol: float = 1e-8): """ Perform (orthogonal) Varimax rotation. @@ -85,83 +177,3 @@ def varimax(X: np.ndarray, gamma: float = 1, max_iter: int = 1000, rtol: float = # Rotate Xrot = X @ R return Xrot, R - - -# ============================================================================= -# PROMAX -# ============================================================================= -def promax(X: np.ndarray, power: int = 1, max_iter: int = 1000, rtol: float = 1e-8): - """ - Perform (oblique) Promax rotation. - - This implementation also works for complex numbers. - - Parameters - ---------- - X : np.ndarray - 2D matrix to be rotated. Must have shape ``p x m`` containing - p features and m modes. - power : int - Rotation parameter defining the power the Varimax solution is raised - to. For ``power=1``, this is equivalent to the Varimax solution - (the default is 1). - max_iter: int - Maximum number of iterations for finding the rotation matrix - (the default is 1000). - rtol: - The relative tolerance for the rotation process to achieve - (the default is 1e-8). - - Returns - ------- - Xrot : np.ndarray - 2D matrix containing the rotated modes. - rot_mat : np.ndarray - Rotation matrix of shape ``m x m`` with m being number of modes. - phi : np.ndarray - Correlation matrix of PCs of shape ``m x m`` with m being number - of modes. For Varimax solution (``power=1``), the correlation matrix - is diagonal i.e. the modes are uncorrelated. - - """ - X = X.copy() - - # Perform varimax rotation - X, rot_mat = varimax(X=X, max_iter=max_iter, rtol=rtol) - - # Pre-normalization by communalities (sum of squared rows) - h = np.sqrt(np.sum(X * X.conj(), axis=1)) - # Add a stabilizer to avoid zero communalities - eps = 1e-9 - X = (1.0 / (h + eps))[:, np.newaxis] * X - - # Max-normalisation of columns - Xnorm = X / np.max(abs(X), axis=0) - - # "Procustes" equation - P = Xnorm * np.abs(Xnorm) ** (power - 1) - - # Fit linear regression model of "Procrustes" equation - # see Richman 1986 for derivation - L = np.linalg.inv(X.conj().T @ X) @ X.conj().T @ P - - # calculate diagonal of inverse square - try: - sigma_inv = np.diag(np.diag(np.linalg.inv(L.conj().T @ L))) - except np.linalg.LinAlgError: - sigma_inv = np.diag(np.diag(np.linalg.pinv(L.conj().T @ L))) - - # transform and calculate inner products - L = L @ np.sqrt(sigma_inv) - Xrot = X @ L - - # Post-normalization based on Kaiser - Xrot = h[:, np.newaxis] * Xrot - - rot_mat = rot_mat @ L - - # Correlation matrix - L_inv = np.linalg.inv(L) - phi = L_inv @ L_inv.conj().T - - return Xrot, rot_mat, phi diff --git a/xeofs/utils/sanity_checks.py b/xeofs/utils/sanity_checks.py index 5023655..ddc5bb9 100644 --- a/xeofs/utils/sanity_checks.py +++ b/xeofs/utils/sanity_checks.py @@ -2,6 +2,8 @@ import xarray as xr +from xeofs.utils.data_types import Dims + def assert_single_dataarray(da, name): """Check if the given object is a DataArray. @@ -61,7 +63,7 @@ def assert_dataarray_or_dataset(da, name): raise TypeError(f"{name} must be either a DataArray or Dataset") -def ensure_tuple(arg: Any) -> Tuple[str]: +def convert_to_dim_type(arg: Any) -> Dims: # Check for invalid types if not isinstance(arg, (str, tuple, list)): raise TypeError(f"Invalid input type: {type(arg).__name__}") @@ -78,3 +80,16 @@ def ensure_tuple(arg: Any) -> Tuple[str]: return tuple(arg) else: return (arg,) + + +def validate_input_type(X) -> None: + err_msg = "Invalid input type: {:}. Expected one of the following: DataArray, Dataset or list of these.".format( + type(X).__name__ + ) + if isinstance(X, (xr.DataArray, xr.Dataset)): + pass + elif isinstance(X, (list, tuple)): + if not all(isinstance(x, (xr.DataArray, xr.Dataset)) for x in X): + raise TypeError(err_msg) + else: + raise TypeError(err_msg) diff --git a/xeofs/utils/xarray_utils.py b/xeofs/utils/xarray_utils.py index 9d209ee..5610b05 100644 --- a/xeofs/utils/xarray_utils.py +++ b/xeofs/utils/xarray_utils.py @@ -1,64 +1,149 @@ -from typing import List, Sequence, Hashable, Tuple +from typing import Sequence, Hashable, Tuple, TypeVar, List, Any import numpy as np import xarray as xr -from scipy.signal import hilbert # type: ignore -from .sanity_checks import ensure_tuple -from .data_types import XarrayData, DataArray, Dataset, SingleDataObject +from .sanity_checks import convert_to_dim_type +from .data_types import ( + Dims, + DimsList, + Data, + DataVar, + DataArray, + DataSet, + DataList, +) from .constants import VALID_LATITUDE_NAMES +T = TypeVar("T") -def compute_sqrt_cos_lat_weights( - data: SingleDataObject, dim: Hashable | Sequence[Hashable] -) -> SingleDataObject: + +def unwrap_singleton_list(input_list: List[T]) -> T | List[T]: + if len(input_list) == 1: + return input_list[0] + else: + return input_list + + +def process_parameter( + parameter_name: str, parameter, default, n_data: int +) -> List[Any]: + if parameter is None: + return convert_to_list(default) * n_data + elif isinstance(parameter, (list, tuple)): + _check_parameter_number(parameter_name, parameter, n_data) + return convert_to_list(parameter) + else: + return convert_to_list(parameter) * n_data + + +def convert_to_list(data: T | List[T] | Tuple[T]) -> List[T]: + if isinstance(data, list): + return data + elif isinstance(data, tuple): + return list(data) + else: + return list([data]) + + +def _check_parameter_number(parameter_name: str, parameter, n_data: int): + if len(parameter) != n_data: + raise ValueError( + f"number of data objects passed should match number of parameter {parameter_name}" + f"len(data objects)={n_data} and " + f"len({parameter_name})={len(parameter)}" + ) + + +def feature_ones_like(data: DataVar, feature_dims: Dims) -> DataVar: + if isinstance(data, xr.DataArray): + valid_dims = set(data.dims) & set(feature_dims) + feature_coords = {dim: data[dim] for dim in valid_dims} + shape = tuple(coords.size for coords in feature_coords.values()) + return xr.DataArray( + np.ones(shape, dtype=float), + dims=tuple(valid_dims), + coords=feature_coords, + ) + elif isinstance(data, xr.Dataset): + return xr.Dataset( + { + var: feature_ones_like(da, feature_dims) + for var, da in data.data_vars.items() + } + ) + else: + raise TypeError( + "Invalid input type: {:}. Expected one of the following: DataArray or Dataset".format( + type(data).__name__ + ) + ) + + +def compute_sqrt_cos_lat_weights(data: DataVar, feature_dims: Dims) -> DataVar: """Compute the square root of cosine of latitude weights. Parameters ---------- - data : xarray.DataArray or xarray.Dataset + data : xr.DataArray | xr.Dataset Data to be scaled. dim : sequence of hashable Dimensions along which the data is considered to be a feature. Returns ------- - xarray.DataArray or xarray.Dataset + xr.DataArray | xr.Dataset Square root of cosine of latitude weights. """ - dim = ensure_tuple(dim) - - # Find latitude coordinate - is_lat_coord = np.isin(np.array(dim), VALID_LATITUDE_NAMES) - # Select latitude coordinate and compute coslat weights - lat_coord = np.array(dim)[is_lat_coord] + if isinstance(data, xr.DataArray): + lat_dim = extract_latitude_dimension(feature_dims) - if len(lat_coord) > 1: - raise ValueError( - f"{lat_coord} are ambiguous latitude coordinates. Only ONE of the following is allowed for computing coslat weights: {VALID_LATITUDE_NAMES}" - ) - - if len(lat_coord) == 1: - latitudes = data.coords[lat_coord[0]] + latitudes = data.coords[lat_dim] weights = sqrt_cos_lat_weights(latitudes) # Features that cannot be associated to a latitude receive a weight of 1 - weights = weights.where(weights.notnull(), 1) + # weights = weights.where(weights.notnull(), 1) + weights.name = "coslat_weights" + return weights + elif isinstance(data, xr.Dataset): + return xr.Dataset( + { + var: compute_sqrt_cos_lat_weights(da, feature_dims) + for var, da in data.data_vars.items() + } + ) + else: + raise TypeError( + "Invalid input type: {:}. Expected one of the following: DataArray".format( + type(data).__name__ + ) + ) + + +def extract_latitude_dimension(feature_dims: Dims) -> Hashable: + # Find latitude coordinate + lat_dim = set(feature_dims) & set(VALID_LATITUDE_NAMES) + + if len(lat_dim) == 0: raise ValueError( "No latitude coordinate was found to compute coslat weights. Must be one of the following: {:}".format( VALID_LATITUDE_NAMES ) ) - weights.name = "coslat_weights" - return weights + elif len(lat_dim) == 1: + return lat_dim.pop() + else: + raise ValueError( + f"Found ambiguous latitude dimensions: {lat_dim}. Only ONE of the following is allowed for computing coslat weights: {VALID_LATITUDE_NAMES}" + ) def get_dims( - data: DataArray | Dataset | List[DataArray], - sample_dims: Hashable | Sequence[Hashable] | List[Sequence[Hashable]], -) -> Tuple[Hashable, Hashable]: + data: DataList, + sample_dims: Hashable | Sequence[Hashable], +) -> Tuple[Dims, DimsList]: """Extracts the dimensions of a DataArray or Dataset that are not included in the sample dimensions. Parameters: @@ -77,22 +162,17 @@ def get_dims( """ # Check for invalid types - if isinstance(data, (xr.DataArray, xr.Dataset)): - sample_dims = ensure_tuple(sample_dims) - feature_dims = _get_feature_dims(data, sample_dims) - - elif isinstance(data, list): - sample_dims = ensure_tuple(sample_dims) - feature_dims = [_get_feature_dims(da, sample_dims) for da in data] + if isinstance(data, list): + sample_dims = convert_to_dim_type(sample_dims) + feature_dims: DimsList = [_get_feature_dims(da, sample_dims) for da in data] + return sample_dims, feature_dims else: err_message = f"Invalid input type: {type(data).__name__}. Expected one of " - err_message += f"of the following: DataArray, Dataset or list of DataArrays." + err_message += f"of the following: list of DataArrays or Datasets." raise TypeError(err_message) - return sample_dims, feature_dims # type: ignore - -def _get_feature_dims(data: XarrayData, sample_dims: Tuple[str]) -> Tuple[Hashable]: +def _get_feature_dims(data: DataArray | DataSet, sample_dims: Dims) -> Dims: """Extracts the dimensions of a DataArray that are not included in the sample dimensions. @@ -109,21 +189,20 @@ def _get_feature_dims(data: XarrayData, sample_dims: Tuple[str]) -> Tuple[Hashab Feature dimensions. """ - feature_dims = tuple(dim for dim in data.dims if dim not in sample_dims) - return feature_dims + return tuple(dim for dim in data.dims if dim not in sample_dims) -def sqrt_cos_lat_weights(data: SingleDataObject) -> SingleDataObject: +def sqrt_cos_lat_weights(data: DataArray) -> DataArray: """Compute the square root of the cosine of the latitude. Parameters: ------------ - data: xr.DataArray or xr.Dataset + data: xr.DataArray Input data. Returns: --------- - sqrt_cos_lat: xr.DataArray or xr.Dataset + sqrt_cos_lat: xr.DataArray Square root of the cosine of the latitude. """ @@ -155,39 +234,6 @@ def total_variance(data: DataArray, dim) -> DataArray: return data.var(dim, ddof=1).sum() -def hilbert_transform( - data: DataArray, dim, padding="exp", decay_factor=0.2 -) -> DataArray: - """Hilbert transform with optional padding to mitigate spectral leakage. - - Parameters: - ------------ - data: DataArray - Input data. - dim: str - Dimension along which to apply the Hilbert transform. - padding: str - Padding type. Can be 'exp' or None. - decay_factor: float - Decay factor of the exponential function. - - Returns: - --------- - data: DataArray - Hilbert transform of the input data. - - """ - return xr.apply_ufunc( - _hilbert_transform_with_padding, - data, - input_core_dims=[["sample", "feature"]], - output_core_dims=[["sample", "feature"]], - kwargs={"padding": padding, "decay_factor": decay_factor}, - dask="parallelized", - dask_gufunc_kwargs={"allow_rechunk": True}, - ) - - def _np_sqrt_cos_lat_weights(data): """Compute the square root of the cosine of the latitude. @@ -202,81 +248,4 @@ def _np_sqrt_cos_lat_weights(data): Square root of the cosine of the latitude. """ - return np.sqrt(np.cos(np.deg2rad(data))).clip(0, 1) - - -def _hilbert_transform_with_padding(y, padding="exp", decay_factor=0.2): - """Hilbert transform with optional padding to mitigate spectral leakage. - - Parameters: - ------------ - y: np.ndarray - Input array. - padding: str - Padding type. Can be 'exp' or None. - decay_factor: float - Decay factor of the exponential function. - - Returns: - --------- - y: np.ndarray - Hilbert transform of the input array. - - """ - n_samples = y.shape[0] - - if padding == "exp": - y = _pad_exp(y, decay_factor=decay_factor) - - y = hilbert(y, axis=0) - - if padding == "exp": - y = y[n_samples : 2 * n_samples] - - # Padding can introduce a shift in the mean of the imaginary part - # of the Hilbert transform. Correct for this shift. - y = y - y.mean(axis=0) - - return y - - -def _pad_exp(y, decay_factor=0.2): - """Pad the input array with an exponential decay function. - - The start and end of the input array are padded with an exponential decay - function falling to a reference line given by a linear fit of the data array. - - Parameters: - ------------ - y: np.ndarray - Input array. - decay_factor: float - Decay factor of the exponential function. - - Returns: - --------- - y_ext: np.ndarray - Padded array. - - """ - x = np.arange(y.shape[0]) - x_ext = np.arange(-x.size, 2 * x.size) - - coefs = np.polynomial.polynomial.polyfit(x, y, deg=1) - yfit = np.polynomial.polynomial.polyval(x, coefs).T - yfit_ext = np.polynomial.polynomial.polyval(x_ext, coefs).T - - y_ano = y - yfit - - amp_pre = np.take(y_ano, 0, axis=0)[:, None] - amp_pos = np.take(y_ano, -1, axis=0)[:, None] - - exp_ext = np.exp(-x / x.size / decay_factor) - exp_ext_reverse = exp_ext[::-1] - - pad_pre = amp_pre * exp_ext_reverse - pad_pos = amp_pos * exp_ext - - y_ext = np.concatenate([pad_pre.T, y_ano, pad_pos.T], axis=0) - y_ext += yfit_ext - return y_ext + return np.sqrt(np.cos(np.deg2rad(data)).clip(0, 1)) diff --git a/xeofs/validation/bootstrapper.py b/xeofs/validation/bootstrapper.py index 416c3db..8006793 100644 --- a/xeofs/validation/bootstrapper.py +++ b/xeofs/validation/bootstrapper.py @@ -7,9 +7,7 @@ from tqdm import trange from ..models import EOF -from ..data_container.eof_bootstrapper_data_container import ( - EOFBootstrapperDataContainer, -) +from ..data_container import DataContainer from ..utils.data_types import DataArray from .._version import __version__ @@ -34,6 +32,9 @@ def __init__(self, n_bootstraps=20, seed=None): } ) + # Initialize the DataContainer to store the results + self.data = DataContainer() + @abstractmethod def fit(self, model): """Bootstrap a given model.""" @@ -50,16 +51,15 @@ def __init__(self, n_bootstraps=20, seed=None): super().__init__(n_bootstraps=n_bootstraps, seed=seed) self.attrs.update({"model": "Bootstrapped EOF analysis"}) - # Initialize the DataContainer to store the results - self.data: EOFBootstrapperDataContainer = EOFBootstrapperDataContainer() - def fit(self, model: EOF): """Bootstrap a given model.""" self.model = model self.preprocessor = model.preprocessor + sample_name = model.sample_name + feature_name = model.feature_name - input_data = model.data.input_data + input_data = model.data["input_data"] n_samples = input_data.sample.size model_params = model.get_params() @@ -74,33 +74,32 @@ def fit(self, model: EOF): bst_total_variance = [] # type: ignore bst_components = [] # type: ignore bst_scores = [] # type: ignore - bst_idx_modes_sorted = [] # type: ignore for i in trange(n_bootstraps): # Sample with replacement idx_rnd = rng.choice(n_samples, n_samples, replace=True) - bst_data = input_data.isel(sample=idx_rnd) + bst_data = input_data.isel({sample_name: idx_rnd}) + # We need to assign the sample coordinates of the real data + # otherwise the transform() method will raise an error as it + # tries to align the sample coordinates + # with the coordinates of the bootstrapped (permutated) data + bst_data = bst_data.assign_coords({sample_name: input_data[sample_name]}) # Perform EOF analysis with the subsampled data # No scaling because we use the pre-scaled data from the model - bst_model = EOF( - n_modes=n_modes, standardize=False, use_coslat=False, use_weights=False - ) + bst_model = EOF(n_modes=n_modes, standardize=False, use_coslat=False) bst_model.fit(bst_data, dim="sample") # Save results - expvar = bst_model.data.explained_variance - totvar = bst_model.data.total_variance - idx_modes_sorted = bst_model.data.idx_modes_sorted - components = bst_model.data.components + expvar = bst_model.data["explained_variance"] + totvar = bst_model.data["total_variance"] + components = bst_model.data["components"] scores = bst_model.transform(input_data) bst_expvar.append(expvar) bst_total_variance.append(totvar) - bst_idx_modes_sorted.append(idx_modes_sorted) bst_components.append(components) bst_scores.append(scores) # Concatenate the bootstrap results along a new dimension bst_expvar: DataArray = xr.concat(bst_expvar, dim="n") bst_total_variance: DataArray = xr.concat(bst_total_variance, dim="n") - bst_idx_modes_sorted: DataArray = xr.concat(bst_idx_modes_sorted, dim="n") bst_components: DataArray = xr.concat(bst_components, dim="n") bst_scores: DataArray = xr.concat(bst_scores, dim="n") @@ -108,14 +107,13 @@ def fit(self, model: EOF): coords_n = np.arange(1, n_bootstraps + 1) bst_expvar = bst_expvar.assign_coords(n=coords_n) bst_total_variance = bst_total_variance.assign_coords(n=coords_n) - bst_idx_modes_sorted = bst_idx_modes_sorted.assign_coords(n=coords_n) bst_components = bst_components.assign_coords(n=coords_n) bst_scores = bst_scores.assign_coords(n=coords_n) # Fix sign of individual components determined by correlation coefficients # for a given mode with all the individual bootstrap members # NOTE: we use scores as they have typically a lower dimensionality than components - model_scores = model.data.scores + model_scores = model.data["scores"] corr = ( (bst_scores * model_scores).mean("sample") / bst_scores.std("sample") @@ -125,13 +123,14 @@ def fit(self, model: EOF): bst_components = bst_components * signs bst_scores = bst_scores * signs - self.data.set_data( - input_data=self.model.data.input_data, - components=bst_components, - scores=bst_scores, - explained_variance=bst_expvar, - total_variance=bst_total_variance, - idx_modes_sorted=bst_idx_modes_sorted, + self.data.add( + name="input_data", data=model.data["input_data"], allow_compute=False ) + self.data.add(name="components", data=bst_components) + self.data.add(name="scores", data=bst_scores) + self.data.add(name="norms", data=model.data["norms"]) + self.data.add(name="explained_variance", data=bst_expvar) + self.data.add(name="total_variance", data=bst_total_variance) + # Assign the same attributes as the original model self.data.set_attrs(self.attrs)