Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change scan so that compiler specialisation will be active #83

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions EpiAware/src/EpiAware.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ using Distributions,
DataFramesMeta

# Exported utilities
export create_discrete_pmf, spread_draws
export create_discrete_pmf, spread_draws, scan

# Exported types
export EpiData, Renewal, ExpGrowthRate, DirectInfections
export EpiData, Renewal, ExpGrowthRate, DirectInfections, AbstractEpiModel
seabbs marked this conversation as resolved.
Show resolved Hide resolved

# Exported Turing model constructors
export make_epi_inference_model
Expand Down
33 changes: 21 additions & 12 deletions EpiAware/src/utilities.jl
Original file line number Diff line number Diff line change
@@ -1,23 +1,32 @@

"""
scan(f, init, xs)
scan(f::F, init, xs) where {F <: AbstractEpiModel}

Apply `f` to each element of `xs` and accumulate the results.

`f` must be a [callable](https://docs.julialang.org/en/v1/manual/methods/#Function-like-objects)
on a sub-type of `AbstractEpiModel`.

### Design note
`scan` is being restricted to `AbstractEpiModel` sub-types to ensure:
1. That compiler specialization is [activated](https://docs.julialang.org/en/v1/manual/performance-tips/#Be-aware-of-when-Julia-avoids-specializing)
2. Also avoids potential compiler [overhead](https://docs.julialang.org/en/v1/devdocs/functions/#compiler-efficiency-issues)
from specialisation on `f<: Function`.


Apply a function `f` to each element of `xs` along with an accumulator hidden state with intial
seabbs marked this conversation as resolved.
Show resolved Hide resolved
value `init`. The function `f` takes the current accumulator value and the current element of `xs` as
arguments, and returns a new accumulator value and a result value. The function `scan` returns a tuple
`(ys, carry)`, where `ys` is an array containing the result values and `carry` is the final accumulator
value. This is similar to the JAX function `jax.lax.scan`.

# Arguments
- `f`: A function that takes an accumulator value and an element of `xs` as arguments and returns a new
hidden state.
- `init`: The initial accumulator value.
- `f`: A callable/functor that takes two arguments, `carry` and `x`, and returns a new
`carry` and a result `y`.
- `init`: The initial value for the `carry` variable.
- `xs`: An iterable collection of elements.

# Returns
- `ys`: An array containing the result values of applying `f` to each element of `xs`.
- `carry`: The final accumulator value.
- `ys`: An array containing the results of applying `f` to each element of `xs`.
- `carry`: The final value of the `carry` variable after processing all elements of `xs`.

"""
function scan(f, init, xs::Vector{T}) where {T <: Union{Integer, AbstractFloat}}
function scan(f::F, init, xs) where {F <: AbstractEpiModel}
carry = init
ys = similar(xs)
for (i, x) in enumerate(xs)
Expand Down
28 changes: 26 additions & 2 deletions EpiAware/test/test_utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,19 @@
xs = [1, 2, 3, 4, 5]
expected_ys = [1, 3, 6, 10, 15]
expected_carry = 15
ys, carry = EpiAware.scan(add, 0, xs)

# Check that a generic function CAN'T be used
@test_throws MethodError EpiAware.scan(add, 0, xs)

# Check that a callable subtype of `AbstractEpiModel` CAN be used
struct TestEpiModelAdd <: AbstractEpiModel
end
function (epimodel::TestEpiModelAdd)(a, b)
return a + b, a + b
end

ys, carry = EpiAware.scan(TestEpiModelAdd(), 0, xs)

@test ys == expected_ys
@test carry == expected_carry
end
Expand All @@ -22,7 +34,19 @@ end
expected_ys = [1, 2, 6, 24, 120]
expected_carry = 120

ys, carry = EpiAware.scan(multiply, 1, xs)
# Check that a generic function CAN'T be used
@test_throws MethodError ys, carry=EpiAware.scan(multiply, 1, xs)

# Check that a callable subtype of `AbstractEpiModel` CAN be used
struct TestEpiModelMult <: AbstractEpiModel
end

function (epimodel::TestEpiModelMult)(a, b)
return a * b, a * b
end

ys, carry = EpiAware.scan(TestEpiModelMult(), 1, xs)

@test ys == expected_ys
@test carry == expected_carry
end
Expand Down
Loading