Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clarify Inputs API for regular, stabilized, and non-stabilized PT #291

Open
sefffal opened this issue Oct 21, 2024 · 3 comments
Open

Clarify Inputs API for regular, stabilized, and non-stabilized PT #291

sefffal opened this issue Oct 21, 2024 · 3 comments

Comments

@sefffal
Copy link

sefffal commented Oct 21, 2024

Based on discussion in #290 (comment)_
I would suggest that the API be adjusted a bit.

I had previously misunderstood how to control if regular PT, stabilized PT, and non-stabilized PT are used.

Here is the current API usage for Inputs:
Regular PT: n_chains=8, n_chains_variational=0, variational=nothing
Stabilized Variational PT: n_chains=8, n_chains_variational=8, variational=GaussianReference....
Non-stabilized variational PT: n_chains=8, n_chains_variational=0, variational=GaussianReference....

Here is how I thought it worked:
Regular PT: n_chains=8, n_chains_variational=0, variational=Anythong (ie. variational argument was ignored if n_chains_variational=0,)
Stabilized Variational PT: n_chains=8, n_chains_variational=8, variational=GaussianReference....
Non-stabilized variational PT: n_chains=0, n_chains_variational=8, variational=GaussianReference....

Another factor that lead to my confusion is that in Non-stabilized variational PT, the sampler output table shows the column Λ instead of Λ_var. IMO this last point could be considered a bug?

In fact, I think I have at times been using non-stabilized variational PT when I thought I was using regular PT (with no variational chains).

Thanks all!

@sefffal
Copy link
Author

sefffal commented Oct 23, 2024

I ran into another API edge case that I think could be clarified, or made into an error:

Inputs(
   n_chains=N, # where N>0 
   n_chains_variational=M, # where M > 0
   variational=nothing, # or left default
)

Sets up two fixed legs, and names one of them variational and one fixed. It also prints out both Gamma and Gamma_var. I think that's a little misleading, since it's effectively just running two copies of standard PT in parallel.

@trevorcampbell
Copy link
Collaborator

trevorcampbell commented Oct 23, 2024

Yeah, I also ran into difficulties in understanding the variational/fixed API when I first used Pigeons as well. At a minimum, the fact that the n_chains argument changes meaning depending on other arguments is tricky.

@miguelbiron @alexandrebouchard could we not just rename variational to reference, and then allow n_chains and reference to take in either an integer / Distribution, or a list of integers / list of distributions? This makes it clear that each leg of PT is associated with a reference and number of chains. If all the references are "nothing" or some other fixed distribution, it won't be tuned, and for each distribution in the list that is tunable, we tune it.

If we don't want to allow arbitrary star topologies right now, we can limit the list length to 2. Or we can say "integer/distribution or 2-tuple of integers/distributions"

Thoughts?

@nikola-sur
Copy link
Collaborator

I agree that that's a super confusing API. I like Trevor's idea of allowing to supply an integer or a 2-tuple for the number of chains and cleaning it up.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants