Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Explore running StackObservationModels in parallel #254

Open
seabbs opened this issue Jun 4, 2024 · 8 comments
Open

Explore running StackObservationModels in parallel #254

seabbs opened this issue Jun 4, 2024 · 8 comments

Comments

@seabbs
Copy link
Collaborator

seabbs commented Jun 4, 2024

It is possible that Turing.jl already supports running submodels in parallel at least with some backends. It seems possible that if they don't now this functionality may be added if there is user demand.

StackObservationModels is a clear target for us as it is naively parallel and anything that is being stacked is likely to take sufficient compute to make this worthwhile. It is also likely that we will use this pattern again (i.e for multiple renewal processes) and so solving this here would solve it more widely in the tooling.

See https://discourse.julialang.org/t/within-chain-parallelization-with-turing-jl/103402/4 for some discussion around current within-chain support (TLDR: Limited to calls where the LHS is fixed (i.e observations)).

@SamuelBrand1
Copy link
Collaborator

For within chain parallelisation the impression is that currently the concept is that this is down to the user using whatever tools from the Julia language.

Are there inspiring examples from other PPLs? Atm, it seems quite case-by-case what the best form of within chain parallel to use and so I can see why Turing is not seeing that as something to offer.

One consideration here is that is easier to parallelise the accumulation of log-posterior density than it is to parallelise the full ~ action in Turing. In my opinion, this is an argument for having an "inference mode" in Turing as discussed TuringLang/DynamicPPL.jl#510

In terms of nice packages in wider Julia, Floops.jl has nice functionality here, and you can at least in Forward mode get it to run with Turing as per this Pluto nb. Note that to get this to work I had to use @addlogprob!....

@seabbs
Copy link
Collaborator Author

seabbs commented Jun 4, 2024

For within chain parallelisation the impression is that currently the concept is that this is down to the user using whatever tools from the Julia language.

So yes but the only live example is restricted to having observations on the left-hand side. This would have very limited utility for us - especially if it only supports forward diff. As the original comment points out we have somee very clear uses cases where multithreading would be useful (namely dispatching on submodel in a for loop - this would closely mirror how reduce_sum is used in stan if limited to only submodels that don't produce output used by other parts of the model (i.e StackObservationModels would be this.

Atm, it seems quite case-by-case what the best form of within chain parallel to use and so I can see why Turing is not seeing that as something to offer.

I'm struggling to see support for this point given there are many obvious cases where multi-threading is what makes it possible to run very large models in any reasonable time. Given the current relatively bad performance of Turing and the relative ease of supporting the julia parallel ecosystem this seems relatively higher priority here than in other PPLs (to oft set slowness).

I see your point about an inference mode but I think slightly off topic.

Examples

reduce_sum from stan: https://mc-stan.org/docs/stan-users-guide/parallelization.html#reduce-sum

does a lot of heavy lifting for you.

numpyro has approaches for this via Jax some of which are automated and some of which are not. No example pinned down though. It has the explicit plate context which seems like reduce_sum in stan but I have seen nothing that suggests it does in fact work in parallel.

@SamuelBrand1
Copy link
Collaborator

The reduce_sum approach is the one thats easy to implement directly? In what you linked stan is offering the user the chance to rewrite their code so that conditionally independent log post density sums can get parallelised whilst killing the syntactic sugar of the stan language... I don't see whats different to using Floops here?

@SamuelBrand1
Copy link
Collaborator

e.g they convert:

data {
  int N;
  array[N] int y;
  vector[N] x;
}
parameters {
  vector[2] beta;
}
model {
  beta ~ std_normal();
  y ~ bernoulli_logit(beta[1] + beta[2] * x);
}

into

functions {
  real partial_sum(array[] int y_slice,
                   int start, int end,
                   vector x,
                   vector beta) {
    return bernoulli_logit_lpmf(y_slice | beta[1] + beta[2] * x[start:end]);
  }
}
data {
  int N;
  array[N] int y;
  vector[N] x;
}
parameters {
  vector[2] beta;
}
model {
  int grainsize = 1;
  beta ~ std_normal();
  target += reduce_sum(partial_sum, y,
                       grainsize,
                       x, beta);
}

@SamuelBrand1
Copy link
Collaborator

SamuelBrand1 commented Jun 4, 2024

Whereas in the normal Turing -> Floops example (albeit for MvNormal)

@model function test_model(n::Integer)
	μ ~ Normal(1.0, 1.0)
	σ ~ truncated(Normal(0., 1.), 0, Inf)
	y ~ MvNormal* ones(n), σ * ones(n))
end

goes to

@model function test_model_floops(y, n::Integer)
	μ ~ Normal(1.0, 1.0)
	σ ~ truncated(Normal(0., 1.), 0, Inf)
	
	lls = Vector{eltype(μ)}(undef, n)
	let v = [μ, σ]
		@floop for i in axes(y)
			lls[i] = logpdf(Normal(v[1], v[2]), y[i])
		end
	end
	Turing.@addlogprob! sum(lls)
end

Which seems to unlock the same kind of performance improvement for the same kind of downsides as the stan example?

@SamuelBrand1
Copy link
Collaborator

numpyro has approaches for this via Jax some of which are automated and some of which are not. No example pinned down though.

@dylanhmorris , do you have any cool examples here?

@seabbs
Copy link
Collaborator Author

seabbs commented Jun 4, 2024

Which seems to unlock the same kind of performance improvement for the same kind of downsides as the stan example?

This example is extremely limited vs what you can do in stan (i.e https://github.com/epinowcast/epinowcast/blob/886b45cb4bc5f338fa53d22e83d25335c55b1a4a/inst/stan/epinowcast.stan#L400 which runs all of the complicated obs model in parallel -i.e effectively dispatching over submodels).

I don't think our current approach works naturally with rephashsing everything as obs ~ lots of stuff hence the suggestion that we want to target being able to do this over submodels.

@seabbs
Copy link
Collaborator Author

seabbs commented Jun 4, 2024

@SamuelBrand1 and I had a f2f with the conclusion being that in the first instance we are aiming to check if @submodel can be dispatched in parallel.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants