Skip to content

Commit

Permalink
Support for submodels (#233)
Browse files Browse the repository at this point in the history
Part of the motivations for #221 and #222 was so we could add submodels/model-nesting.

Well, now we can. 

Special thanks to @devmotion who reviewed those PRs (several times), improving them significantly, made additional PRs and suggested the current impl of the `@submodel` ❤️ 

EDIT: We fixed the performance:) This now has zero runtime overhead! See comment-section.
EDIT 2: Thanks to @devmotion, we can now alos deal with dynamically specified prefices!

- [Motivating example: AR1-prior](#org46a90a5)
- [Demos](#org7e05701)
- [Can it ever fail?](#org75acb71)
- [Benchmarks](#orga99bcf4)


<a id="org46a90a5"></a>

# Motivating example: AR1-prior

```julia
using Turing
using DynamicPPL
```

```julia
# Could have made model which samples `num_obs` AR1 samples simulatenously,
# but for the sake of showing off dynamic prefixes, we'll only use a vector-implementation.
# The matrix implementation will be quite a bit faster too, but oh well.
@model function AR1(num_steps, α, μ, σ, ::Type{TV} = Vector{Float64}) where {TV}
    η ~ MvNormal(num_steps, 1.0)
    δ = sqrt(1 - α^2)

    x = TV(undef, num_steps)
    x[1] = η[1]
    @inbounds for t = 2:num_steps
        x[t] = @. α * x[t - 1] + δ * η[t]
    end

    return @. μ + σ * x
end

# Generate an observation
σ_obs = 0.1
num_obs = 5
num_steps = 10

ar1 = AR1(num_steps, 0.5, 1.0, 1.0)
ys = mapreduce(hcat, 1:num_obs) do i
    ar1() + σ_obs * randn(num_steps)
end
```

    10×5 Matrix{Float64}:
      2.30189    0.301618  1.73268   -0.65096    1.46835
      2.11187   -1.34878   2.3728     1.02125    3.28422
     -0.249064   0.769488  1.34044    3.22175    2.52196
     -0.25863   -0.216914  0.528954   3.04756    3.8234
      0.372122   0.473511  0.708068   0.76197    0.202003
      0.41487    0.759435  1.80162    0.790204   0.12331
      1.32585    0.567929  2.74316    1.0874     2.82701
      1.84307    1.16138   1.36382    0.735388   1.07423
      3.20139    0.75177   1.57236    0.865401  -0.315341
      1.22479    1.35688   2.8239     0.597959   0.587955

```julia
@model function demo(y)
    α ~ Uniform()
    μ ~ Normal()
    σ ~ truncated(Normal(), 0, Inf)

    num_steps = size(y, 1)
    num_obs = size(y, 2)
    @inbounds for i = 1:num_obs
        x = @submodel $(Symbol("ar1_$i")) AR1(num_steps, α, μ, σ)
        y[:, i] ~ MvNormal(x, 0.1)
    end
end;

m = demo(y);
vi = VarInfo(m);
```

```julia
keys(vi)
```

    8-element Vector{VarName{sym, Tuple{}} where sym}:
     α
     μ
     σ
     ar1_1.η
     ar1_2.η
     ar1_3.η
     ar1_4.η
     ar1_5.η

```julia
vi[@varname α]
```

    0.9383208224122919

```julia
chain = sample(m, NUTS(1_000, 0.8), 3_000);
```

    ┌ Info: Found initial step size
    │   ϵ = 0.025
    └ @ Turing.Inference /home/tor/.julia/packages/Turing/rHLGJ/src/inference/hmc.jl:188
    Sampling: 100%|█████████████████████████████████████████| Time: 0:04:00

```julia
chain[1001:end, [:α, :μ, :σ], :]
```

    Chains MCMC chain (2000×3×1 Array{Float64, 3}):
    
    Iterations        = 1001:3000
    Thinning interval = 1
    Chains            = 1
    Samples per chain = 2000
    parameters        = α, μ, σ
    internals         = 
    
    Summary Statistics
      parameters      mean       std   naive_se      mcse        ess      rhat 
          Symbol   Float64   Float64    Float64   Float64    Float64   Float64 
    
               α    0.5474    0.1334     0.0030    0.0073   159.6969    0.9995
               μ    1.0039    0.2733     0.0061    0.0168   169.9106    1.0134
               σ    1.1294    0.1807     0.0040    0.0106   166.8670    0.9998
    
    Quantiles
      parameters      2.5%     25.0%     50.0%     75.0%     97.5% 
          Symbol   Float64   Float64   Float64   Float64   Float64 
    
               α    0.2684    0.4625    0.5534    0.6445    0.7861
               μ    0.4248    0.8227    1.0241    1.2011    1.4801
               σ    0.8781    1.0018    1.0989    1.2239    1.5472

Yay! We recovered the true parameters :tada:

```julia
@benchmark $m($vi)
```

    BenchmarkTools.Trial: 
      memory estimate:  12.05 KiB
      allocs estimate:  123
      --------------
      minimum time:     15.091 μs (0.00% GC)
      median time:      17.861 μs (0.00% GC)
      mean time:        19.582 μs (5.23% GC)
      maximum time:     10.293 ms (99.46% GC)
      --------------
      samples:          10000
      evals/sample:     1


<a id="org7e05701"></a>

# Demos

```julia
using DynamicPPL, Distributions
```

    ┌ Info: Precompiling DynamicPPL [366bfd00-2699-11ea-058f-f148b4cae6d8]
    └ @ Base loading.jl:1317

```julia
@model function demo1(x)
    x ~ Normal()
end;
@model function demo2(x, y)
    @submodel demo1(x)
    y ~ Uniform()
end false;
m2 = demo2(missing, missing);
vi2 = VarInfo(m2);
keys(vi2)
```

    2-element Vector{VarName{sym, Tuple{}} where sym}:
     x
     y

```julia
println(vi2[VarName(Symbol("x"))])
println(vi2[VarName(Symbol("y"))])
```

    0.3069117531180063
    0.7325324947386318

We can also `observe` without issues:

```julia
@model function demo2(x, y)
    @submodel demo1(x)
    y ~ Normal(x)
end false;
m2 = demo2(1000.0, missing);
vi2 = VarInfo(m2);
keys(vi2)
```

    1-element Vector{VarName{:y, Tuple{}}}:
     y

```julia
vi2[@varname y]
```

    1000.3905079427211

```julia
DynamicPPL.getlogp(vi2)
```

    -500001.9141252931

But what if the models have the same variable-names?!

"Sure, this is cool and all, but can we even use the values from the nested values in the parent model?"

```julia
@model function demo_return(x)
    x ~ Normal()
    return x
end;

@model function demo_useval(x, y)
    x1 = @submodel sub1 demo_return(x)
    x2 = @submodel sub2 demo_return(y)

    z ~ Normal(x1 + x2 + 100, 1.0)
end false;
vi = VarInfo(demo_useval(missing, missing));
keys(vi)
```

    3-element Vector{VarName{sym, Tuple{}} where sym}:
     sub1.x
     sub2.x
     z

```julia
vi[@varname z]
```

    101.09066854862154

And just to prove a point:

```julia
@model function nested(x, y)
    @submodel 1 nested1(x, y)
    y ~ Uniform()
end false;
@model function nested1(x, y)
    @submodel 2 nested2(x, y)
    y ~ Uniform()
end false;
@model function nested2(x, y)
    z = @submodel 3 nested3(x, y)
    y ~ Normal(z, 1.0)
end false;
@model function nested3(x, y)
    x ~ Uniform()
    y ~ Normal(-100.0, 1.0)

    return x + 1000
end false;

m = nested(missing, missing);
vi = VarInfo(m);
keys(vi)
```

    5-element Vector{VarName{sym, Tuple{}} where sym}:
     1.2.3.x
     1.2.3.y
     1.2.y
     1.y
     y

```julia
vi[VarName(Symbol("1.2.y"))]
```

    1000.5609156083766

```julia
DynamicPPL.getlogp(vi)
```

    -4.620040828101227


<a id="org75acb71"></a>

# Can it ever fail?

Yeah, if the user doesn't provide the prefix, it can:

```julia
@model function nested(x, y)
    @submodel nested1(x, y)
    y ~ Uniform()
end false;
@model function nested1(x, y)
    @submodel nested2(x, y)
    y ~ Uniform()
end false;
@model function nested2(x, y)
    z = @submodel nested3(x, y)
    y ~ Normal(z, 1.0)
end false;
@model function nested3(x, y)
    x ~ Uniform()
    y ~ Normal(-100.0, 1.0)

    return x + 1000
end false;

m = nested(missing, missing);
vi = VarInfo(m);
keys(vi)
```

    2-element Vector{VarName{sym, Tuple{}} where sym}:
     x
     y

```julia
# Inner-most value is recorded (i.e. the first one reached)
vi[@varname y]
```

    -100.16836599596732

And it messes up the logp computation:

```julia
DynamicPPL.getlogp(vi)
```

    -Inf

But I could imagine there's a way for us to fix this, or at least warn the user when this happens.


<a id="orga99bcf4"></a>

# Benchmarks

At this point you're probably wondering, "but does it have any overhead (at runtime)?". For a "shallow" nestings, nah, but if you go deep enough there seems to be a tiny bit (likely because we're calling the "constructor" for the model):

```julia
using BenchmarkTools

@model function base(x, y)
    x ~ Uniform()
    y ~ Uniform()
    y1 ~ Uniform()
    z = x + 1000
    y12 ~ Normal()
    y123 ~ Normal(-100.0, 1.0)
end

m1 = base(missing, missing);
vi1 = VarInfo(m1);
```

```julia
@model function nested_shallow(x, y)
    @submodel 1 nested1_shallow(x, y)
    y ~ Uniform()
end false;
@model function nested1_shallow(x, y)
    x ~ Uniform()
    y ~ Uniform()
    z = x + 1000
    y12 ~ Normal()
    y123 ~ Normal(-100.0, 1.0)
end false;

m2 = nested_shallow(missing, missing);
vi2 = VarInfo(m2);
```

```julia
@model function nested(x, y)
    @submodel 1 nested1(x, y)
    y ~ Uniform()
end false;
@model function nested1(x, y)
    @submodel 2 nested2(x, y)
    y ~ Uniform()
end false;
@model function nested2(x, y)
    z = @submodel 3 nested3(x, y)
    y ~ Normal(z, 1.0)
end false;
@model function nested3(x, y)
    x ~ Uniform()
    y ~ Normal(-100.0, 1.0)

    return x + 1000
end

m3 = nested(missing, missing);
vi3 = VarInfo(m3);
```

```julia
@model function nested_noprefix(x, y)
    @submodel nested_noprefix1(x, y)
    y ~ Uniform()
end false;
@model function nested_noprefix1(x, y)
    @submodel nested_noprefix2(x, y)
    y1 ~ Uniform()
end false;
@model function nested_noprefix2(x, y)
    z = @submodel nested_noprefix3(x, y)
    y2 ~ Normal(z, 1.0)
end false;
@model function nested_noprefix3(x, y)
    x ~ Uniform()
    y3 ~ Normal(-100.0, 1.0)

    return x + 1000
end

m4 = nested_noprefix(missing, missing);
vi4 = VarInfo(m4);
```

```julia
keys(vi1)
```

    5-element Vector{VarName{sym, Tuple{}} where sym}:
     x
     y
     y1
     y12
     y123

```julia
keys(vi2)
```

    5-element Vector{VarName{sym, Tuple{}} where sym}:
     1.x
     1.y
     1.y12
     1.y123
     y

```julia
keys(vi3)
```

    5-element Vector{VarName{sym, Tuple{}} where sym}:
     1.2.3.x
     1.2.3.y
     1.2.y
     1.y
     y

```julia
keys(vi4)
```

    5-element Vector{VarName{sym, Tuple{}} where sym}:
     x
     y3
     y2
     y1
     y

```julia
@benchmark $m1($vi1)
```

    BenchmarkTools.Trial: 
      memory estimate:  160 bytes
      allocs estimate:  5
      --------------
      minimum time:     1.714 μs (0.00% GC)
      median time:      1.747 μs (0.00% GC)
      mean time:        1.835 μs (0.00% GC)
      maximum time:     6.894 μs (0.00% GC)
      --------------
      samples:          10000
      evals/sample:     10

```julia
@benchmark $m2($vi2)
```

    BenchmarkTools.Trial: 
      memory estimate:  160 bytes
      allocs estimate:  5
      --------------
      minimum time:     1.759 μs (0.00% GC)
      median time:      1.778 μs (0.00% GC)
      mean time:        1.819 μs (0.00% GC)
      maximum time:     5.563 μs (0.00% GC)
      --------------
      samples:          10000
      evals/sample:     10

```julia
@benchmark $m3($vi3)
```

    BenchmarkTools.Trial: 
      memory estimate:  160 bytes
      allocs estimate:  5
      --------------
      minimum time:     1.718 μs (0.00% GC)
      median time:      1.746 μs (0.00% GC)
      mean time:        1.787 μs (0.00% GC)
      maximum time:     5.758 μs (0.00% GC)
      --------------
      samples:          10000
      evals/sample:     10

```julia
@benchmark $m4($vi4)
```

    BenchmarkTools.Trial: 
      memory estimate:  160 bytes
      allocs estimate:  5
      --------------
      minimum time:     1.672 μs (0.00% GC)
      median time:      1.696 μs (0.00% GC)
      mean time:        1.756 μs (0.00% GC)
      maximum time:     4.882 μs (0.00% GC)
      --------------
      samples:          10000
      evals/sample:     10

Notice that the number of allocations have increased for the deeply nested model. Seems like the Julia compiler isn't too good at inferring the return-types of Turing-models? This seems to be the case too by looking at the lowered code. I haven't given this too much thought yet btw; likely is a way for us to help the compiler here.
  • Loading branch information
torfjelde committed May 18, 2021
1 parent f7531ba commit 0f7548d
Show file tree
Hide file tree
Showing 7 changed files with 164 additions and 4 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.10.19"
version = "0.10.20"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down
6 changes: 5 additions & 1 deletion src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ export AbstractVarInfo,
LikelihoodContext,
PriorContext,
MiniBatchContext,
PrefixContext,
assume,
dot_assume,
observer,
Expand All @@ -96,7 +97,9 @@ export AbstractVarInfo,
logjoint,
pointwise_loglikelihoods,
# Convenience macros
@addlogprob!
@addlogprob!,
@submodel


# Reexport
using Distributions: loglikelihood
Expand Down Expand Up @@ -124,5 +127,6 @@ include("compiler.jl")
include("prob_macro.jl")
include("compat/ad.jl")
include("loglikelihoods.jl")
include("submodel_macro.jl")

end # module
4 changes: 2 additions & 2 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ check_tilde_rhs(x::AbstractArray{<:Distribution}) = x
#################

"""
@model(expr[, warn = true])
@model(expr[, warn = false])
Macro to specify a probabilistic model.
Expand All @@ -73,7 +73,7 @@ end
To generate a `Model`, call `model(xvalue)` or `model(xvalue, yvalue)`.
"""
macro model(expr, warn=true)
macro model(expr, warn=false)
# include `LineNumberNode` with information about the call site in the
# generated function for easier debugging and interpretation of error messages
esc(model(__module__, __source__, expr, warn))
Expand Down
6 changes: 6 additions & 0 deletions src/context_implementations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ end
function tilde(rng, ctx::MiniBatchContext, sampler, right, left::VarName, inds, vi)
return tilde(rng, ctx.ctx, sampler, right, left, inds, vi)
end
function tilde(rng, ctx::PrefixContext, sampler, right, vn::VarName, inds, vi)
return tilde(rng, ctx.ctx, sampler, right, prefix(ctx, vn), inds, vi)
end

"""
tilde_assume(rng, ctx, sampler, right, vn, inds, vi)
Expand Down Expand Up @@ -75,6 +78,9 @@ end
function tilde(ctx::MiniBatchContext, sampler, right, left, vi)
return ctx.loglike_scalar * tilde(ctx.ctx, sampler, right, left, vi)
end
function tilde(ctx::PrefixContext, sampler, right, left, vi)
return tilde(ctx.ctx, sampler, right, left, vi)
end

"""
tilde_observe(ctx, sampler, right, left, vname, vinds, vi)
Expand Down
26 changes: 26 additions & 0 deletions src/contexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,29 @@ end
function MiniBatchContext(ctx = DefaultContext(); batch_size, npoints)
return MiniBatchContext(ctx, npoints/batch_size)
end


struct PrefixContext{Prefix, C} <: AbstractContext
ctx::C
end
PrefixContext{Prefix}(ctx::AbstractContext) where {Prefix} = PrefixContext{Prefix, typeof(ctx)}(ctx)

const PREFIX_SEPARATOR = Symbol(".")

function PrefixContext{PrefixInner}(
ctx::PrefixContext{PrefixOuter}
) where {PrefixInner, PrefixOuter}
if @generated
:(PrefixContext{$(QuoteNode(Symbol(PrefixOuter, _prefix_seperator, PrefixInner)))}(ctx.ctx))
else
PrefixContext{Symbol(PrefixOuter, PREFIX_SEPARATOR, PrefixInner)}(ctx.ctx)
end
end

function prefix(::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix, Sym}
if @generated
return :(VarName{$(QuoteNode(Symbol(Prefix, _prefix_seperator, Sym)))}(vn.indexing))
else
VarName{Symbol(Prefix, PREFIX_SEPARATOR, Sym)}(vn.indexing)
end
end
23 changes: 23 additions & 0 deletions src/submodel_macro.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
macro submodel(expr)
return quote
_evaluate(
$(esc(:__rng__)),
$(esc(expr)),
$(esc(:__varinfo__)),
$(esc(:__sampler__)),
$(esc(:__context__))
)
end
end

macro submodel(prefix, expr)
return quote
_evaluate(
$(esc(:__rng__)),
$(esc(expr)),
$(esc(:__varinfo__)),
$(esc(:__sampler__)),
PrefixContext{$(esc(Meta.quot(prefix)))}($(esc(:__context__)))
)
end
end
101 changes: 101 additions & 0 deletions test/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,107 @@ end
@test demo2()() == 42
end

@testset "submodel" begin
# No prefix, 1 level.
@model function demo1(x)
x ~ Normal()
end;
@model function demo2(x, y)
@submodel demo1(x)
y ~ Uniform()
end;
# No observation.
m = demo2(missing, missing);
vi = VarInfo(m);
ks = keys(vi)
@test VarName(:x) ks
@test VarName(:y) ks

# Observation in top-level.
m = demo2(missing, 1.0);
vi = VarInfo(m);
ks = keys(vi)
@test VarName(:x) ks
@test VarName(:y) ks

# Observation in nested model.
m = demo2(1000.0, missing);
vi = VarInfo(m);
ks = keys(vi)
@test VarName(:x) ks
@test VarName(:y) ks

# Observe all.
m = demo2(1000.0, 0.5);
vi = VarInfo(m);
ks = keys(vi)
@test isempty(ks)

# Check values makes sense.
@model function demo2(x, y)
@submodel demo1(x)
y ~ Normal(x)
end;
m = demo2(1000.0, missing);
# Mean of `y` should be close to 1000.
@test abs(mean([VarInfo(m)[VarName(:y)] for i = 1:10]) - 1000) 10;

# Prefixed submodels and usage of submodel return values.
@model function demo_return(x)
x ~ Normal()
return x
end;

@model function demo_useval(x, y)
x1 = @submodel sub1 demo_return(x)
x2 = @submodel sub2 demo_return(y)

z ~ Normal(x1 + x2 + 100, 1.0)
end;
m = demo_useval(missing, missing)
vi = VarInfo(m);
ks = keys(vi)
@test VarName(Symbol("sub1.x")) ks
@test VarName(Symbol("sub2.x")) ks
@test VarName(:z) ks
@test abs(mean([VarInfo(m)[VarName(:z)] for i = 1:10]) - 100) 10

# AR1 model. Dynamic prefixing.
@model function AR1(num_steps, α, μ, σ, ::Type{TV} = Vector{Float64}) where {TV}
η ~ MvNormal(num_steps, 1.0)
δ = sqrt(1 - α^2)

x = TV(undef, num_steps)
x[1] = η[1]
@inbounds for t = 2:num_steps
x[t] = @. α * x[t - 1] + δ * η[t]
end

return @. μ + σ * x
end

@model function demo(y)
α ~ Uniform()
μ ~ Normal()
σ ~ truncated(Normal(), 0, Inf)

num_steps = length(y[1])
num_obs = length(y)
@inbounds for i = 1:num_obs
x = @submodel $(Symbol("ar1_$i")) AR1(num_steps, α, μ, σ)
y[i] ~ MvNormal(x, 0.1)
end
end;

ys = [randn(10), randn(10)];
m = demo(ys);
vi = VarInfo(m);

for k in [, , , Symbol("ar1_1.η"), Symbol("ar1_2.η")]
@test VarName(k) keys(vi)
end
end

@testset "check_tilde_rhs" begin
@test_throws ArgumentError DynamicPPL.check_tilde_rhs(randn())

Expand Down

2 comments on commit 0f7548d

@torfjelde
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/36976

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.10.20 -m "<description of version>" 0f7548df6caf6fa10d802c84e05b8c161b0d9cea
git push origin v0.10.20

Please sign in to comment.