diff --git a/Project.toml b/Project.toml index a81112d74..78acb2566 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.27.2" +version = "0.28" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index 007dfef11..8935edc12 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -49,7 +49,7 @@ struct LogDensityFunction{V,M,C} varinfo::V "model used for evaluation" model::M - "context used for evaluation" + "context used for evaluation; if `nothing`, `leafcontext(model.context)` will be used when applicable" context::C end @@ -66,15 +66,20 @@ end function LogDensityFunction( model::Model, varinfo::AbstractVarInfo=VarInfo(model), - context::AbstractContext=model.context, + context::Union{Nothing,AbstractContext}=nothing, ) return LogDensityFunction(varinfo, model, context) end +# If a `context` has been specified, we use that. Otherwise we just use the leaf context of `model`. +function getcontext(f::LogDensityFunction) + return f.context === nothing ? leafcontext(f.model.context) : f.context +end + # HACK: heavy usage of `AbstractSampler` for, well, _everything_, is being phased out. In the mean time # we need to define these annoying methods to ensure that we stay compatible with everything. -getsampler(f::LogDensityFunction) = getsampler(f.context) -hassampler(f::LogDensityFunction) = hassampler(f.context) +getsampler(f::LogDensityFunction) = getsampler(getcontext(f)) +hassampler(f::LogDensityFunction) = hassampler(getcontext(f)) _get_indexer(ctx::AbstractContext) = _get_indexer(NodeTrait(ctx), ctx) _get_indexer(ctx::SamplingContext) = ctx.sampler @@ -86,12 +91,13 @@ _get_indexer(::IsLeaf, ctx::AbstractContext) = Colon() Return the parameters of the wrapped varinfo as a vector. """ -getparams(f::LogDensityFunction) = f.varinfo[_get_indexer(f.context)] +getparams(f::LogDensityFunction) = f.varinfo[_get_indexer(getcontext(f))] # LogDensityProblems interface function LogDensityProblems.logdensity(f::LogDensityFunction, θ::AbstractVector) - vi_new = unflatten(f.varinfo, f.context, θ) - return getlogp(last(evaluate!!(f.model, vi_new, f.context))) + context = getcontext(f) + vi_new = unflatten(f.varinfo, context, θ) + return getlogp(last(evaluate!!(f.model, vi_new, context))) end function LogDensityProblems.capabilities(::Type{<:LogDensityFunction}) return LogDensityProblems.LogDensityOrder{0}()