Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Started basis examples #253

Merged
merged 9 commits into from
Oct 31, 2024
Merged
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
247 changes: 245 additions & 2 deletions src/nemos/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,19 @@ def fit(self, X: FeatureMatrix, y=None):
-------
self :
The transformer object.

Examples
--------
>>> import numpy as np
>>> from nemos.basis import MSplineBasis, TransformerBasis

# Example input
pranmod01 marked this conversation as resolved.
Show resolved Hide resolved
>>> X, y = np.random.normal(size=(100, 2)), np.random.uniform(size=100)
pranmod01 marked this conversation as resolved.
Show resolved Hide resolved

# Define and fit tranformation basis
pranmod01 marked this conversation as resolved.
Show resolved Hide resolved
>>> basis = MSplineBasis(10)
>>> transformer = TransformerBasis(basis)
>>> transformer_fitted = transformer.fit(X) # input must be a 2d array
pranmod01 marked this conversation as resolved.
Show resolved Hide resolved
"""
self._basis._set_kernel(*self._unpack_inputs(X))
return self
Expand All @@ -223,6 +236,28 @@ def transform(self, X: FeatureMatrix, y=None) -> FeatureMatrix:
-------
:
The data transformed by the basis functions.

Examples
--------
>>> import numpy as np
>>> from nemos.basis import MSplineBasis, TransformerBasis

>>> # Example input
>>> X, y = np.random.normal(size=(10000, 2)), np.random.uniform(size=100)
pranmod01 marked this conversation as resolved.
Show resolved Hide resolved

>>> # Define and fit tranformation basis
>>> basis = MSplineBasis(10, mode="conv", window_size=200)
>>> transformer = TransformerBasis(basis)
>>> # Before calling `fit` the convolution kernel is not set
>>> transformer.kernel_

>>> transformer_fitted = transformer.fit(X) # input must be a 2d array
pranmod01 marked this conversation as resolved.
Show resolved Hide resolved
>>> # Now the convolution kernel is initialized and has shape (window_size, n_basis_funcs)
>>> transformer_fitted.kernel_.shape
(200, 10)

>>> # Transform basis
>>> feature_transformed = transformer.transform(X[:, 0:1]) # input must be a 2d array, (num_samples, 1)
pranmod01 marked this conversation as resolved.
Show resolved Hide resolved
"""
# transpose does not work with pynapple
# can't use func(*X.T) to unwrap
Expand All @@ -248,6 +283,21 @@ def fit_transform(self, X: FeatureMatrix, y=None) -> FeatureMatrix:
array-like
The data transformed by the basis functions, after fitting the basis
functions to the data.

Examples
--------
>>> import numpy as np
>>> from nemos.basis import MSplineBasis, TransformerBasis

>>> # Example input
>>> X, y = np.random.normal(size=(100, 1)), np.random.uniform(size=100)
pranmod01 marked this conversation as resolved.
Show resolved Hide resolved

>>> # Define tranformation basis
>>> basis = MSplineBasis(10)
>>> transformer = TransformerBasis(basis)

>>> # Fit and transform basis
>>> feature_transformed = transformer.fit_transform(X) # input must be a 2d array, (num_samples, 1)
pranmod01 marked this conversation as resolved.
Show resolved Hide resolved
"""
return self._basis.compute_features(*self._unpack_inputs(X))

Expand Down Expand Up @@ -705,6 +755,19 @@ def compute_features(self, *xi: ArrayLike) -> FeatureMatrix:
input samples with the basis functions. The output shape varies based on
the subclass and mode.

Examples
-------
>>> import numpy as np
>>> from nemos.basis import BSplineBasis

>>> # Generate data
>>> num_samples = 10000
>>> X = np.random.normal(size=(num_samples, )) # raw time series
>>> basis = BSplineBasis(10)
>>> features = basis.compute_features(X) # basis transformed time series
>>> features.shape
(10000, 10)

Notes
-----
Subclasses should implement how to handle the transformation specific to their
Expand Down Expand Up @@ -882,6 +945,23 @@ def evaluate_on_grid(self, *n_samples: int) -> Tuple[Tuple[NDArray], NDArray]:
This differs from the numpy.meshgrid default, which uses Cartesian indexing.
For the same input, Cartesian indexing would return an output of shape $(M_2, M_1, M_3, ....,M_N)$.

Examples
--------
>>> # Evaluate and visualize 4 M-spline basis functions of order 3:
>>> import numpy as np
>>> import matplotlib.pyplot as plt
>>> from nemos.basis import MSplineBasis
>>> mspline_basis = MSplineBasis(n_basis_funcs=4, order=3)
pranmod01 marked this conversation as resolved.
Show resolved Hide resolved
>>> sample_points, basis_values = mspline_basis.evaluate_on_grid(100)
>>> for i in range(4):
... p = plt.plot(sample_points, basis_values[:, i], label=f'Function {i+1}')
pranmod01 marked this conversation as resolved.
Show resolved Hide resolved
>>> plt.title('M-Spline Basis Functions')
pranmod01 marked this conversation as resolved.
Show resolved Hide resolved
Text(0.5, 1.0, 'M-Spline Basis Functions')
>>> plt.xlabel('Domain')
pranmod01 marked this conversation as resolved.
Show resolved Hide resolved
Text(0.5, 0, 'Domain')
>>> plt.ylabel('Basis Function Value')
pranmod01 marked this conversation as resolved.
Show resolved Hide resolved
Text(0, 0.5, 'Basis Function Value')
>>> l = plt.legend()
"""
self._check_input_dimensionality(n_samples)

Expand Down Expand Up @@ -1071,7 +1151,29 @@ class AdditiveBasis(Basis):
n_basis_funcs : int
Number of basis functions.


Examples
--------
>>> # Generate sample data
>>> import numpy as np
>>> import nemos as nmo
pranmod01 marked this conversation as resolved.
Show resolved Hide resolved
>>> X, y = np.random.normal(size=(30, 2)), np.random.poisson(size=30)
pranmod01 marked this conversation as resolved.
Show resolved Hide resolved
>>> # X.shape is (n_samples, n_inputs), where n_inputs is the number required by the basis
pranmod01 marked this conversation as resolved.
Show resolved Hide resolved

>>> # define two basis objects and add them
>>> basis_1 = nmo.basis.BSplineBasis(10)
>>> basis_2 = nmo.basis.RaisedCosineBasisLinear(15)
>>> additive_basis = nmo.basis.AdditiveBasis(basis1=basis_1, basis2=basis_2)
pranmod01 marked this conversation as resolved.
Show resolved Hide resolved
>>> transformed_X = additive_basis.to_transformer().transform(X)
pranmod01 marked this conversation as resolved.
Show resolved Hide resolved
>>> print(transformed_X.shape)
(30, 25)

>>> # can add another basis to the AdditiveBasis object
>>> X = np.random.normal(size=(30, 3))
>>> basis_3 = nmo.basis.RaisedCosineBasisLog(100)
>>> additive_basis_2 = additive_basis + basis_3
>>> transformed_X = additive_basis_2.to_transformer().transform(X)
pranmod01 marked this conversation as resolved.
Show resolved Hide resolved
>>> print(transformed_X.shape)
(30, 125)
"""

def __init__(self, basis1: Basis, basis2: Basis) -> None:
Expand Down Expand Up @@ -1183,6 +1285,28 @@ class MultiplicativeBasis(Basis):
n_basis_funcs : int
Number of basis functions.

Examples
--------
>>> # Generate sample data
>>> import numpy as np
>>> import nemos as nmo
>>> X, y = np.random.normal(size=(30, 3)), np.random.poisson(size=30)
pranmod01 marked this conversation as resolved.
Show resolved Hide resolved

>>> # define two basis and multiply
>>> basis_1 = nmo.basis.BSplineBasis(10)
>>> basis_2 = nmo.basis.RaisedCosineBasisLinear(15)
>>> multiplicative_basis = nmo.basis.MultiplicativeBasis(basis1=basis_1, basis2=basis_2)
pranmod01 marked this conversation as resolved.
Show resolved Hide resolved
>>> transformed_X = multiplicative_basis.to_transformer().transform(X[:, 0:2])
pranmod01 marked this conversation as resolved.
Show resolved Hide resolved
>>> print(transformed_X.shape)
(30, 150)

>>> # Can multiply or add another basis to the AdditiveBasis object
>>> # This will cause the number of output features of the result basis to grow accordingly
>>> basis_3 = nmo.basis.RaisedCosineBasisLog(100)
>>> multiplicative_basis_2 = multiplicative_basis * basis_3
>>> transformed_X = multiplicative_basis_2.to_transformer().transform(X)
pranmod01 marked this conversation as resolved.
Show resolved Hide resolved
>>> print(transformed_X.shape)
(30, 15000)
"""

def __init__(self, basis1: Basis, basis2: Basis) -> None:
Expand Down Expand Up @@ -1298,7 +1422,6 @@ class SplineBasis(Basis, abc.ABC):
----------
order : int
Spline order.

"""

def __init__(
Expand Down Expand Up @@ -1614,6 +1737,16 @@ class BSplineBasis(SplineBasis):
[1] Prautzsch, H., Boehm, W., Paluszny, M. (2002). B-spline representation. In: Bézier and B-Spline Techniques.
Mathematics and Visualization. Springer, Berlin, Heidelberg. https://doi.org/10.1007/978-3-662-04919-8_5

Examples
--------
>>> from numpy import linspace
>>> from nemos.basis import BSplineBasis

>>> bspline_basis = BSplineBasis(n_basis_funcs=5, order=3)
>>> bspline_transformer = bspline_basis.to_transformer()
pranmod01 marked this conversation as resolved.
Show resolved Hide resolved

>>> sample_points = linspace(0, 1, 100)
>>> basis_functions = bspline_basis(sample_points)
"""

def __init__(
Expand Down Expand Up @@ -1693,6 +1826,14 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]:
-----
The evaluation is performed by looping over each element and using `splev` from
SciPy to compute the basis values.

Examples
--------
>>> import numpy as np
>>> import matplotlib.pyplot as plt
>>> from nemos.basis import BSplineBasis
pranmod01 marked this conversation as resolved.
Show resolved Hide resolved
>>> bspline_basis = BSplineBasis(n_basis_funcs=4, order=3)
>>> sample_points, basis_values = bspline_basis.evaluate_on_grid(100)
"""
return super().evaluate_on_grid(n_samples)

Expand Down Expand Up @@ -1728,6 +1869,19 @@ class CyclicBSplineBasis(SplineBasis):
Number of basis functions.
order : int
Order of the splines used in basis functions.

Examples
--------
>>> from numpy import linspace
>>> from nemos.basis import CyclicBSplineBasis
>>> X = np.random.normal(size=(1000, 1))

>>> cyclic_basis = CyclicBSplineBasis(n_basis_funcs=5, order=3, mode="conv", window_size=10)
>>> sample_points = linspace(0, 1, 100)
>>> basis_functions = cyclic_basis(sample_points)
>>> X_transformed = cyclic_basis.to_transformer().fit_transform(X)
pranmod01 marked this conversation as resolved.
Show resolved Hide resolved
>>> X_transformed.shape
(1000, 5)
"""

def __init__(
Expand Down Expand Up @@ -1835,6 +1989,14 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]:
-----
The evaluation is performed by looping over each element and using `splev` from
SciPy to compute the basis values.

Examples
--------
>>> import numpy as np
>>> import matplotlib.pyplot as plt
>>> from nemos.basis import CyclicBSplineBasis
>>> cyclic_basis = CyclicBSplineBasis(n_basis_funcs=4, order=3)
pranmod01 marked this conversation as resolved.
Show resolved Hide resolved
>>> sample_points, basis_values = cyclic_basis.evaluate_on_grid(100)
"""
return super().evaluate_on_grid(n_samples)

Expand Down Expand Up @@ -1864,6 +2026,19 @@ class RaisedCosineBasisLinear(Basis):
Only used in "conv" mode. Additional keyword arguments that are passed to
`nemos.convolve.create_convolutional_predictor`

Examples
--------
>>> from numpy import linspace
>>> from nemos.basis import RaisedCosineBasisLinear
>>> X = np.random.normal(size=(1000, 1))

>>> cosine_basis = RaisedCosineBasisLinear(n_basis_funcs=5, mode="conv", window_size=10)
>>> sample_points = linspace(0, 1, 100)
>>> basis_functions = cosine_basis(sample_points)
>>> X_transformed = cosine_basis.to_transformer().fit_transform(X)
pranmod01 marked this conversation as resolved.
Show resolved Hide resolved
>>> X_transformed.shape
(1000, 5)

# References
------------
[1] Pillow, J. W., Paninski, L., Uzzel, V. J., Simoncelli, E. P., & J.,
Expand Down Expand Up @@ -2003,6 +2178,13 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]:
basis_funcs :
Raised cosine basis functions, shape (n_samples, n_basis_funcs)

Examples
--------
>>> import numpy as np
>>> import matplotlib.pyplot as plt
>>> from nemos.basis import RaisedCosineBasisLinear
pranmod01 marked this conversation as resolved.
Show resolved Hide resolved
>>> cosine_basis = RaisedCosineBasisLinear(n_basis_funcs=5, mode="conv", window_size=10)
>>> sample_points, basis_values = cosine_basis.evaluate_on_grid(100)
"""
return super().evaluate_on_grid(n_samples)

Expand Down Expand Up @@ -2057,6 +2239,19 @@ class RaisedCosineBasisLog(RaisedCosineBasisLinear):
Only used in "conv" mode. Additional keyword arguments that are passed to
`nemos.convolve.create_convolutional_predictor`

Examples
--------
>>> from numpy import linspace
>>> from nemos.basis import RaisedCosineBasisLog
>>> X = np.random.normal(size=(1000, 1))

>>> cosine_basis = RaisedCosineBasisLog(n_basis_funcs=5, mode="conv", window_size=10)
>>> sample_points = linspace(0, 1, 100)
>>> basis_functions = cosine_basis(sample_points)
>>> X_transformed = cosine_basis.to_transformer().fit_transform(X)
pranmod01 marked this conversation as resolved.
Show resolved Hide resolved
>>> X_transformed.shape
(1000, 5)

# References
------------
[1] Pillow, J. W., Paninski, L., Uzzel, V. J., Simoncelli, E. P., & J.,
Expand Down Expand Up @@ -2210,6 +2405,21 @@ class OrthExponentialBasis(Basis):
**kwargs :
Only used in "conv" mode. Additional keyword arguments that are passed to
`nemos.convolve.create_convolutional_predictor`

Examples
--------
>>> from numpy import linspace
>>> from nemos.basis import OrthExponentialBasis
>>> X = np.random.normal(size=(1000, 1))
>>> n_basis_funcs = 5
pranmod01 marked this conversation as resolved.
Show resolved Hide resolved
>>> decay_rates = [0.01, 0.02, 0.03, 0.04, 0.05] # sample decay rates
pranmod01 marked this conversation as resolved.
Show resolved Hide resolved
>>> window_size=10
>>> ortho_basis = OrthExponentialBasis(n_basis_funcs, decay_rates, "conv", window_size)
>>> sample_points = linspace(0, 1, 100)
>>> basis_functions = ortho_basis(sample_points)
>>> X_transformed = ortho_basis.to_transformer().fit_transform(X)
pranmod01 marked this conversation as resolved.
Show resolved Hide resolved
>>> X_transformed.shape
(1000, 5)
"""

def __init__(
Expand Down Expand Up @@ -2365,6 +2575,16 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]:
Evaluated exponentially decaying basis functions, numerically
orthogonalized, shape (n_samples, n_basis_funcs)

Examples
--------
>>> import numpy as np
>>> import matplotlib.pyplot as plt
>>> from nemos.basis import OrthExponentialBasis
>>> n_basis_funcs = 5
pranmod01 marked this conversation as resolved.
Show resolved Hide resolved
>>> decay_rates = [0.01, 0.02, 0.03, 0.04, 0.05] # sample decay rates
>>> window_size=10
>>> ortho_basis = OrthExponentialBasis(n_basis_funcs, decay_rates, "conv", window_size)
>>> sample_points, basis_values = ortho_basis.evaluate_on_grid(100)
"""
return super().evaluate_on_grid(n_samples)

Expand All @@ -2387,6 +2607,17 @@ def mspline(x: NDArray, k: int, i: int, T: NDArray) -> NDArray:
-------
spline
M-spline basis function, shape (n_sample_points, ).

Examples
--------
>>> import numpy as np
>>> from numpy import linspace
>>> from nemos.basis import mspline

>>> sample_points = linspace(0, 1, 100)
>>> mspline_eval = mspline(x=sample_points, k=3, i=2, T=np.random.rand(7)) # define a cubic M-spline
>>> mspline_eval.shape
(100,)
"""
# Boundary conditions.
if (T[i + k] - T[i]) < 1e-6:
Expand Down Expand Up @@ -2453,6 +2684,18 @@ def bspline(
Notes
-----
The function uses splev function from scipy.interpolate library for the basis evaluation.

Examples
--------
>>> import numpy as np
>>> from numpy import linspace
>>> from nemos.basis import bspline

>>> sample_points = linspace(0, 1, 100)
>>> knots = knots = BSplineBasis(10)._generate_knots(sample_points)
>>> bspline_eval = bspline(sample_points, knots) # define a cubic B-spline
pranmod01 marked this conversation as resolved.
Show resolved Hide resolved
>>> bspline_eval.shape
(100, 10)
"""
knots.sort()
nk = knots.shape[0]
Expand Down
Loading