Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initialization #252

Merged
merged 11 commits into from
Oct 25, 2024
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@ profile = "black"
# Configure pytest
[tool.pytest.ini_options]
testpaths = ["tests"] # Specify the directory where test files are located
addopts = "-n auto"

[tool.coverage.run]
omit = [
Expand Down
20 changes: 4 additions & 16 deletions src/nemos/glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
import jax.numpy as jnp
import jaxopt
from numpy.typing import ArrayLike
from scipy.optimize import root

from . import observation_models as obs
from . import tree_utils, validation
from .base_regressor import BaseRegressor
from .exceptions import NotFittedError
from .initialize_regressor import initialize_intercept_matching_mean_rate
from .pytrees import FeaturePytree
from .regularizer import GroupLasso, Lasso, Regularizer, Ridge
from .type_casting import jnp_asarray_if, support_pynapple
Expand Down Expand Up @@ -534,21 +534,9 @@ def _initialize_parameters(
else:
data = X

# find numerically zeros of the link
def func(x):
return self.observation_model.inverse_link_function(x) - y.mean(
axis=0, keepdims=False
)

# scipy root finding, much more stable than gradient descent
func_root = root(func, y.mean(axis=0, keepdims=False), method="hybr")
if not jnp.allclose(func_root.fun, 0, atol=10**-4):
raise ValueError(
"Could not set the initial intercept as the inverse of the firing rate for "
"the provided link function. "
"Please, provide initial parameters instead!"
)
initial_intercept = jnp.atleast_1d(func_root.x)
initial_intercept = initialize_intercept_matching_mean_rate(
self.observation_model.inverse_link_function, y
)

# Initialize parameters
init_params = (
Expand Down
104 changes: 104 additions & 0 deletions src/nemos/initialize_regressor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
from typing import Callable

import jax
import jax.numpy as jnp
from numpy.typing import ArrayLike
from scipy.optimize import root_scalar

# dictionary of known inverse link functions.
INVERSE_FUNCS = {
jnp.exp: jnp.log,
jax.nn.softplus: lambda x: jnp.log(jnp.exp(x) - 1.0),
}


def scalar_root_find_elementwise(
func: Callable, args: ArrayLike, x0: ArrayLike
) -> jnp.ndarray:
"""
Find roots of a scalar function.

This can be used as an attempt to find a numerical inverse of an unknown link function of a GLM; typically,
this numerical inverse, is used to set the initial intercept to match the mean firing rate of the model.

Parameters
----------
func:
A callable, which typically will be `inv_link_func(x) - jnp.mean(spikes)`.
args:
List of additional arguments passed to the function.
x0:
Initial values for the root-finding algorithm.

Returns
-------
:
An array containing the roots of each f(x) = func(x, args[k]), for k in 1,..., len(args).

Raises
------
ValueError:
If any of the optimization is not successful.
"""
opts = [root_scalar(func, arg, x0=x, method="secant") for arg, x in zip(args, x0)]

if not all(jnp.abs(func(opt.root, args[i])) < 10**-4 for i, opt in enumerate(opts)):
raise ValueError(
"Could not set the initial intercept as the inverse of the firing rate for "
"the provided link function. "
"Please, provide initial parameters instead!"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, we can't provide any info as to why this failed, right? I think this message is a bit opaque for users (for example), but maybe we are treating users who use a non-standard link function as advanced.

I do think we could have initialize_intercept_matching_mean_rate catch this ValueError and raise a more specific error, saying that we were unable to set the initial parameters to match the mean firing rate.

)

return jnp.array([opt.root for opt in opts])


def initialize_intercept_matching_mean_rate(
inverse_link_function: Callable, y: jnp.ndarray
) -> jnp.ndarray:
"""
Compute the initial intercept term for a regression models.

This method compute an initial intercept term for a regression models such that the baseline activity
matches the mean activity of each neuron, assuming that the model coefficients are initialized to zero.


Parameters
----------
inverse_link_function:
The inverse link function of the model, linking the mean to the linear combination of the covariates in
a GLM.
y:
The neural activity, shape either (num_sample,) for single variable regressors as `GLM`
or (n_sample, n_neurons) for multi-variable regressors, such as `PopulaitonGLM`.

Returns
-------
:
The initial intercept term, shape (n_neurons,).

"""

# return inverse if analytical solution is available
analytical_inv = INVERSE_FUNCS.get(inverse_link_function, None)

means = jnp.atleast_1d(jnp.mean(y, axis=0))
if analytical_inv:
out = analytical_inv(means)
if jnp.any(jnp.isnan(out)):
raise ValueError(
"Failed to initialize the model intercept as the inverse of the firing rate for "
"the provided link function. The mean firing rate has some non-positive values."
)
return out

def func(x, mean_x):
return inverse_link_function(x) - mean_x

try:
out = scalar_root_find_elementwise(func, means, means)
except ValueError:
raise ValueError(
"Failed to initialize the model intercept as the inverse of the firing rate for the"
" provided link function. Please, provide initial parameters instead!"
)
return out
22 changes: 22 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,3 +626,25 @@ def loss_tree(params, XX, yy):
return jnp.power(yy - pred, 2).sum() + norm

return X_tree, y, coef_tree, ridge_tree, loss_tree


@pytest.fixture()
def example_X_y_high_firing_rates():
"""Example that used failed with NeMoS original initialization."""
np.random.seed(123)

n_features = 18
n_neurons = 60
n_samples = 500

# random design array. Shape (n_time_points, n_features).
X = 0.5 * np.random.normal(size=(n_samples, n_features))

# log-rates & weights
b_true = np.random.uniform(size=(n_neurons,)) * 3 # baseline rates
w_true = np.random.uniform(size=(n_features, n_neurons)) * 0.1 # real weights:

# generate counts (spikes will be (n_samples, n_features)
rate = jnp.exp(jnp.dot(X, w_true) + b_true)
spikes = np.random.poisson(rate)
return X, spikes
Loading
Loading