Skip to content

Commit

Permalink
fix: reindexing transformed scores
Browse files Browse the repository at this point in the history
Resolved an issue where the coordinates of transformed scores were
incorrectly reindexed using the fitted data's reference frame. The indexing now correctly
utilizes the coordinates from the
tansformed data itself, ensuring accurate
alignment with the original dataset (closes #98).
  • Loading branch information
nicrie committed Oct 31, 2023
1 parent 42d8e75 commit a8e458e
Show file tree
Hide file tree
Showing 12 changed files with 149 additions and 59 deletions.
29 changes: 23 additions & 6 deletions tests/models/test_eof.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,16 +363,16 @@ def test_get_params():
)
def test_transform(dim, mock_data_array):
"""Test projecting new unseen data onto the components (EOFs/eigenvectors)"""
data = mock_data_array.isel({dim[0]: slice(0, 3)})
new_data = mock_data_array.isel({dim[0]: slice(4, 5)})

# Create a xarray DataArray with random data
model = EOF(n_modes=5, solver="full")
model.fit(mock_data_array, dim)
model = EOF(n_modes=2, solver="full")
model.fit(data, dim)
scores = model.scores()

# Create a new xarray DataArray with random data
new_data = mock_data_array

projections = model.transform(new_data)
# Project data onto the components
projections = model.transform(data)

# Check that the projection has the right dimensions
assert projections.dims == scores.dims, "Projection has wrong dimensions" # type: ignore
Expand All @@ -390,6 +390,23 @@ def test_transform(dim, mock_data_array):
scores.sel(mode=slice(1, 3)), projections.sel(mode=slice(1, 3)), rtol=1e-3
)

# Project unseen data onto the components
new_projections = model.transform(new_data)

# Check that the projection has the right dimensions
assert new_projections.dims == scores.dims, "Projection has wrong dimensions" # type: ignore

# Check that the projection has the right data type
assert isinstance(new_projections, xr.DataArray), "Projection is not a DataArray"

# Check that the projection has the right name
assert new_projections.name == "scores", "Projection has wrong name: {}".format(
new_projections.name
)

# Ensure that the new projections are not NaNs
assert np.all(new_projections.notnull().values), "New projections contain NaNs"


@pytest.mark.parametrize(
"dim",
Expand Down
2 changes: 1 addition & 1 deletion xeofs/models/_base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def transform(self, data: List[Data] | Data, normalized=True) -> DataArray:
if normalized:
data2D = data2D / self.data["norms"]
data2D.name = "scores"
return self.preprocessor.inverse_transform_scores(data2D)
return self.preprocessor.inverse_transform_scores_unseen(data2D)

@abstractmethod
def _transform_algorithm(self, data: DataArray) -> DataArray:
Expand Down
4 changes: 2 additions & 2 deletions xeofs/models/gwpca.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def _fit_algorithm(self, X: DataArray) -> Self:
valid_y_names = VALID_CARTESIAN_Y_NAMES + VALID_LATITUDE_NAMES
n_sample_dims = len(self.sample_dims)
if n_sample_dims == 1:
indexes = self.preprocessor.preconverter.transformers[0].original_indexes
indexes = self.preprocessor.preconverter.transformers[0].coords_from_fit
sample_dims = self.preprocessor.renamer.transformers[0].sample_dims_after
xy = None
for dim in sample_dims:
Expand All @@ -158,7 +158,7 @@ def _fit_algorithm(self, X: DataArray) -> Self:
if xy is None:
raise ValueError("Cannot find sample coordinates.")
elif n_sample_dims == 2:
indexes = self.preprocessor.postconverter.transformers[0].original_indexes
indexes = self.preprocessor.postconverter.transformers[0].coords_from_fit
xy = np.asarray([*indexes[self.sample_name].values])

else:
Expand Down
4 changes: 4 additions & 0 deletions xeofs/preprocessing/concatenator.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,7 @@ def inverse_transform_components(self, X: DataArray) -> List[DataArray]:
def inverse_transform_scores(self, X: DataArray) -> DataArray:
"""Reshape the 2D scores (sample x mode) back into its original shape."""
return X

def inverse_transform_scores_unseen(self, X: DataArray) -> DataArray:
"""Reshape the 2D scores (sample x mode) back into its original shape."""
return X
3 changes: 3 additions & 0 deletions xeofs/preprocessing/dimension_renamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,6 @@ def inverse_transform_components(self, X: DataVarBound) -> DataVarBound:

def inverse_transform_scores(self, X: DataArray) -> DataArray:
return self._inverse_transform(X)

def inverse_transform_scores_unseen(self, X: DataArray) -> DataArray:
return self._inverse_transform(X)
3 changes: 3 additions & 0 deletions xeofs/preprocessing/list_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,6 @@ def inverse_transform_components(self, X: List[Data]) -> List[Data]:

def inverse_transform_scores(self, X: DataArray) -> DataArray:
return self.transformers[0].inverse_transform_scores(X)

def inverse_transform_scores_unseen(self, X: DataArray) -> DataArray:
return self.transformers[0].inverse_transform_scores_unseen(X)
41 changes: 25 additions & 16 deletions xeofs/preprocessing/multi_index_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,22 @@ class MultiIndexConverter(Transformer):

def __init__(self):
super().__init__()
self.original_indexes = {}
self.modified_dimensions = []
self.coords_from_fit = {}
self.coords_from_transform = {}

def fit(
self,
X: Data,
sample_dims: Optional[Dims] = None,
feature_dims: Optional[Dims] = None,
**kwargs
**kwargs,
) -> Self:
# Store original MultiIndexes and replace with simple index
# Store original MultiIndexes
for dim in X.dims:
index = X.indexes[dim]
if isinstance(index, pd.MultiIndex):
self.original_indexes[dim] = X.coords[dim]
self.coords_from_fit[dim] = X.coords[dim]
self.modified_dimensions.append(dim)

return self
Expand All @@ -35,37 +36,45 @@ def transform(self, X: DataVar) -> DataVar:

# Replace MultiIndexes with simple index
for dim in self.modified_dimensions:
size = X_transformed.coords[dim].size
# We need to store the indexes from "unseen" data
self.coords_from_transform[dim] = X_transformed.coords[dim]

index = X_transformed.indexes[dim]
X_transformed = X_transformed.drop_vars(dim)
X_transformed.coords[dim] = range(size)
X_transformed.coords[dim] = range(index.size)

return X_transformed

def _inverse_transform(self, X: DataVarBound) -> DataVarBound:
def _inverse_transform(self, X: DataVarBound, reference: str) -> DataVarBound:
X_inverse_transformed = X.copy(deep=True)

match reference:
case "fit":
reference_indexes = self.coords_from_fit
case "transform":
reference_indexes = self.coords_from_transform

# Restore original MultiIndexes
for dim, original_index in self.original_indexes.items():
for dim, original_index in reference_indexes.items():
if dim in X_inverse_transformed.dims:
X_inverse_transformed.coords[dim] = original_index
# Set indexes to original MultiIndexes
indexes = [
idx
for idx in self.original_indexes[dim].indexes.keys()
if idx != dim
]
indexes = [idx for idx in original_index.indexes.keys() if idx != dim]
X_inverse_transformed = X_inverse_transformed.set_index({dim: indexes})

return X_inverse_transformed

def inverse_transform_data(self, X: DataVarBound) -> DataVarBound:
return self._inverse_transform(X)
return self._inverse_transform(X, reference="fit")

def inverse_transform_components(self, X: DataVarBound) -> DataVarBound:
return self._inverse_transform(X)
return self._inverse_transform(X, reference="fit")

def inverse_transform_scores(self, X: DataArray) -> DataArray:
return self._inverse_transform(X)
return self._inverse_transform(X, reference="fit")

def inverse_transform_scores_unseen(self, X: DataArray) -> DataArray:
return self._inverse_transform(X, reference="transform")


# class DataListMultiIndexConverter(BaseEstimator, TransformerMixin):
Expand Down
45 changes: 36 additions & 9 deletions xeofs/preprocessing/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,9 +241,36 @@ def inverse_transform_components(self, X: DataArray) -> List[Data] | Data:
def inverse_transform_scores(self, X: DataArray) -> DataArray:
"""Inverse transform the scores.
This should be used for scores obtained from the fitted data.
Parameters:
-------------
data: xr.DataArray or list of xarray.DataArray
X: xr.DataArray
Input data.
Returns
-------
xr.DataArray
The inverse transformed scores.
"""
X = self.concatenator.inverse_transform_scores(X)
X = self.sanitizer.inverse_transform_scores(X)
X = self.postconverter.inverse_transform_scores(X)
X = self.stacker.inverse_transform_scores(X)
X = self.preconverter.inverse_transform_scores(X)
X = self.renamer.inverse_transform_scores(X)
X = self.scaler.inverse_transform_scores(X)
return X

def inverse_transform_scores_unseen(self, X: DataArray) -> DataArray:
"""Inverse transform the scores.
This should be used for scores obtained from new data.
Parameters:
-------------
X: xr.DataArray
Input data.
Returns
Expand All @@ -252,14 +279,14 @@ def inverse_transform_scores(self, X: DataArray) -> DataArray:
The inverse transformed scores.
"""
X_list = self.concatenator.inverse_transform_scores(X)
X_list = self.sanitizer.inverse_transform_scores(X_list)
X_list = self.postconverter.inverse_transform_scores(X_list)
X_list_ND = self.stacker.inverse_transform_scores(X_list)
X_list_ND = self.preconverter.inverse_transform_scores(X_list_ND)
X_list_ND = self.renamer.inverse_transform_scores(X_list_ND)
X_list_ND = self.scaler.inverse_transform_scores(X_list_ND)
return X_list_ND
X = self.concatenator.inverse_transform_scores_unseen(X)
X = self.sanitizer.inverse_transform_scores_unseen(X)
X = self.postconverter.inverse_transform_scores_unseen(X)
X = self.stacker.inverse_transform_scores_unseen(X)
X = self.preconverter.inverse_transform_scores_unseen(X)
X = self.renamer.inverse_transform_scores_unseen(X)
X = self.scaler.inverse_transform_scores_unseen(X)
return X

def _process_output(self, X: List[Data]) -> List[Data] | Data:
if self.return_list:
Expand Down
15 changes: 15 additions & 0 deletions xeofs/preprocessing/sanitizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,17 @@ def transform(self, X: DataArray) -> DataArray:
# Check if input has the correct coordinates
self._check_input_coords(X)

# Store sample coordinates for inverse transform
self.sample_coords_transform = X.coords[self.sample_name]
# Remove NaN entries; only consider full-dimensional NaNs
# We already know valid features from the fitted dataset
X = X.isel({self.feature_name: self.is_valid_feature})
# However, we need to recheck for valid samples, as the new dataset may
# have different samples
is_valid_sample = ~X.isnull().all(self.feature_name).compute()
X = X.isel({self.sample_name: is_valid_sample})
# Store valid sample locations for inverse transform
self.is_valid_sample_transform = is_valid_sample

return X

Expand Down Expand Up @@ -110,3 +114,14 @@ def inverse_transform_scores(self, X: DataArray) -> DataArray:
return X
else:
return X.reindex({self.sample_name: self.sample_coords.values})

def inverse_transform_scores_unseen(self, X: DataArray) -> DataArray:
# Reindex only if sample coordinates are different
coords_are_equal = X.coords[self.sample_name].identical(
self.sample_coords_transform
)

if coords_are_equal:
return X
else:
return X.reindex({self.sample_name: self.sample_coords_transform.values})
3 changes: 3 additions & 0 deletions xeofs/preprocessing/scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,9 @@ def inverse_transform_components(self, X: DataVarBound) -> DataVarBound:
def inverse_transform_scores(self, X: DataArray) -> DataArray:
return X

def inverse_transform_scores_unseen(self, X: DataArray) -> DataArray:
return X


# class DataListScaler(Scaler):
# """Scale a list of xr.DataArray along sample dimensions.
Expand Down
52 changes: 28 additions & 24 deletions xeofs/preprocessing/stacker.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,21 +114,6 @@ def _stack(self, X: Data, sample_dims: Dims, feature_dims: Dims) -> DataArray:
The reshaped 2d-data.
"""

@abstractmethod
def _unstack(self, X: DataArray) -> Data:
"""Unstack 2D DataArray to its original dimensions.
Parameters
----------
data : DataArray
The data to be unstacked.
Returns
-------
data_unstacked : DataArray
The unstacked data.
"""

def _reorder_dims(self, X: DataVarBound) -> DataVarBound:
"""Reorder dimensions to original order; catch ('mode') dimensions via ellipsis"""
order_input_dims = [
Expand Down Expand Up @@ -219,6 +204,7 @@ def fit_transform(
) -> DataArray:
return self.fit(X, sample_dims, feature_dims).transform(X)

@abstractmethod
def inverse_transform_data(self, X: DataArray) -> Data:
"""Reshape the 2D data (sample x feature) back into its original dimensions.
Expand All @@ -233,10 +219,8 @@ def inverse_transform_data(self, X: DataArray) -> Data:
The reshaped data.
"""
Xnd = self._unstack(X)
Xnd = self._reorder_dims(Xnd)
return Xnd

@abstractmethod
def inverse_transform_components(self, X: DataArray) -> Data:
"""Reshape the 2D components (sample x feature) back into its original dimensions.
Expand All @@ -251,10 +235,8 @@ def inverse_transform_components(self, X: DataArray) -> Data:
The reshaped data.
"""
Xnd = self._unstack(X)
Xnd = self._reorder_dims(Xnd)
return Xnd

@abstractmethod
def inverse_transform_scores(self, data: DataArray) -> DataArray:
"""Reshape the 2D scores (sample x feature) back into its original dimensions.
Expand All @@ -269,9 +251,10 @@ def inverse_transform_scores(self, data: DataArray) -> DataArray:
The reshaped data.
"""
data = self._unstack(data) # type: ignore
data = self._reorder_dims(data)
return data

@abstractmethod
def inverse_transform_scores_unseen(self, data: DataArray) -> DataArray:
pass


class DataArrayStacker(Stacker):
Expand Down Expand Up @@ -372,6 +355,24 @@ def _unstack(self, data: DataArray) -> DataArray:

return data

def inverse_transform_data(self, X: DataArray) -> Data:
Xnd = self._unstack(X)
Xnd = self._reorder_dims(Xnd)
return Xnd

def inverse_transform_components(self, X: DataArray) -> Data:
Xnd = self._unstack(X)
Xnd = self._reorder_dims(Xnd)
return Xnd

def inverse_transform_scores(self, data: DataArray) -> DataArray:
data = self._unstack(data) # type: ignore
data = self._reorder_dims(data)
return data

def inverse_transform_scores_unseen(self, data: DataArray) -> DataArray:
return self.inverse_transform_scores(data)


class DataSetStacker(Stacker):
"""Converts a Dataset of any dimensionality into a 2D structure."""
Expand Down Expand Up @@ -494,6 +495,9 @@ def inverse_transform_scores(self, X: DataArray) -> DataArray:
X = self._unstack_scores(X)
return X

def inverse_transform_scores_unseen(self, X: DataArray) -> DataArray:
return self.inverse_transform_scores(X)


class StackerFactory:
"""Factory class for creating stackers."""
Expand Down
Loading

0 comments on commit a8e458e

Please sign in to comment.