Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Started basis examples #253

Merged
merged 9 commits into from
Oct 31, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
216 changes: 214 additions & 2 deletions src/nemos/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,19 @@ def fit(self, X: FeatureMatrix, y=None):
-------
self :
The transformer object.

Examples
--------
>>> import numpy as np
>>> from nemos.basis import MSplineBasis, TransformerBasis

>>> # Example input
>>> X = np.random.normal(size=(100, 2))

>>> # Define and fit tranformation basis
>>> basis = MSplineBasis(10)
>>> transformer = TransformerBasis(basis)
>>> transformer_fitted = transformer.fit(X)
"""
self._basis._set_kernel(*self._unpack_inputs(X))
return self
Expand All @@ -223,6 +236,28 @@ def transform(self, X: FeatureMatrix, y=None) -> FeatureMatrix:
-------
:
The data transformed by the basis functions.

Examples
--------
>>> import numpy as np
>>> from nemos.basis import MSplineBasis, TransformerBasis

>>> # Example input
>>> X = np.random.normal(size=(10000, 2))

>>> # Define and fit tranformation basis
>>> basis = MSplineBasis(10, mode="conv", window_size=200)
>>> transformer = TransformerBasis(basis)
>>> # Before calling `fit` the convolution kernel is not set
>>> transformer.kernel_

>>> transformer_fitted = transformer.fit(X)
>>> # Now the convolution kernel is initialized and has shape (window_size, n_basis_funcs)
>>> transformer_fitted.kernel_.shape
(200, 10)

>>> # Transform basis
>>> feature_transformed = transformer.transform(X[:, 0:1])
"""
# transpose does not work with pynapple
# can't use func(*X.T) to unwrap
Expand All @@ -248,6 +283,21 @@ def fit_transform(self, X: FeatureMatrix, y=None) -> FeatureMatrix:
array-like
The data transformed by the basis functions, after fitting the basis
functions to the data.

Examples
--------
>>> import numpy as np
>>> from nemos.basis import MSplineBasis, TransformerBasis

>>> # Example input
>>> X = np.random.normal(size=(100, 1))

>>> # Define tranformation basis
>>> basis = MSplineBasis(10)
>>> transformer = TransformerBasis(basis)

>>> # Fit and transform basis
>>> feature_transformed = transformer.fit_transform(X)
"""
return self._basis.compute_features(*self._unpack_inputs(X))

Expand Down Expand Up @@ -705,6 +755,19 @@ def compute_features(self, *xi: ArrayLike) -> FeatureMatrix:
input samples with the basis functions. The output shape varies based on
the subclass and mode.

Examples
-------
>>> import numpy as np
>>> from nemos.basis import BSplineBasis

>>> # Generate data
>>> num_samples = 10000
>>> X = np.random.normal(size=(num_samples, )) # raw time series
>>> basis = BSplineBasis(10)
>>> features = basis.compute_features(X) # basis transformed time series
>>> features.shape
(10000, 10)

Notes
-----
Subclasses should implement how to handle the transformation specific to their
Expand Down Expand Up @@ -882,6 +945,19 @@ def evaluate_on_grid(self, *n_samples: int) -> Tuple[Tuple[NDArray], NDArray]:
This differs from the numpy.meshgrid default, which uses Cartesian indexing.
For the same input, Cartesian indexing would return an output of shape $(M_2, M_1, M_3, ....,M_N)$.

Examples
--------
>>> # Evaluate and visualize 4 M-spline basis functions of order 3:
>>> import numpy as np
>>> import matplotlib.pyplot as plt
>>> from nemos.basis import MSplineBasis
>>> mspline_basis = MSplineBasis(n_basis_funcs=4, order=3)
pranmod01 marked this conversation as resolved.
Show resolved Hide resolved
>>> sample_points, basis_values = mspline_basis.evaluate_on_grid(100)
>>> p = plt.plot(sample_points, basis_values)
>>> _ = plt.title('M-Spline Basis Functions')
>>> _ = plt.xlabel('Domain')
>>> _ = plt.ylabel('Basis Function Value')
>>> _ = plt.legend([f'Function {i+1}' for i in range(4)]);
"""
self._check_input_dimensionality(n_samples)

Expand Down Expand Up @@ -1071,7 +1147,22 @@ class AdditiveBasis(Basis):
n_basis_funcs : int
Number of basis functions.


Examples
--------
>>> # Generate sample data
>>> import numpy as np
>>> import nemos as nmo
pranmod01 marked this conversation as resolved.
Show resolved Hide resolved
>>> X = np.random.normal(size=(30, 2))

>>> # define two basis objects and add them
>>> basis_1 = nmo.basis.BSplineBasis(10)
>>> basis_2 = nmo.basis.RaisedCosineBasisLinear(15)
>>> additive_basis = basis_1 + basis_2

>>> # can add another basis to the AdditiveBasis object
>>> X = np.random.normal(size=(30, 3))
>>> basis_3 = nmo.basis.RaisedCosineBasisLog(100)
>>> additive_basis_2 = additive_basis + basis_3
"""

def __init__(self, basis1: Basis, basis2: Basis) -> None:
Expand Down Expand Up @@ -1183,6 +1274,22 @@ class MultiplicativeBasis(Basis):
n_basis_funcs : int
Number of basis functions.

Examples
--------
>>> # Generate sample data
>>> import numpy as np
>>> import nemos as nmo
>>> X = np.random.normal(size=(30, 3))

>>> # define two basis and multiply
>>> basis_1 = nmo.basis.BSplineBasis(10)
>>> basis_2 = nmo.basis.RaisedCosineBasisLinear(15)
>>> multiplicative_basis = basis_1 * basis_2

>>> # Can multiply or add another basis to the AdditiveBasis object
>>> # This will cause the number of output features of the result basis to grow accordingly
>>> basis_3 = nmo.basis.RaisedCosineBasisLog(100)
>>> multiplicative_basis_2 = multiplicative_basis * basis_3
"""

def __init__(self, basis1: Basis, basis2: Basis) -> None:
Expand Down Expand Up @@ -1298,7 +1405,6 @@ class SplineBasis(Basis, abc.ABC):
----------
order : int
Spline order.

"""

def __init__(
Expand Down Expand Up @@ -1614,6 +1720,14 @@ class BSplineBasis(SplineBasis):
[1] Prautzsch, H., Boehm, W., Paluszny, M. (2002). B-spline representation. In: Bézier and B-Spline Techniques.
Mathematics and Visualization. Springer, Berlin, Heidelberg. https://doi.org/10.1007/978-3-662-04919-8_5

Examples
--------
>>> from numpy import linspace
>>> from nemos.basis import BSplineBasis

>>> bspline_basis = BSplineBasis(n_basis_funcs=5, order=3)
>>> sample_points = linspace(0, 1, 100)
>>> basis_functions = bspline_basis(sample_points)
"""

def __init__(
Expand Down Expand Up @@ -1693,6 +1807,14 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]:
-----
The evaluation is performed by looping over each element and using `splev` from
SciPy to compute the basis values.

Examples
--------
>>> import numpy as np
>>> import matplotlib.pyplot as plt
>>> from nemos.basis import BSplineBasis
pranmod01 marked this conversation as resolved.
Show resolved Hide resolved
>>> bspline_basis = BSplineBasis(n_basis_funcs=4, order=3)
>>> sample_points, basis_values = bspline_basis.evaluate_on_grid(100)
"""
return super().evaluate_on_grid(n_samples)

Expand Down Expand Up @@ -1728,6 +1850,16 @@ class CyclicBSplineBasis(SplineBasis):
Number of basis functions.
order : int
Order of the splines used in basis functions.

Examples
--------
>>> from numpy import linspace
>>> from nemos.basis import CyclicBSplineBasis
>>> X = np.random.normal(size=(1000, 1))

>>> cyclic_basis = CyclicBSplineBasis(n_basis_funcs=5, order=3, mode="conv", window_size=10)
>>> sample_points = linspace(0, 1, 100)
>>> basis_functions = cyclic_basis(sample_points)
"""

def __init__(
Expand Down Expand Up @@ -1835,6 +1967,14 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]:
-----
The evaluation is performed by looping over each element and using `splev` from
SciPy to compute the basis values.

Examples
--------
>>> import numpy as np
>>> import matplotlib.pyplot as plt
>>> from nemos.basis import CyclicBSplineBasis
>>> cyclic_basis = CyclicBSplineBasis(n_basis_funcs=4, order=3)
pranmod01 marked this conversation as resolved.
Show resolved Hide resolved
>>> sample_points, basis_values = cyclic_basis.evaluate_on_grid(100)
"""
return super().evaluate_on_grid(n_samples)

Expand Down Expand Up @@ -1864,6 +2004,16 @@ class RaisedCosineBasisLinear(Basis):
Only used in "conv" mode. Additional keyword arguments that are passed to
`nemos.convolve.create_convolutional_predictor`

Examples
--------
>>> from numpy import linspace
>>> from nemos.basis import RaisedCosineBasisLinear
>>> X = np.random.normal(size=(1000, 1))

>>> cosine_basis = RaisedCosineBasisLinear(n_basis_funcs=5, mode="conv", window_size=10)
>>> sample_points = linspace(0, 1, 100)
>>> basis_functions = cosine_basis(sample_points)

# References
------------
[1] Pillow, J. W., Paninski, L., Uzzel, V. J., Simoncelli, E. P., & J.,
Expand Down Expand Up @@ -2003,6 +2153,13 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]:
basis_funcs :
Raised cosine basis functions, shape (n_samples, n_basis_funcs)

Examples
--------
>>> import numpy as np
>>> import matplotlib.pyplot as plt
>>> from nemos.basis import RaisedCosineBasisLinear
pranmod01 marked this conversation as resolved.
Show resolved Hide resolved
>>> cosine_basis = RaisedCosineBasisLinear(n_basis_funcs=5, mode="conv", window_size=10)
>>> sample_points, basis_values = cosine_basis.evaluate_on_grid(100)
"""
return super().evaluate_on_grid(n_samples)

Expand Down Expand Up @@ -2057,6 +2214,16 @@ class RaisedCosineBasisLog(RaisedCosineBasisLinear):
Only used in "conv" mode. Additional keyword arguments that are passed to
`nemos.convolve.create_convolutional_predictor`

Examples
--------
>>> from numpy import linspace
>>> from nemos.basis import RaisedCosineBasisLog
>>> X = np.random.normal(size=(1000, 1))

>>> cosine_basis = RaisedCosineBasisLog(n_basis_funcs=5, mode="conv", window_size=10)
>>> sample_points = linspace(0, 1, 100)
>>> basis_functions = cosine_basis(sample_points)

# References
------------
[1] Pillow, J. W., Paninski, L., Uzzel, V. J., Simoncelli, E. P., & J.,
Expand Down Expand Up @@ -2210,6 +2377,18 @@ class OrthExponentialBasis(Basis):
**kwargs :
Only used in "conv" mode. Additional keyword arguments that are passed to
`nemos.convolve.create_convolutional_predictor`

Examples
--------
>>> from numpy import linspace
>>> from nemos.basis import OrthExponentialBasis
>>> X = np.random.normal(size=(1000, 1))
>>> n_basis_funcs = 5
pranmod01 marked this conversation as resolved.
Show resolved Hide resolved
>>> decay_rates = [0.01, 0.02, 0.03, 0.04, 0.05] # sample decay rates
>>> window_size=10
>>> ortho_basis = OrthExponentialBasis(n_basis_funcs, decay_rates, "conv", window_size)
>>> sample_points = linspace(0, 1, 100)
>>> basis_functions = ortho_basis(sample_points)
"""

def __init__(
Expand Down Expand Up @@ -2365,6 +2544,16 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]:
Evaluated exponentially decaying basis functions, numerically
orthogonalized, shape (n_samples, n_basis_funcs)

Examples
--------
>>> import numpy as np
>>> import matplotlib.pyplot as plt
>>> from nemos.basis import OrthExponentialBasis
>>> n_basis_funcs = 5
pranmod01 marked this conversation as resolved.
Show resolved Hide resolved
>>> decay_rates = [0.01, 0.02, 0.03, 0.04, 0.05] # sample decay rates
>>> window_size=10
>>> ortho_basis = OrthExponentialBasis(n_basis_funcs, decay_rates, "conv", window_size)
>>> sample_points, basis_values = ortho_basis.evaluate_on_grid(100)
"""
return super().evaluate_on_grid(n_samples)

Expand All @@ -2387,6 +2576,17 @@ def mspline(x: NDArray, k: int, i: int, T: NDArray) -> NDArray:
-------
spline
M-spline basis function, shape (n_sample_points, ).

Examples
--------
>>> import numpy as np
>>> from numpy import linspace
>>> from nemos.basis import mspline

>>> sample_points = linspace(0, 1, 100)
>>> mspline_eval = mspline(x=sample_points, k=3, i=2, T=np.random.rand(7)) # define a cubic M-spline
>>> mspline_eval.shape
(100,)
"""
# Boundary conditions.
if (T[i + k] - T[i]) < 1e-6:
Expand Down Expand Up @@ -2453,6 +2653,18 @@ def bspline(
Notes
-----
The function uses splev function from scipy.interpolate library for the basis evaluation.

Examples
--------
>>> import numpy as np
>>> from numpy import linspace
>>> from nemos.basis import bspline

>>> sample_points = linspace(0, 1, 100)
>>> knots = np.array([0, 0, 0, 0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 1, 1, 1, 1])
>>> bspline_eval = bspline(sample_points, knots) # define a cubic B-spline
pranmod01 marked this conversation as resolved.
Show resolved Hide resolved
>>> bspline_eval.shape
(100, 10)
"""
knots.sort()
nk = knots.shape[0]
Expand Down
Loading