Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add documentation about model identifiability #259

Open
BalzaniEdoardo opened this issue Nov 1, 2024 · 2 comments
Open

Add documentation about model identifiability #259

BalzaniEdoardo opened this issue Nov 1, 2024 · 2 comments

Comments

@BalzaniEdoardo
Copy link
Collaborator

  • Background note on identifiability of model including:

    • simple example of firing rate x trial type
    • simple example of bspline
    • show what happens as you fit model from different initial conditions
    • explain that ridge reg or lasso do not have the problem (actually vanilla smoother penalty has the issue, so we cannot say that all regularization solve the issue)
  • Tutorial on Basis and identifiability

@billbrod
Copy link
Member

billbrod commented Nov 5, 2024

There was an example of a rank-deficient matrix in plot_06_calcium_imaging.py, but this was removed in #247. pasting the code here in case it's useful

# %%
# To ensure the computation accounts for the model's intercept term,
# we can add a constant column of ones before calculating the rank.
# Below is a utility function for adding the intercept column.

def add_intercept(X):
    """Add an intercept term to design matrix.

    Convert matrix to float64, drops nans and add intercept term.
    """
    # convert to float64 for rank computation precision
    X = np.asarray(X, dtype=np.float64)
    # drop nans
    X = X[nmo.tree_utils.get_valid_multitree(X)]
    return np.hstack([np.ones((X.shape[0], 1)), X])

print(f"Number of features: {X.shape[1] + 1}")  # num coefficients + intercept
print(f"Matrix rank: {np.linalg.matrix_rank(add_intercept(X))}")

# %%
# By setting,

w = np.ones((X.shape[1] + 1))
w[0] = -1
w[1 + heading_basis.n_basis_funcs:] = 0


# %%
# We have that,

np.max(np.abs(np.dot(add_intercept(X), w)))


# %%
# This implies that there will be infinite different parameters that results in the same firing rate,
# or equivalently there will be infinite many equivalent solutions to an un-regularized GLM.

# define some random coefficients
coef = np.random.randn(X.shape[1] + 1)

# the firing rate is softplus([1, X] * coef)
# adding w to the coefficients does not change the output rate.
firing_rate = jax.nn.softplus(np.dot(add_intercept(X), coef))
firing_rate_2 = jax.nn.softplus(np.dot(add_intercept(X), coef + w))

# check that the rate match
np.allclose(firing_rate, firing_rate_2)

@billbrod
Copy link
Member

billbrod commented Nov 5, 2024

When we add this, remember to link to it from the ## Design matrix section of tutorials/plot_06_calcium_imaging.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants