diff --git a/src/nemos/basis.py b/src/nemos/basis.py index 1b0c9f12..92c3871b 100644 --- a/src/nemos/basis.py +++ b/src/nemos/basis.py @@ -204,6 +204,19 @@ def fit(self, X: FeatureMatrix, y=None): ------- self : The transformer object. + + Examples + -------- + >>> import numpy as np + >>> from nemos.basis import MSplineBasis, TransformerBasis + + >>> # Example input + >>> X = np.random.normal(size=(100, 2)) + + >>> # Define and fit tranformation basis + >>> basis = MSplineBasis(10) + >>> transformer = TransformerBasis(basis) + >>> transformer_fitted = transformer.fit(X) """ self._basis._set_kernel(*self._unpack_inputs(X)) return self @@ -223,6 +236,28 @@ def transform(self, X: FeatureMatrix, y=None) -> FeatureMatrix: ------- : The data transformed by the basis functions. + + Examples + -------- + >>> import numpy as np + >>> from nemos.basis import MSplineBasis, TransformerBasis + + >>> # Example input + >>> X = np.random.normal(size=(10000, 2)) + + >>> # Define and fit tranformation basis + >>> basis = MSplineBasis(10, mode="conv", window_size=200) + >>> transformer = TransformerBasis(basis) + >>> # Before calling `fit` the convolution kernel is not set + >>> transformer.kernel_ + + >>> transformer_fitted = transformer.fit(X) + >>> # Now the convolution kernel is initialized and has shape (window_size, n_basis_funcs) + >>> transformer_fitted.kernel_.shape + (200, 10) + + >>> # Transform basis + >>> feature_transformed = transformer.transform(X[:, 0:1]) """ # transpose does not work with pynapple # can't use func(*X.T) to unwrap @@ -248,6 +283,21 @@ def fit_transform(self, X: FeatureMatrix, y=None) -> FeatureMatrix: array-like The data transformed by the basis functions, after fitting the basis functions to the data. + + Examples + -------- + >>> import numpy as np + >>> from nemos.basis import MSplineBasis, TransformerBasis + + >>> # Example input + >>> X = np.random.normal(size=(100, 1)) + + >>> # Define tranformation basis + >>> basis = MSplineBasis(10) + >>> transformer = TransformerBasis(basis) + + >>> # Fit and transform basis + >>> feature_transformed = transformer.fit_transform(X) """ return self._basis.compute_features(*self._unpack_inputs(X)) @@ -705,6 +755,19 @@ def compute_features(self, *xi: ArrayLike) -> FeatureMatrix: input samples with the basis functions. The output shape varies based on the subclass and mode. + Examples + ------- + >>> import numpy as np + >>> from nemos.basis import BSplineBasis + + >>> # Generate data + >>> num_samples = 10000 + >>> X = np.random.normal(size=(num_samples, )) # raw time series + >>> basis = BSplineBasis(10) + >>> features = basis.compute_features(X) # basis transformed time series + >>> features.shape + (10000, 10) + Notes ----- Subclasses should implement how to handle the transformation specific to their @@ -882,6 +945,19 @@ def evaluate_on_grid(self, *n_samples: int) -> Tuple[Tuple[NDArray], NDArray]: This differs from the numpy.meshgrid default, which uses Cartesian indexing. For the same input, Cartesian indexing would return an output of shape $(M_2, M_1, M_3, ....,M_N)$. + Examples + -------- + >>> # Evaluate and visualize 4 M-spline basis functions of order 3: + >>> import numpy as np + >>> import matplotlib.pyplot as plt + >>> from nemos.basis import MSplineBasis + >>> mspline_basis = MSplineBasis(n_basis_funcs=4, order=3) + >>> sample_points, basis_values = mspline_basis.evaluate_on_grid(100) + >>> p = plt.plot(sample_points, basis_values) + >>> _ = plt.title('M-Spline Basis Functions') + >>> _ = plt.xlabel('Domain') + >>> _ = plt.ylabel('Basis Function Value') + >>> _ = plt.legend([f'Function {i+1}' for i in range(4)]); """ self._check_input_dimensionality(n_samples) @@ -1071,7 +1147,22 @@ class AdditiveBasis(Basis): n_basis_funcs : int Number of basis functions. - + Examples + -------- + >>> # Generate sample data + >>> import numpy as np + >>> import nemos as nmo + >>> X = np.random.normal(size=(30, 2)) + + >>> # define two basis objects and add them + >>> basis_1 = nmo.basis.BSplineBasis(10) + >>> basis_2 = nmo.basis.RaisedCosineBasisLinear(15) + >>> additive_basis = basis_1 + basis_2 + + >>> # can add another basis to the AdditiveBasis object + >>> X = np.random.normal(size=(30, 3)) + >>> basis_3 = nmo.basis.RaisedCosineBasisLog(100) + >>> additive_basis_2 = additive_basis + basis_3 """ def __init__(self, basis1: Basis, basis2: Basis) -> None: @@ -1183,6 +1274,22 @@ class MultiplicativeBasis(Basis): n_basis_funcs : int Number of basis functions. + Examples + -------- + >>> # Generate sample data + >>> import numpy as np + >>> import nemos as nmo + >>> X = np.random.normal(size=(30, 3)) + + >>> # define two basis and multiply + >>> basis_1 = nmo.basis.BSplineBasis(10) + >>> basis_2 = nmo.basis.RaisedCosineBasisLinear(15) + >>> multiplicative_basis = basis_1 * basis_2 + + >>> # Can multiply or add another basis to the AdditiveBasis object + >>> # This will cause the number of output features of the result basis to grow accordingly + >>> basis_3 = nmo.basis.RaisedCosineBasisLog(100) + >>> multiplicative_basis_2 = multiplicative_basis * basis_3 """ def __init__(self, basis1: Basis, basis2: Basis) -> None: @@ -1298,7 +1405,6 @@ class SplineBasis(Basis, abc.ABC): ---------- order : int Spline order. - """ def __init__( @@ -1614,6 +1720,14 @@ class BSplineBasis(SplineBasis): [1] Prautzsch, H., Boehm, W., Paluszny, M. (2002). B-spline representation. In: Bézier and B-Spline Techniques. Mathematics and Visualization. Springer, Berlin, Heidelberg. https://doi.org/10.1007/978-3-662-04919-8_5 + Examples + -------- + >>> from numpy import linspace + >>> from nemos.basis import BSplineBasis + + >>> bspline_basis = BSplineBasis(n_basis_funcs=5, order=3) + >>> sample_points = linspace(0, 1, 100) + >>> basis_functions = bspline_basis(sample_points) """ def __init__( @@ -1693,6 +1807,14 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: ----- The evaluation is performed by looping over each element and using `splev` from SciPy to compute the basis values. + + Examples + -------- + >>> import numpy as np + >>> import matplotlib.pyplot as plt + >>> from nemos.basis import BSplineBasis + >>> bspline_basis = BSplineBasis(n_basis_funcs=4, order=3) + >>> sample_points, basis_values = bspline_basis.evaluate_on_grid(100) """ return super().evaluate_on_grid(n_samples) @@ -1728,6 +1850,16 @@ class CyclicBSplineBasis(SplineBasis): Number of basis functions. order : int Order of the splines used in basis functions. + + Examples + -------- + >>> from numpy import linspace + >>> from nemos.basis import CyclicBSplineBasis + >>> X = np.random.normal(size=(1000, 1)) + + >>> cyclic_basis = CyclicBSplineBasis(n_basis_funcs=5, order=3, mode="conv", window_size=10) + >>> sample_points = linspace(0, 1, 100) + >>> basis_functions = cyclic_basis(sample_points) """ def __init__( @@ -1835,6 +1967,14 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: ----- The evaluation is performed by looping over each element and using `splev` from SciPy to compute the basis values. + + Examples + -------- + >>> import numpy as np + >>> import matplotlib.pyplot as plt + >>> from nemos.basis import CyclicBSplineBasis + >>> cyclic_basis = CyclicBSplineBasis(n_basis_funcs=4, order=3) + >>> sample_points, basis_values = cyclic_basis.evaluate_on_grid(100) """ return super().evaluate_on_grid(n_samples) @@ -1864,6 +2004,16 @@ class RaisedCosineBasisLinear(Basis): Only used in "conv" mode. Additional keyword arguments that are passed to `nemos.convolve.create_convolutional_predictor` + Examples + -------- + >>> from numpy import linspace + >>> from nemos.basis import RaisedCosineBasisLinear + >>> X = np.random.normal(size=(1000, 1)) + + >>> cosine_basis = RaisedCosineBasisLinear(n_basis_funcs=5, mode="conv", window_size=10) + >>> sample_points = linspace(0, 1, 100) + >>> basis_functions = cosine_basis(sample_points) + # References ------------ [1] Pillow, J. W., Paninski, L., Uzzel, V. J., Simoncelli, E. P., & J., @@ -2003,6 +2153,13 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: basis_funcs : Raised cosine basis functions, shape (n_samples, n_basis_funcs) + Examples + -------- + >>> import numpy as np + >>> import matplotlib.pyplot as plt + >>> from nemos.basis import RaisedCosineBasisLinear + >>> cosine_basis = RaisedCosineBasisLinear(n_basis_funcs=5, mode="conv", window_size=10) + >>> sample_points, basis_values = cosine_basis.evaluate_on_grid(100) """ return super().evaluate_on_grid(n_samples) @@ -2057,6 +2214,16 @@ class RaisedCosineBasisLog(RaisedCosineBasisLinear): Only used in "conv" mode. Additional keyword arguments that are passed to `nemos.convolve.create_convolutional_predictor` + Examples + -------- + >>> from numpy import linspace + >>> from nemos.basis import RaisedCosineBasisLog + >>> X = np.random.normal(size=(1000, 1)) + + >>> cosine_basis = RaisedCosineBasisLog(n_basis_funcs=5, mode="conv", window_size=10) + >>> sample_points = linspace(0, 1, 100) + >>> basis_functions = cosine_basis(sample_points) + # References ------------ [1] Pillow, J. W., Paninski, L., Uzzel, V. J., Simoncelli, E. P., & J., @@ -2210,6 +2377,18 @@ class OrthExponentialBasis(Basis): **kwargs : Only used in "conv" mode. Additional keyword arguments that are passed to `nemos.convolve.create_convolutional_predictor` + + Examples + -------- + >>> from numpy import linspace + >>> from nemos.basis import OrthExponentialBasis + >>> X = np.random.normal(size=(1000, 1)) + >>> n_basis_funcs = 5 + >>> decay_rates = [0.01, 0.02, 0.03, 0.04, 0.05] # sample decay rates + >>> window_size=10 + >>> ortho_basis = OrthExponentialBasis(n_basis_funcs, decay_rates, "conv", window_size) + >>> sample_points = linspace(0, 1, 100) + >>> basis_functions = ortho_basis(sample_points) """ def __init__( @@ -2365,6 +2544,16 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: Evaluated exponentially decaying basis functions, numerically orthogonalized, shape (n_samples, n_basis_funcs) + Examples + -------- + >>> import numpy as np + >>> import matplotlib.pyplot as plt + >>> from nemos.basis import OrthExponentialBasis + >>> n_basis_funcs = 5 + >>> decay_rates = [0.01, 0.02, 0.03, 0.04, 0.05] # sample decay rates + >>> window_size=10 + >>> ortho_basis = OrthExponentialBasis(n_basis_funcs, decay_rates, "conv", window_size) + >>> sample_points, basis_values = ortho_basis.evaluate_on_grid(100) """ return super().evaluate_on_grid(n_samples) @@ -2387,6 +2576,17 @@ def mspline(x: NDArray, k: int, i: int, T: NDArray) -> NDArray: ------- spline M-spline basis function, shape (n_sample_points, ). + + Examples + -------- + >>> import numpy as np + >>> from numpy import linspace + >>> from nemos.basis import mspline + + >>> sample_points = linspace(0, 1, 100) + >>> mspline_eval = mspline(x=sample_points, k=3, i=2, T=np.random.rand(7)) # define a cubic M-spline + >>> mspline_eval.shape + (100,) """ # Boundary conditions. if (T[i + k] - T[i]) < 1e-6: @@ -2453,6 +2653,18 @@ def bspline( Notes ----- The function uses splev function from scipy.interpolate library for the basis evaluation. + + Examples + -------- + >>> import numpy as np + >>> from numpy import linspace + >>> from nemos.basis import bspline + + >>> sample_points = linspace(0, 1, 100) + >>> knots = np.array([0, 0, 0, 0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 1, 1, 1, 1]) + >>> bspline_eval = bspline(sample_points, knots) # define a cubic B-spline + >>> bspline_eval.shape + (100, 10) """ knots.sort() nk = knots.shape[0]