Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test Mooncake AD for EpiAware models #454

Open
SamuelBrand1 opened this issue Sep 10, 2024 · 19 comments
Open

Test Mooncake AD for EpiAware models #454

SamuelBrand1 opened this issue Sep 10, 2024 · 19 comments
Labels

Comments

@SamuelBrand1
Copy link
Collaborator

This could be working now https://github.com/TuringLang/Turing.jl/pull/2289/files#review-changes-modal .

Tapir AD looks a really good advance on ReverseDiff, so this would be good https://github.com/compintell/Tapir.jl

@seabbs
Copy link
Collaborator

seabbs commented Sep 10, 2024

We should look at adding to benchmarking

@SamuelBrand1
Copy link
Collaborator Author

A feature to look for is if accumulate diff calls get a boost. A rrule exists for accumulate here but its unclear to me that ReverseDiff get this.

@yebai
Copy link

yebai commented Sep 12, 2024

cc @willtebbutt who can help.

@willtebbutt
Copy link

Thanks for tagging me in this @yebai . To know for sure whether Tapir.jl will be of use I'd have to know a bit more about exactly what problems you're interested in being able to differentiate, but a quick demo involving accumulate:

using Pkg
Pkg.activate(; temp=true)
Pkg.add(["BenchmarkTools", "ReverseDiff", "Tapir"])
using BenchmarkTools, ReverseDiff, Tapir

f(x) = sum(identity, accumulate(+, x))
x = randn(1_000_000);

@benchmark f($x)

tape = ReverseDiff.compile(ReverseDiff.GradientTape(f, x));
gradient_storage = zero(x);
@benchmark ReverseDiff.gradient!($gradient_storage, $tape, $x)

rule = Tapir.build_rrule(f, x)
@benchmark Tapir.value_and_gradient!!($rule, f, $x)

yields

julia> @benchmark f($x)
BenchmarkTools.Trial: 1320 samples with 1 evaluation.
 Range (min  max):  1.747 ms  280.507 ms  ┊ GC (min  max):  0.00%  99.13%
 Time  (median):     2.330 ms               ┊ GC (median):     0.00%
 Time  (mean ± σ):   3.780 ms ±   9.231 ms  ┊ GC (mean ± σ):  34.48% ± 19.17%

  ▄█      ▁                                                    
  ███▁▁▁▃██▇▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▃▃▃▃▃▁▁▃▆▃▄▅▄▆▃▄▄▆▅▅▃▁▅▅ ▇
  1.75 ms      Histogram: log(frequency) by time      27.5 ms <

 Memory estimate: 7.63 MiB, allocs estimate: 2.

julia> @benchmark ReverseDiff.gradient!($gradient_storage, $tape, $x)
BenchmarkTools.Trial: 38 samples with 1 evaluation.
 Range (min  max):  127.823 ms  140.864 ms  ┊ GC (min  max): 0.00%  0.00%
 Time  (median):     133.395 ms               ┊ GC (median):    0.00%
 Time  (mean ± σ):   133.754 ms ±   3.520 ms  ┊ GC (mean ± σ):  0.00% ± 0.00%

  ▁   ▁ █  ██▁▁▁▁    █  ██ ▁█  ▁▁▁  ▁  ▁▁ ▁█ ▁  ▁   ▁█▁ ▁     ▁  
  █▁▁▁█▁█▁▁██████▁▁▁▁█▁▁██▁██▁▁███▁▁█▁▁██▁██▁█▁▁█▁▁▁███▁█▁▁▁▁▁█ ▁
  128 ms           Histogram: frequency by time          141 ms <

 Memory estimate: 0 bytes, allocs estimate: 0.

julia> @benchmark Tapir.value_and_gradient!!($rule, f, $x)
BenchmarkTools.Trial: 106 samples with 1 evaluation.
 Range (min  max):  37.834 ms  589.776 ms  ┊ GC (min  max):  0.00%  91.95%
 Time  (median):     39.679 ms               ┊ GC (median):     0.00%
 Time  (mean ± σ):   48.074 ms ±  56.167 ms  ┊ GC (mean ± σ):  16.70% ± 14.45%

  █▆                                                            
  ███▄▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▄▁▁▁▁▁▁▁▆ ▄
  37.8 ms       Histogram: log(frequency) by time       155 ms <

 Memory estimate: 22.91 MiB, allocs estimate: 272.

(Annoyingly you won't actually be able to run this example for a couple of hours, because I messed something up with the way that Tapir.jl interacts with BenchmarkTools.jl, and have a fix that should be available on the general registry in the next couple of hours -- I had to dev Tapir.jl and checkout to the appropriate branch in order to be able to run this).

There's a decent speed up when compared with ReverseDiff.jl in this case. I'd be interested to know if you've got any other examples that you're keen to try out!

@SamuelBrand1
Copy link
Collaborator Author

SamuelBrand1 commented Sep 12, 2024

Hey @willtebbutt ,

Thanks for coming over to show this!

The broad outline of our interest here is that we expose constructors for various ways of defining discrete time epidemiological models. Any time-stepping is (generally) done with a scanning function that uses Base.accumulate under the hood to propagate a state forward in time dependent on some other process (think the time varying reproduction number).

When doing inference anything that speeds up the grad calls here is going to be very useful.

@willtebbutt
Copy link

Sounds good.

I'm keen to help out, so please do ping me if I can be of use.

@SamuelBrand1
Copy link
Collaborator Author

So long as you have a moderate-to-high tolerance for stupid questions I'll take you up on that!

@wsmoses
Copy link

wsmoses commented Sep 12, 2024

Also could be fun to try Enzyme.jl at the same time.

In Turing code it generally sees an extra order of magnitude over Tapir and also is increasingly getting adopted by big Julia packages as the new default AD.

@yebai
Copy link

yebai commented Sep 12, 2024

it generally sees an extra order of magnitude over Tapir

In my experience, the performance difference between Tapir and Enzyme seems relatively small for Turing models with non-trivial computation. @willtebbutt did an excellent job capitalising on the recent improvements in Julia's compiler API.

@wsmoses
Copy link

wsmoses commented Sep 12, 2024

sounds like another reason to run more benchmarks then :)

@SamuelBrand1
Copy link
Collaborator Author

Also could be fun to try Enzyme.jl at the same time.

In Turing code it generally sees an extra order of magnitude over Tapir and also is increasingly getting adopted by big Julia packages as the new default AD.

Somewhere on my HD I've got a first pass script to write a simple Renewal epi model aimed at working with Enzyme (based on the code in the Box model, but my day-to-day has been a bit intense.

@seabbs
Copy link
Collaborator

seabbs commented Dec 13, 2024

I looked at adding this to our benchmarking in #540 but see failures everywhere (I assume it is my integration at fault)

@SamuelBrand1
Copy link
Collaborator Author

I would be nice to see a before and after of Mooncake given improvements to its Turing performance are being committed (cf TuringLang/Turing.jl#2418)

@seabbs seabbs changed the title Test Tapir AD for EpiAware models Test Mooncake AD for EpiAware models Dec 16, 2024
@seabbs
Copy link
Collaborator

seabbs commented Dec 16, 2024

Sorry @SamuelBrand1 what do you mean?

@wsmoses
Copy link

wsmoses commented Dec 16, 2024

Hm, mind opening an issue for whatever the errors are (would be happy to help you resolve them).

We’ve seen some fairly significant speed ups of Enzyme over mooncake for propprog, for example here: https://nsiccha.github.io/StanBlocks.jl/performance.html#runtime-overview (though I know Will has been working to fix some of mooncakes more egregious slowdowns there)

@SamuelBrand1
Copy link
Collaborator Author

Sorry @SamuelBrand1 what do you mean?

I meant that @willtebbutt has made some improvements to using Mooncake with Turing which we don't have yet, so I thought it might be nice to track the change as we transition to newer version of DPPL and Turing... but its pretty low priority.

@seabbs
Copy link
Collaborator

seabbs commented Dec 16, 2024

thought it might be nice to track the change as we transition to newer version of DPPL and Turing... but its pretty low priority.

Currently its completely broken in the benchmarks for what is I expect user error (as is Enzyme) (i.e. #540) so hard to track but I agree with this idea for both tracking and fixing upstream bug reports (but for that we need it to not silently fail which is what my great implementation does right now).

@seabbs
Copy link
Collaborator

seabbs commented Dec 16, 2024

@wsmoses thanks! I think our current benchmarks and case studies are a bit embedded in the approach we are taking. my hope was we could pull some reprexes of the issues out of them and flag them up stream but not quite there yet

@willtebbutt
Copy link

I meant that @willtebbutt has made some improvements to using Mooncake with Turing which we don't have yet, so I thought it might be nice to track the change as we transition to newer version of DPPL and Turing... but its pretty low priority.

I think these should all be available on the latest versions of DPPL and Mooncake now. I would be surprised if you were to see a big change in performance for the example in this issue though, because the main performance bug was in the interfacing between Turing + DifferentiationInterface + Mooncake, rather than anything internal in Mooncake (I just re-ran locally and saw largely the same results).

(though I know Will has been working to fix some of mooncakes more egregious slowdowns there)

Yeah, these should now be fixed -- I've asked the StanBlocks maintainer to re-run the benchmarks when the opportunity presents itself.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

5 participants