Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Output of hierarchical model unnecessarily contains fit for every single trial of parameters with regression formula. #497

Closed
SaschaFroelich opened this issue Jul 18, 2024 · 5 comments

Comments

@SaschaFroelich
Copy link

SaschaFroelich commented Jul 18, 2024

When fitting a hierarchical model like the one from the tutorial:

model_reg_v_angle_hier = hssm.HSSM(
    data=dataset_reg_v_hier,
    model="angle",
    include=[
        {
            "name": "v",
            "prior": {
                "Intercept": {
                    "name": "Uniform",
                    "lower": -3.0,
                    "upper": 3.0,
                    "initval": 0.0,
                },
                "x": {"name": "Uniform", "lower": -1.0, "upper": 1.0, "initval": 0.0},
                "y": {"name": "Uniform", "lower": -1.0, "upper": 1.0, "initval": 0.0},
            },
            "formula": "v ~ 1 + (1|subject) + x + y",
            "link": "identity",
        }
    ],
)

jax.config.update("jax_enable_x64", False)
out = model_reg_v_angle_hier.sample(
    sampler="nuts_numpyro", chains=2, cores=1, draws=30, tune=30
)

out will contain one row for the fitted v of every trial of every subject. This quickly grows to enormous sizes when fitting the model to actual experimental data. For instance, I have a total of 33k trials across 60 participants, and would like to formulate a regression formula for v and for z. The output is several GB large, and simply printing the summary of the inference object to check whether the chains converged takes forever.

It does this also when I change the regression formula to something like v ~ 1 + (1|subject), so that v does not change from trial to trial for each individual subject.

Is there a way to suppress the outputting of the fitting results for each individual trial, and only get the output at group-level, or at subject-level?

Furthermore, in the tutorial, az.plot_forest(model_reg_v_angle_hier.traces) plots the values for v for the individual subjects. But when I copy the tutorial code as is, it tries to plot one row for v per trial, across all subjects (>3k trials).

I use hssm version 0.2.1 on Ubuntu 22.04. with python 3.10.12.

@SaschaFroelich SaschaFroelich changed the title Output of hierarchical model is ginormous Output of hierarchical model contains fit for every single trial of parameter with regression formula. Jul 18, 2024
@SaschaFroelich SaschaFroelich changed the title Output of hierarchical model contains fit for every single trial of parameter with regression formula. Output of hierarchical model contains fit for every single trial of parameters with regression formula. Jul 18, 2024
@SaschaFroelich SaschaFroelich changed the title Output of hierarchical model contains fit for every single trial of parameters with regression formula. Output of hierarchical model unnecessarily contains fit for every single trial of parameters with regression formula. Jul 19, 2024
@digicosmos86
Copy link
Collaborator

Hi,

Can you try specifying include_mean=False when you are calling .sample()? This will not include the means in the InferenceData object. If the arviz plots are not clean enough, you can also specify var_names when you are calling az.plot_forest to filter those variables out.

@SaschaFroelich
Copy link
Author

SaschaFroelich commented Jul 23, 2024

Hi,

Can you try specifying include_mean=False when you are calling .sample()? This will not include the means in the InferenceData object. If the arviz plots are not clean enough, you can also specify var_names when you are calling az.plot_forest to filter those variables out.

Hi,

thanks for your reply! include_mean=False does not make any difference, unfortunately. What is it supposed to do? Yes, I can use var_names for the forest plots, but before I do that I would like to check whether my chains converged, which I do by checking the r_hats with az.summary(out) (unless there's a better way). That takes extremely long (~12 minutes on my machine with 16 cores and 32GB memory). I simply don't think it's necessary to store the result for every single trial (amounting to an InferenceObject of 2GB in my case). Especially when the results are the same for most trials (I differentiate between 3 different trial types). Wouldn't it make sense to include a keyword argument that disables storage of every individual trial fit result?

@digicosmos86
Copy link
Collaborator

Yes we are working on that. Can you see if you can add a var_names argument to model.sample() with only the variables that you want to include? According to PyMC documentation, you can override the default behavior which is to include all free and deterministics.

@frankmj
Copy link
Collaborator

frankmj commented Jul 23, 2024 via email

@digicosmos86
Copy link
Collaborator

It seems that there's no way to exclude variables in the InferenceData though, which can blow up when the model is complex. I have submitted an issue to bambi. bambinos/bambi#828

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants