From 34be85c6371cf4d0fe9f9db5c74a3fa7527452f7 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 20 Nov 2023 23:59:30 +0000 Subject: [PATCH] Also pass in `context` as an argument to `acclogp!!` (#563) * also pass in `context` as an argument to `acclogp!!` * Update src/utils.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * bump patch version --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- Project.toml | 2 +- src/abstract_varinfo.jl | 12 ++++++++++-- src/context_implementations.jl | 8 ++++---- src/utils.jl | 4 +++- 4 files changed, 18 insertions(+), 8 deletions(-) diff --git a/Project.toml b/Project.toml index 4f4a5ecaa..4221d64fe 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.24.1" +version = "0.24.2" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index 67c2f3fcb..fe2c0b3e5 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -107,12 +107,20 @@ Set the log of the joint probability of the observed data and parameters sampled function setlogp!! end """ - acclogp!!(vi::AbstractVarInfo, logp) + acclogp!!([context::AbstractContext, ]vi::AbstractVarInfo, logp) Add `logp` to the value of the log of the joint probability of the observed data and parameters sampled in `vi`, mutating if it makes sense. """ -function acclogp!! end +function acclogp!!(context::AbstractContext, vi::AbstractVarInfo, logp) + return acclogp!!(NodeTrait(context), context, vi, logp) +end +function acclogp!!(::IsLeaf, context::AbstractContext, vi::AbstractVarInfo, logp) + return acclogp!!(vi, logp) +end +function acclogp!!(::IsParent, context::AbstractContext, vi::AbstractVarInfo, logp) + return acclogp!!(childcontext(context), vi, logp) +end """ resetlogp!!(vi::AbstractVarInfo) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 1f0641007..494fb0e47 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -115,7 +115,7 @@ probability of `vi` with the returned value. """ function tilde_assume!!(context, right, vn, vi) value, logp, vi = tilde_assume(context, right, vn, vi) - return value, acclogp!!(vi, logp) + return value, acclogp!!(context, vi, logp) end # observe @@ -181,7 +181,7 @@ probability of `vi` with the returned value. """ function tilde_observe!!(context, right, left, vi) logp, vi = tilde_observe(context, right, left, vi) - return left, acclogp!!(vi, logp) + return left, acclogp!!(context, vi, logp) end function assume(rng, spl::Sampler, dist) @@ -383,7 +383,7 @@ Falls back to `dot_tilde_assume(context, right, left, vn, vi)`. """ function dot_tilde_assume!!(context, right, left, vn, vi) value, logp, vi = dot_tilde_assume(context, right, left, vn, vi) - return value, acclogp!!(vi, logp), vi + return value, acclogp!!(context, vi, logp), vi end # `dot_assume` @@ -634,7 +634,7 @@ Falls back to `dot_tilde_observe(context, right, left, vi)`. """ function dot_tilde_observe!!(context, right, left, vi) logp, vi = dot_tilde_observe(context, right, left, vi) - return left, acclogp!!(vi, logp) + return left, acclogp!!(context, vi, logp) end # Falls back to non-sampler definition. diff --git a/src/utils.jl b/src/utils.jl index ca068b1dc..ae79a7792 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -70,7 +70,9 @@ true """ macro addlogprob!(ex) return quote - $(esc(:(__varinfo__))) = acclogp!!($(esc(:(__varinfo__))), $(esc(ex))) + $(esc(:(__varinfo__))) = acclogp!!( + $(esc(:(__context__))), $(esc(:(__varinfo__))), $(esc(ex)) + ) end end