-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Initialization #252
Merged
Merged
Initialization #252
Changes from all commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
45bdc61
added optim routine in a module
BalzaniEdoardo 5cd9295
use initialization function in GLMs.
BalzaniEdoardo 430eec6
Merge branch 'development' into initialization
BalzaniEdoardo 5ca9b49
added tests
BalzaniEdoardo 8a08d3a
fixed tests
BalzaniEdoardo 3dac13a
linters
BalzaniEdoardo 2b2d73a
flake8
BalzaniEdoardo e37883e
Update src/nemos/initialize_regressor.py
BalzaniEdoardo ddef196
Apply suggestions from code review
sjvenditto b4a2148
linted and captuered root finding warn
BalzaniEdoardo 66ab701
merged development
BalzaniEdoardo File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
from typing import Callable | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
from numpy.typing import ArrayLike | ||
from scipy.optimize import root_scalar | ||
|
||
# dictionary of known inverse link functions. | ||
INVERSE_FUNCS = { | ||
jnp.exp: jnp.log, | ||
jax.nn.softplus: lambda x: jnp.log(jnp.exp(x) - 1.0), | ||
} | ||
|
||
|
||
def scalar_root_find_elementwise( | ||
func: Callable, args: ArrayLike, x0: ArrayLike | ||
) -> jnp.ndarray: | ||
""" | ||
Find roots of a scalar function. | ||
|
||
This can be used as an attempt to find a numerical inverse of an unknown link function of a GLM; typically, | ||
this numerical inverse, is used to set the initial intercept to match the mean firing rate of the model. | ||
|
||
Parameters | ||
---------- | ||
func: | ||
A callable, which typically will be `inv_link_func(x) - jnp.mean(spikes)`. | ||
args: | ||
List of additional arguments passed to the function. | ||
x0: | ||
Initial values for the root-finding algorithm. | ||
|
||
Returns | ||
------- | ||
: | ||
An array containing the roots of each f(x) = func(x, args[k]), for k in 1,..., len(args). | ||
|
||
Raises | ||
------ | ||
ValueError: | ||
If any of the optimization is not successful. | ||
""" | ||
opts = [root_scalar(func, arg, x0=x, method="secant") for arg, x in zip(args, x0)] | ||
|
||
if not all(jnp.abs(func(opt.root, args[i])) < 10**-4 for i, opt in enumerate(opts)): | ||
raise ValueError( | ||
"Could not set the initial intercept as the inverse of the firing rate for " | ||
"the provided link function. " | ||
"Please, provide initial parameters instead!" | ||
) | ||
|
||
return jnp.array([opt.root for opt in opts]) | ||
|
||
|
||
def initialize_intercept_matching_mean_rate( | ||
inverse_link_function: Callable, y: jnp.ndarray | ||
) -> jnp.ndarray: | ||
""" | ||
Compute the initial intercept term for a regression models. | ||
|
||
This method compute an initial intercept term for a regression models such that the baseline activity | ||
matches the mean activity of each neuron, assuming that the model coefficients are initialized to zero. | ||
|
||
|
||
Parameters | ||
---------- | ||
inverse_link_function: | ||
The inverse link function of the model, linking the mean to the linear combination of the covariates in | ||
a GLM. | ||
y: | ||
The neural activity, shape either (num_sample,) for single variable regressors as `GLM` | ||
or (n_sample, n_neurons) for multi-variable regressors, such as `PopulaitonGLM`. | ||
|
||
Returns | ||
------- | ||
: | ||
The initial intercept term, shape (n_neurons,). | ||
|
||
""" | ||
|
||
# return inverse if analytical solution is available | ||
analytical_inv = INVERSE_FUNCS.get(inverse_link_function, None) | ||
|
||
means = jnp.atleast_1d(jnp.mean(y, axis=0)) | ||
if analytical_inv: | ||
out = analytical_inv(means) | ||
if jnp.any(jnp.isnan(out)): | ||
raise ValueError( | ||
"Failed to initialize the model intercept as the inverse of the firing rate for " | ||
"the provided link function. The mean firing rate has some non-positive values." | ||
) | ||
return out | ||
|
||
def func(x, mean_x): | ||
return inverse_link_function(x) - mean_x | ||
|
||
try: | ||
out = scalar_root_find_elementwise(func, means, means) | ||
except ValueError: | ||
raise ValueError( | ||
"Failed to initialize the model intercept as the inverse of the firing rate for the" | ||
" provided link function. Please, provide initial parameters instead!" | ||
) | ||
return out |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general, we can't provide any info as to why this failed, right? I think this message is a bit opaque for users (for example), but maybe we are treating users who use a non-standard link function as advanced.
I do think we could have
initialize_intercept_matching_mean_rate
catch this ValueError and raise a more specific error, saying that we were unable to set the initial parameters to match the mean firing rate.