diff --git a/Project.toml b/Project.toml index 4f4a5ecaa..4221d64fe 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.24.1" +version = "0.24.2" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index 67c2f3fcb..fe2c0b3e5 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -107,12 +107,20 @@ Set the log of the joint probability of the observed data and parameters sampled function setlogp!! end """ - acclogp!!(vi::AbstractVarInfo, logp) + acclogp!!([context::AbstractContext, ]vi::AbstractVarInfo, logp) Add `logp` to the value of the log of the joint probability of the observed data and parameters sampled in `vi`, mutating if it makes sense. """ -function acclogp!! end +function acclogp!!(context::AbstractContext, vi::AbstractVarInfo, logp) + return acclogp!!(NodeTrait(context), context, vi, logp) +end +function acclogp!!(::IsLeaf, context::AbstractContext, vi::AbstractVarInfo, logp) + return acclogp!!(vi, logp) +end +function acclogp!!(::IsParent, context::AbstractContext, vi::AbstractVarInfo, logp) + return acclogp!!(childcontext(context), vi, logp) +end """ resetlogp!!(vi::AbstractVarInfo) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 1f0641007..494fb0e47 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -115,7 +115,7 @@ probability of `vi` with the returned value. """ function tilde_assume!!(context, right, vn, vi) value, logp, vi = tilde_assume(context, right, vn, vi) - return value, acclogp!!(vi, logp) + return value, acclogp!!(context, vi, logp) end # observe @@ -181,7 +181,7 @@ probability of `vi` with the returned value. """ function tilde_observe!!(context, right, left, vi) logp, vi = tilde_observe(context, right, left, vi) - return left, acclogp!!(vi, logp) + return left, acclogp!!(context, vi, logp) end function assume(rng, spl::Sampler, dist) @@ -383,7 +383,7 @@ Falls back to `dot_tilde_assume(context, right, left, vn, vi)`. """ function dot_tilde_assume!!(context, right, left, vn, vi) value, logp, vi = dot_tilde_assume(context, right, left, vn, vi) - return value, acclogp!!(vi, logp), vi + return value, acclogp!!(context, vi, logp), vi end # `dot_assume` @@ -634,7 +634,7 @@ Falls back to `dot_tilde_observe(context, right, left, vi)`. """ function dot_tilde_observe!!(context, right, left, vi) logp, vi = dot_tilde_observe(context, right, left, vi) - return left, acclogp!!(vi, logp) + return left, acclogp!!(context, vi, logp) end # Falls back to non-sampler definition. diff --git a/src/utils.jl b/src/utils.jl index ca068b1dc..ae79a7792 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -70,7 +70,9 @@ true """ macro addlogprob!(ex) return quote - $(esc(:(__varinfo__))) = acclogp!!($(esc(:(__varinfo__))), $(esc(ex))) + $(esc(:(__varinfo__))) = acclogp!!( + $(esc(:(__context__))), $(esc(:(__varinfo__))), $(esc(ex)) + ) end end