-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Explore running StackObservationModels in parallel #254
Comments
For within chain parallelisation the impression is that currently the concept is that this is down to the user using whatever tools from the Julia language. Are there inspiring examples from other PPLs? Atm, it seems quite case-by-case what the best form of within chain parallel to use and so I can see why One consideration here is that is easier to parallelise the accumulation of log-posterior density than it is to parallelise the full In terms of nice packages in wider Julia, |
So yes but the only live example is restricted to having observations on the left-hand side. This would have very limited utility for us - especially if it only supports forward diff. As the original comment points out we have somee very clear uses cases where multithreading would be useful (namely dispatching on submodel in a for loop - this would closely mirror how reduce_sum is used in stan if limited to only submodels that don't produce output used by other parts of the model (i.e
I'm struggling to see support for this point given there are many obvious cases where multi-threading is what makes it possible to run very large models in any reasonable time. Given the current relatively bad performance of I see your point about an inference mode but I think slightly off topic. Examples
does a lot of heavy lifting for you.
|
The |
e.g they convert: data {
int N;
array[N] int y;
vector[N] x;
}
parameters {
vector[2] beta;
}
model {
beta ~ std_normal();
y ~ bernoulli_logit(beta[1] + beta[2] * x);
} into functions {
real partial_sum(array[] int y_slice,
int start, int end,
vector x,
vector beta) {
return bernoulli_logit_lpmf(y_slice | beta[1] + beta[2] * x[start:end]);
}
}
data {
int N;
array[N] int y;
vector[N] x;
}
parameters {
vector[2] beta;
}
model {
int grainsize = 1;
beta ~ std_normal();
target += reduce_sum(partial_sum, y,
grainsize,
x, beta);
} |
Whereas in the normal Turing -> Floops example (albeit for MvNormal) @model function test_model(n::Integer)
μ ~ Normal(1.0, 1.0)
σ ~ truncated(Normal(0., 1.), 0, Inf)
y ~ MvNormal(μ * ones(n), σ * ones(n))
end goes to @model function test_model_floops(y, n::Integer)
μ ~ Normal(1.0, 1.0)
σ ~ truncated(Normal(0., 1.), 0, Inf)
lls = Vector{eltype(μ)}(undef, n)
let v = [μ, σ]
@floop for i in axes(y)
lls[i] = logpdf(Normal(v[1], v[2]), y[i])
end
end
Turing.@addlogprob! sum(lls)
end Which seems to unlock the same kind of performance improvement for the same kind of downsides as the stan example? |
@dylanhmorris , do you have any cool examples here? |
This example is extremely limited vs what you can do in stan (i.e https://github.com/epinowcast/epinowcast/blob/886b45cb4bc5f338fa53d22e83d25335c55b1a4a/inst/stan/epinowcast.stan#L400 which runs all of the complicated obs model in parallel -i.e effectively dispatching over submodels). I don't think our current approach works naturally with rephashsing everything as |
@SamuelBrand1 and I had a f2f with the conclusion being that in the first instance we are aiming to check if |
It is possible that
Turing.jl
already supports running submodels in parallel at least with some backends. It seems possible that if they don't now this functionality may be added if there is user demand.StackObservationModels
is a clear target for us as it is naively parallel and anything that is being stacked is likely to take sufficient compute to make this worthwhile. It is also likely that we will use this pattern again (i.e for multiple renewal processes) and so solving this here would solve it more widely in the tooling.See https://discourse.julialang.org/t/within-chain-parallelization-with-turing-jl/103402/4 for some discussion around current within-chain support (TLDR: Limited to calls where the LHS is fixed (i.e observations)).
The text was updated successfully, but these errors were encountered: