Skip to content

Commit

Permalink
Merge pull request #83 from CDCgov/82-activate-specialisation-for-fun…
Browse files Browse the repository at this point in the history
…ctions-with-function-input

Change `scan` so that compiler specialisation will be active
  • Loading branch information
seabbs authored Feb 28, 2024
2 parents 8b9edf9 + d676612 commit d9c9222
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 16 deletions.
4 changes: 2 additions & 2 deletions EpiAware/src/EpiAware.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ using Distributions,
DataFramesMeta

# Exported utilities
export create_discrete_pmf, spread_draws
export create_discrete_pmf, spread_draws, scan

# Exported types
export EpiData, Renewal, ExpGrowthRate, DirectInfections
export EpiData, Renewal, ExpGrowthRate, DirectInfections, AbstractEpiModel

# Exported Turing model constructors
export make_epi_inference_model
Expand Down
33 changes: 21 additions & 12 deletions EpiAware/src/utilities.jl
Original file line number Diff line number Diff line change
@@ -1,23 +1,32 @@

"""
scan(f, init, xs)
scan(f::F, init, xs) where {F <: AbstractEpiModel}
Apply `f` to each element of `xs` and accumulate the results.
`f` must be a [callable](https://docs.julialang.org/en/v1/manual/methods/#Function-like-objects)
on a sub-type of `AbstractEpiModel`.
### Design note
`scan` is being restricted to `AbstractEpiModel` sub-types to ensure:
1. That compiler specialization is [activated](https://docs.julialang.org/en/v1/manual/performance-tips/#Be-aware-of-when-Julia-avoids-specializing)
2. Also avoids potential compiler [overhead](https://docs.julialang.org/en/v1/devdocs/functions/#compiler-efficiency-issues)
from specialisation on `f<: Function`.
Apply a function `f` to each element of `xs` along with an accumulator hidden state with intial
value `init`. The function `f` takes the current accumulator value and the current element of `xs` as
arguments, and returns a new accumulator value and a result value. The function `scan` returns a tuple
`(ys, carry)`, where `ys` is an array containing the result values and `carry` is the final accumulator
value. This is similar to the JAX function `jax.lax.scan`.
# Arguments
- `f`: A function that takes an accumulator value and an element of `xs` as arguments and returns a new
hidden state.
- `init`: The initial accumulator value.
- `f`: A callable/functor that takes two arguments, `carry` and `x`, and returns a new
`carry` and a result `y`.
- `init`: The initial value for the `carry` variable.
- `xs`: An iterable collection of elements.
# Returns
- `ys`: An array containing the result values of applying `f` to each element of `xs`.
- `carry`: The final accumulator value.
- `ys`: An array containing the results of applying `f` to each element of `xs`.
- `carry`: The final value of the `carry` variable after processing all elements of `xs`.
"""
function scan(f, init, xs::Vector{T}) where {T <: Union{Integer, AbstractFloat}}
function scan(f::F, init, xs) where {F <: AbstractEpiModel}
carry = init
ys = similar(xs)
for (i, x) in enumerate(xs)
Expand Down
28 changes: 26 additions & 2 deletions EpiAware/test/test_utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,19 @@
xs = [1, 2, 3, 4, 5]
expected_ys = [1, 3, 6, 10, 15]
expected_carry = 15
ys, carry = EpiAware.scan(add, 0, xs)

# Check that a generic function CAN'T be used
@test_throws MethodError EpiAware.scan(add, 0, xs)

# Check that a callable subtype of `AbstractEpiModel` CAN be used
struct TestEpiModelAdd <: AbstractEpiModel
end
function (epimodel::TestEpiModelAdd)(a, b)
return a + b, a + b
end

ys, carry = EpiAware.scan(TestEpiModelAdd(), 0, xs)

@test ys == expected_ys
@test carry == expected_carry
end
Expand All @@ -22,7 +34,19 @@ end
expected_ys = [1, 2, 6, 24, 120]
expected_carry = 120

ys, carry = EpiAware.scan(multiply, 1, xs)
# Check that a generic function CAN'T be used
@test_throws MethodError ys, carry=EpiAware.scan(multiply, 1, xs)

# Check that a callable subtype of `AbstractEpiModel` CAN be used
struct TestEpiModelMult <: AbstractEpiModel
end

function (epimodel::TestEpiModelMult)(a, b)
return a * b, a * b
end

ys, carry = EpiAware.scan(TestEpiModelMult(), 1, xs)

@test ys == expected_ys
@test carry == expected_carry
end
Expand Down

0 comments on commit d9c9222

Please sign in to comment.