From a8e458e21a5b30c323b93263ff6e7dd687cb391c Mon Sep 17 00:00:00 2001 From: Niclas Rieger Date: Tue, 31 Oct 2023 20:41:16 +0100 Subject: [PATCH] fix: reindexing transformed scores Resolved an issue where the coordinates of transformed scores were incorrectly reindexed using the fitted data's reference frame. The indexing now correctly utilizes the coordinates from the tansformed data itself, ensuring accurate alignment with the original dataset (closes #98). --- tests/models/test_eof.py | 29 ++++++++--- xeofs/models/_base_model.py | 2 +- xeofs/models/gwpca.py | 4 +- xeofs/preprocessing/concatenator.py | 4 ++ xeofs/preprocessing/dimension_renamer.py | 3 ++ xeofs/preprocessing/list_processor.py | 3 ++ xeofs/preprocessing/multi_index_converter.py | 41 +++++++++------ xeofs/preprocessing/preprocessor.py | 45 +++++++++++++---- xeofs/preprocessing/sanitizer.py | 15 ++++++ xeofs/preprocessing/scaler.py | 3 ++ xeofs/preprocessing/stacker.py | 52 +++++++++++--------- xeofs/preprocessing/transformer.py | 7 ++- 12 files changed, 149 insertions(+), 59 deletions(-) diff --git a/tests/models/test_eof.py b/tests/models/test_eof.py index 16d4ef2..37c49f6 100644 --- a/tests/models/test_eof.py +++ b/tests/models/test_eof.py @@ -363,16 +363,16 @@ def test_get_params(): ) def test_transform(dim, mock_data_array): """Test projecting new unseen data onto the components (EOFs/eigenvectors)""" + data = mock_data_array.isel({dim[0]: slice(0, 3)}) + new_data = mock_data_array.isel({dim[0]: slice(4, 5)}) # Create a xarray DataArray with random data - model = EOF(n_modes=5, solver="full") - model.fit(mock_data_array, dim) + model = EOF(n_modes=2, solver="full") + model.fit(data, dim) scores = model.scores() - # Create a new xarray DataArray with random data - new_data = mock_data_array - - projections = model.transform(new_data) + # Project data onto the components + projections = model.transform(data) # Check that the projection has the right dimensions assert projections.dims == scores.dims, "Projection has wrong dimensions" # type: ignore @@ -390,6 +390,23 @@ def test_transform(dim, mock_data_array): scores.sel(mode=slice(1, 3)), projections.sel(mode=slice(1, 3)), rtol=1e-3 ) + # Project unseen data onto the components + new_projections = model.transform(new_data) + + # Check that the projection has the right dimensions + assert new_projections.dims == scores.dims, "Projection has wrong dimensions" # type: ignore + + # Check that the projection has the right data type + assert isinstance(new_projections, xr.DataArray), "Projection is not a DataArray" + + # Check that the projection has the right name + assert new_projections.name == "scores", "Projection has wrong name: {}".format( + new_projections.name + ) + + # Ensure that the new projections are not NaNs + assert np.all(new_projections.notnull().values), "New projections contain NaNs" + @pytest.mark.parametrize( "dim", diff --git a/xeofs/models/_base_model.py b/xeofs/models/_base_model.py index 5188cf0..e932d10 100644 --- a/xeofs/models/_base_model.py +++ b/xeofs/models/_base_model.py @@ -197,7 +197,7 @@ def transform(self, data: List[Data] | Data, normalized=True) -> DataArray: if normalized: data2D = data2D / self.data["norms"] data2D.name = "scores" - return self.preprocessor.inverse_transform_scores(data2D) + return self.preprocessor.inverse_transform_scores_unseen(data2D) @abstractmethod def _transform_algorithm(self, data: DataArray) -> DataArray: diff --git a/xeofs/models/gwpca.py b/xeofs/models/gwpca.py index 4d317c6..0b0f547 100644 --- a/xeofs/models/gwpca.py +++ b/xeofs/models/gwpca.py @@ -145,7 +145,7 @@ def _fit_algorithm(self, X: DataArray) -> Self: valid_y_names = VALID_CARTESIAN_Y_NAMES + VALID_LATITUDE_NAMES n_sample_dims = len(self.sample_dims) if n_sample_dims == 1: - indexes = self.preprocessor.preconverter.transformers[0].original_indexes + indexes = self.preprocessor.preconverter.transformers[0].coords_from_fit sample_dims = self.preprocessor.renamer.transformers[0].sample_dims_after xy = None for dim in sample_dims: @@ -158,7 +158,7 @@ def _fit_algorithm(self, X: DataArray) -> Self: if xy is None: raise ValueError("Cannot find sample coordinates.") elif n_sample_dims == 2: - indexes = self.preprocessor.postconverter.transformers[0].original_indexes + indexes = self.preprocessor.postconverter.transformers[0].coords_from_fit xy = np.asarray([*indexes[self.sample_name].values]) else: diff --git a/xeofs/preprocessing/concatenator.py b/xeofs/preprocessing/concatenator.py index 1eaf893..7110de2 100644 --- a/xeofs/preprocessing/concatenator.py +++ b/xeofs/preprocessing/concatenator.py @@ -116,3 +116,7 @@ def inverse_transform_components(self, X: DataArray) -> List[DataArray]: def inverse_transform_scores(self, X: DataArray) -> DataArray: """Reshape the 2D scores (sample x mode) back into its original shape.""" return X + + def inverse_transform_scores_unseen(self, X: DataArray) -> DataArray: + """Reshape the 2D scores (sample x mode) back into its original shape.""" + return X diff --git a/xeofs/preprocessing/dimension_renamer.py b/xeofs/preprocessing/dimension_renamer.py index 7824173..a032034 100644 --- a/xeofs/preprocessing/dimension_renamer.py +++ b/xeofs/preprocessing/dimension_renamer.py @@ -59,3 +59,6 @@ def inverse_transform_components(self, X: DataVarBound) -> DataVarBound: def inverse_transform_scores(self, X: DataArray) -> DataArray: return self._inverse_transform(X) + + def inverse_transform_scores_unseen(self, X: DataArray) -> DataArray: + return self._inverse_transform(X) diff --git a/xeofs/preprocessing/list_processor.py b/xeofs/preprocessing/list_processor.py index d7e910b..0dd12e9 100644 --- a/xeofs/preprocessing/list_processor.py +++ b/xeofs/preprocessing/list_processor.py @@ -103,3 +103,6 @@ def inverse_transform_components(self, X: List[Data]) -> List[Data]: def inverse_transform_scores(self, X: DataArray) -> DataArray: return self.transformers[0].inverse_transform_scores(X) + + def inverse_transform_scores_unseen(self, X: DataArray) -> DataArray: + return self.transformers[0].inverse_transform_scores_unseen(X) diff --git a/xeofs/preprocessing/multi_index_converter.py b/xeofs/preprocessing/multi_index_converter.py index 74c497a..6989a44 100644 --- a/xeofs/preprocessing/multi_index_converter.py +++ b/xeofs/preprocessing/multi_index_converter.py @@ -11,21 +11,22 @@ class MultiIndexConverter(Transformer): def __init__(self): super().__init__() - self.original_indexes = {} self.modified_dimensions = [] + self.coords_from_fit = {} + self.coords_from_transform = {} def fit( self, X: Data, sample_dims: Optional[Dims] = None, feature_dims: Optional[Dims] = None, - **kwargs + **kwargs, ) -> Self: - # Store original MultiIndexes and replace with simple index + # Store original MultiIndexes for dim in X.dims: index = X.indexes[dim] if isinstance(index, pd.MultiIndex): - self.original_indexes[dim] = X.coords[dim] + self.coords_from_fit[dim] = X.coords[dim] self.modified_dimensions.append(dim) return self @@ -35,37 +36,45 @@ def transform(self, X: DataVar) -> DataVar: # Replace MultiIndexes with simple index for dim in self.modified_dimensions: - size = X_transformed.coords[dim].size + # We need to store the indexes from "unseen" data + self.coords_from_transform[dim] = X_transformed.coords[dim] + + index = X_transformed.indexes[dim] X_transformed = X_transformed.drop_vars(dim) - X_transformed.coords[dim] = range(size) + X_transformed.coords[dim] = range(index.size) return X_transformed - def _inverse_transform(self, X: DataVarBound) -> DataVarBound: + def _inverse_transform(self, X: DataVarBound, reference: str) -> DataVarBound: X_inverse_transformed = X.copy(deep=True) + match reference: + case "fit": + reference_indexes = self.coords_from_fit + case "transform": + reference_indexes = self.coords_from_transform + # Restore original MultiIndexes - for dim, original_index in self.original_indexes.items(): + for dim, original_index in reference_indexes.items(): if dim in X_inverse_transformed.dims: X_inverse_transformed.coords[dim] = original_index # Set indexes to original MultiIndexes - indexes = [ - idx - for idx in self.original_indexes[dim].indexes.keys() - if idx != dim - ] + indexes = [idx for idx in original_index.indexes.keys() if idx != dim] X_inverse_transformed = X_inverse_transformed.set_index({dim: indexes}) return X_inverse_transformed def inverse_transform_data(self, X: DataVarBound) -> DataVarBound: - return self._inverse_transform(X) + return self._inverse_transform(X, reference="fit") def inverse_transform_components(self, X: DataVarBound) -> DataVarBound: - return self._inverse_transform(X) + return self._inverse_transform(X, reference="fit") def inverse_transform_scores(self, X: DataArray) -> DataArray: - return self._inverse_transform(X) + return self._inverse_transform(X, reference="fit") + + def inverse_transform_scores_unseen(self, X: DataArray) -> DataArray: + return self._inverse_transform(X, reference="transform") # class DataListMultiIndexConverter(BaseEstimator, TransformerMixin): diff --git a/xeofs/preprocessing/preprocessor.py b/xeofs/preprocessing/preprocessor.py index 8da82bc..704d655 100644 --- a/xeofs/preprocessing/preprocessor.py +++ b/xeofs/preprocessing/preprocessor.py @@ -241,9 +241,36 @@ def inverse_transform_components(self, X: DataArray) -> List[Data] | Data: def inverse_transform_scores(self, X: DataArray) -> DataArray: """Inverse transform the scores. + This should be used for scores obtained from the fitted data. + Parameters: ------------- - data: xr.DataArray or list of xarray.DataArray + X: xr.DataArray + Input data. + + Returns + ------- + xr.DataArray + The inverse transformed scores. + + """ + X = self.concatenator.inverse_transform_scores(X) + X = self.sanitizer.inverse_transform_scores(X) + X = self.postconverter.inverse_transform_scores(X) + X = self.stacker.inverse_transform_scores(X) + X = self.preconverter.inverse_transform_scores(X) + X = self.renamer.inverse_transform_scores(X) + X = self.scaler.inverse_transform_scores(X) + return X + + def inverse_transform_scores_unseen(self, X: DataArray) -> DataArray: + """Inverse transform the scores. + + This should be used for scores obtained from new data. + + Parameters: + ------------- + X: xr.DataArray Input data. Returns @@ -252,14 +279,14 @@ def inverse_transform_scores(self, X: DataArray) -> DataArray: The inverse transformed scores. """ - X_list = self.concatenator.inverse_transform_scores(X) - X_list = self.sanitizer.inverse_transform_scores(X_list) - X_list = self.postconverter.inverse_transform_scores(X_list) - X_list_ND = self.stacker.inverse_transform_scores(X_list) - X_list_ND = self.preconverter.inverse_transform_scores(X_list_ND) - X_list_ND = self.renamer.inverse_transform_scores(X_list_ND) - X_list_ND = self.scaler.inverse_transform_scores(X_list_ND) - return X_list_ND + X = self.concatenator.inverse_transform_scores_unseen(X) + X = self.sanitizer.inverse_transform_scores_unseen(X) + X = self.postconverter.inverse_transform_scores_unseen(X) + X = self.stacker.inverse_transform_scores_unseen(X) + X = self.preconverter.inverse_transform_scores_unseen(X) + X = self.renamer.inverse_transform_scores_unseen(X) + X = self.scaler.inverse_transform_scores_unseen(X) + return X def _process_output(self, X: List[Data]) -> List[Data] | Data: if self.return_list: diff --git a/xeofs/preprocessing/sanitizer.py b/xeofs/preprocessing/sanitizer.py index f9e05a3..c0a2181 100644 --- a/xeofs/preprocessing/sanitizer.py +++ b/xeofs/preprocessing/sanitizer.py @@ -74,6 +74,8 @@ def transform(self, X: DataArray) -> DataArray: # Check if input has the correct coordinates self._check_input_coords(X) + # Store sample coordinates for inverse transform + self.sample_coords_transform = X.coords[self.sample_name] # Remove NaN entries; only consider full-dimensional NaNs # We already know valid features from the fitted dataset X = X.isel({self.feature_name: self.is_valid_feature}) @@ -81,6 +83,8 @@ def transform(self, X: DataArray) -> DataArray: # have different samples is_valid_sample = ~X.isnull().all(self.feature_name).compute() X = X.isel({self.sample_name: is_valid_sample}) + # Store valid sample locations for inverse transform + self.is_valid_sample_transform = is_valid_sample return X @@ -110,3 +114,14 @@ def inverse_transform_scores(self, X: DataArray) -> DataArray: return X else: return X.reindex({self.sample_name: self.sample_coords.values}) + + def inverse_transform_scores_unseen(self, X: DataArray) -> DataArray: + # Reindex only if sample coordinates are different + coords_are_equal = X.coords[self.sample_name].identical( + self.sample_coords_transform + ) + + if coords_are_equal: + return X + else: + return X.reindex({self.sample_name: self.sample_coords_transform.values}) diff --git a/xeofs/preprocessing/scaler.py b/xeofs/preprocessing/scaler.py index c7eb920..be4fed1 100644 --- a/xeofs/preprocessing/scaler.py +++ b/xeofs/preprocessing/scaler.py @@ -171,6 +171,9 @@ def inverse_transform_components(self, X: DataVarBound) -> DataVarBound: def inverse_transform_scores(self, X: DataArray) -> DataArray: return X + def inverse_transform_scores_unseen(self, X: DataArray) -> DataArray: + return X + # class DataListScaler(Scaler): # """Scale a list of xr.DataArray along sample dimensions. diff --git a/xeofs/preprocessing/stacker.py b/xeofs/preprocessing/stacker.py index 3f599b5..31ea22e 100644 --- a/xeofs/preprocessing/stacker.py +++ b/xeofs/preprocessing/stacker.py @@ -114,21 +114,6 @@ def _stack(self, X: Data, sample_dims: Dims, feature_dims: Dims) -> DataArray: The reshaped 2d-data. """ - @abstractmethod - def _unstack(self, X: DataArray) -> Data: - """Unstack 2D DataArray to its original dimensions. - - Parameters - ---------- - data : DataArray - The data to be unstacked. - - Returns - ------- - data_unstacked : DataArray - The unstacked data. - """ - def _reorder_dims(self, X: DataVarBound) -> DataVarBound: """Reorder dimensions to original order; catch ('mode') dimensions via ellipsis""" order_input_dims = [ @@ -219,6 +204,7 @@ def fit_transform( ) -> DataArray: return self.fit(X, sample_dims, feature_dims).transform(X) + @abstractmethod def inverse_transform_data(self, X: DataArray) -> Data: """Reshape the 2D data (sample x feature) back into its original dimensions. @@ -233,10 +219,8 @@ def inverse_transform_data(self, X: DataArray) -> Data: The reshaped data. """ - Xnd = self._unstack(X) - Xnd = self._reorder_dims(Xnd) - return Xnd + @abstractmethod def inverse_transform_components(self, X: DataArray) -> Data: """Reshape the 2D components (sample x feature) back into its original dimensions. @@ -251,10 +235,8 @@ def inverse_transform_components(self, X: DataArray) -> Data: The reshaped data. """ - Xnd = self._unstack(X) - Xnd = self._reorder_dims(Xnd) - return Xnd + @abstractmethod def inverse_transform_scores(self, data: DataArray) -> DataArray: """Reshape the 2D scores (sample x feature) back into its original dimensions. @@ -269,9 +251,10 @@ def inverse_transform_scores(self, data: DataArray) -> DataArray: The reshaped data. """ - data = self._unstack(data) # type: ignore - data = self._reorder_dims(data) - return data + + @abstractmethod + def inverse_transform_scores_unseen(self, data: DataArray) -> DataArray: + pass class DataArrayStacker(Stacker): @@ -372,6 +355,24 @@ def _unstack(self, data: DataArray) -> DataArray: return data + def inverse_transform_data(self, X: DataArray) -> Data: + Xnd = self._unstack(X) + Xnd = self._reorder_dims(Xnd) + return Xnd + + def inverse_transform_components(self, X: DataArray) -> Data: + Xnd = self._unstack(X) + Xnd = self._reorder_dims(Xnd) + return Xnd + + def inverse_transform_scores(self, data: DataArray) -> DataArray: + data = self._unstack(data) # type: ignore + data = self._reorder_dims(data) + return data + + def inverse_transform_scores_unseen(self, data: DataArray) -> DataArray: + return self.inverse_transform_scores(data) + class DataSetStacker(Stacker): """Converts a Dataset of any dimensionality into a 2D structure.""" @@ -494,6 +495,9 @@ def inverse_transform_scores(self, X: DataArray) -> DataArray: X = self._unstack_scores(X) return X + def inverse_transform_scores_unseen(self, X: DataArray) -> DataArray: + return self.inverse_transform_scores(X) + class StackerFactory: """Factory class for creating stackers.""" diff --git a/xeofs/preprocessing/transformer.py b/xeofs/preprocessing/transformer.py index 26e33d5..5bd9f1a 100644 --- a/xeofs/preprocessing/transformer.py +++ b/xeofs/preprocessing/transformer.py @@ -1,3 +1,4 @@ +from abc import ABC from typing import Optional from typing_extensions import Self from abc import abstractmethod @@ -7,7 +8,7 @@ from ..utils.data_types import Dims, DataVar, DataArray, DataSet, Data, DataVarBound -class Transformer(BaseEstimator, TransformerMixin): +class Transformer(BaseEstimator, TransformerMixin, ABC): """ Abstract base class to transform an xarray DataArray/Dataset. @@ -66,3 +67,7 @@ def inverse_transform_components(self, X: Data) -> Data: @abstractmethod def inverse_transform_scores(self, X: DataArray) -> DataArray: return X + + @abstractmethod + def inverse_transform_scores_unseen(self, X: DataArray) -> DataArray: + return X