Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FixedContext and ConditionedContext don't use the same varnames as tilde-pipeline #702

Open
penelopeysm opened this issue Oct 29, 2024 · 15 comments

Comments

@penelopeysm
Copy link
Member

using DynamicPPL
using Turing

@model function f()
    m = Vector{Float64}(undef, 1)
    m .~ Normal()
    s ~ Normal()
    return (; s=s, m=m)
end
model = f()
chain = sample(model, Prior(), 2)

DynamicPPL.generated_quantities(model, chain)

The call to generated_quantities always returns the values of s in the chain, as expected, but each time generated_quantities is run, a new vector m is generated.

julia> DynamicPPL.generated_quantities(model, chain)
2×1 Matrix{@NamedTuple{s::Float64, m::Vector{Float64}}}:
 (s = 1.632189074395471, m = [0.8057272304035911])
 (s = -0.8755799731449666, m = [-0.6282480660060888])

julia> DynamicPPL.generated_quantities(model, chain)
2×1 Matrix{@NamedTuple{s::Float64, m::Vector{Float64}}}:
 (s = 1.632189074395471, m = [-0.9533966900128452])
 (s = -0.8755799731449666, m = [-1.081695147997058])

julia> DynamicPPL.generated_quantities(model, chain)
2×1 Matrix{@NamedTuple{s::Float64, m::Vector{Float64}}}:
 (s = 1.632189074395471, m = [-0.9870848372495151])
 (s = -0.8755799731449666, m = [1.6715732215128585])
@penelopeysm
Copy link
Member Author

penelopeysm commented Oct 29, 2024

@penelopeysm
Copy link
Member Author

penelopeysm commented Oct 29, 2024

It seems that this error arises because here

fixed_model = DynamicPPL.fix(model, Dict(varname_pairs))

we have that

Dict(varname_pairs) = Dict{VarName, Float64}(m[1] => 0.02111050094889288, s => -1.8086057102168398)

As can be seen, the FixedContext here contains the varname m[1].

However, when the fixed model is evaluated, it calls isfixed with the left-hand side of the tilde i.e. m

if $(DynamicPPL.isfixed(left, vn))
$left = $(DynamicPPL.getfixed_nested)(__context__, $vn)

and thus hasfixed(context, vn) returns false.

hasfixed(context::FixedContext, vn::VarName) = hasvalue(context.values, vn)


It seems that the call to DynamicPPL.nested_setindex_maybe! was intended to address this issue:

vn_parents = Iterators.map(vns) do vn
# The call nested_setindex_maybe! is used to handle cases where vn is not
# the variable name used in the model, but rather subsumed by one. Except
# for the subsumption part, this could be
# vn => getindex_varname(chain, sample_idx, vn, chain_idx)
# TODO(mhauru) This call to nested_setindex_maybe! is unintuitive.
DynamicPPL.nested_setindex_maybe!(
varinfo, DynamicPPL.getindex_varname(chain, sample_idx, vn, chain_idx), vn

Unfortunately, it doesn't :( In nested_setindex_maybe, there's this call to getmetadata(vi, vn). Here vn is m[1], because it's been extracted from the chain.

DynamicPPL.jl/src/varinfo.jl

Lines 1608 to 1616 in 18af48a

function nested_setindex_maybe!(
vi::VarInfo{<:NamedTuple{names}}, val, vn::VarName{sym}
) where {names,sym}
return if sym in names
_nested_setindex_maybe!(vi, getmetadata(vi, vn), val, vn)
else
nothing
end
end

and so the inner function _nested_setindex_maybe! receives a metadata that only contains m[1] rather than m itself. This means that it returns in the first if-clause and doesn't actually extract the parent varname.

DynamicPPL.jl/src/varinfo.jl

Lines 1617 to 1627 in 18af48a

function _nested_setindex_maybe!(
vi::VarInfo, md::Union{Metadata,VarNamedVector}, val, vn::VarName
)
# If `vn` is in `vns`, then we can just use the standard `setindex!`.
vns = Base.keys(md)
if vn in vns
setindex!(vi, val, vn)
return vn
end
# Otherwise, we need to check if either of the `vns` subsumes `vn`.

@mhauru
Copy link
Member

mhauru commented Oct 30, 2024

I remember finding the nested_setindex_maybe! thing very hard to understand when I was writing/editing that code. I guess this is that complexity coming back to bite us.

I would consider this a bug of FixedContext:

julia> using Distributions

julia> using DynamicPPL

julia> @model function f()
           m = Vector{Float64}(undef, 1)
           m .~ Normal()
           s ~ Normal()
           return (; s=s, m=m)
       end
f (generic function with 2 methods)

julia> model = f()
Model{typeof(f), (), (), (), Tuple{}, Tuple{}, DefaultContext}(f, NamedTuple(), NamedTuple(), DefaultContext())

julia> fix(model, (@varname(m[1]) => 0.1))()
(s = 0.07793110036823493, m = [0.8732131942573598])

@penelopeysm
Copy link
Member Author

Maybe m needs to be resampled first, and then for any variables vn in the FixedContext, if m subsumes vn, then the value of vn needs to be set to its fixed value.

@penelopeysm
Copy link
Member Author

penelopeysm commented Oct 30, 2024

ConditionContext has the same issue

julia> condition(model, (@varname(s) => 1, @varname(m[1]) => 0.1))()
(s = 1, m = [0.199933954170371])

julia> condition(model, (@varname(s) => 1, @varname(m[1]) => 0.1))()
(s = 1, m = [1.5566785471447055])

@torfjelde
Copy link
Member

Ooooo nice catch 👍 We might need a call to unwrap_right_left_vns before the is-fixed check

@penelopeysm
Copy link
Member Author

@torfjelde, is there a reason why we use isfixed() rather than overloading assume etc. for FixedContext?

@torfjelde
Copy link
Member

My comment was a bit out-of-sync as I was sending it while traveling.

I remember finding the nested_setindex_maybe! thing very hard to understand when I was writing/editing that code. I guess this is that complexity coming back to bite us.

AFAIK this isn't an issue with nested_setindex_maybe!, but rather than the fact that we're using different vns in the tilde-pipeline than for condition and fix.

is there a reason why we use isfixed() rather than overloading assume etc. for FixedContext?

It's two-fold:

  1. The assume part of the tilde-pipeline has the semantic notion that what we're working with is a random variable. This means that a context overloading, say, tilde_assume expects to be working with something that we consider a random variable. This could result in strange behavior if we're actually working with a fixed variable.
  2. Performance. Unfortunately, in some cases, the tilde-pipeline can result in performance decreases. Putting the fixing at the top-level of a @model avoids these and should therefore have positive performance implications.

@torfjelde
Copy link
Member

Ooooo nice catch 👍 We might need a call to unwrap_right_left_vns before the is-fixed check

However, this will have performance implications, unfortunately 😕

Looking at the following lines

return quote
$vn = $(DynamicPPL.resolve_varnames)(
$(AbstractPPL.drop_escape(varname(left, need_concretize(left)))), $right
)
$isassumption = $(DynamicPPL.isassumption(left, vn))
if $(DynamicPPL.isfixed(left, vn))
$left .= $(DynamicPPL.getfixed_nested)(__context__, $vn)
elseif $isassumption
$(generate_dot_tilde_assume(left, right, vn))
else
# If `vn` is not in `argnames`, we need to make sure that the variable is defined.
if !$(DynamicPPL.inargnames)($vn, __model__)
$left .= $(DynamicPPL.getconditioned_nested)(__context__, $vn)
end
$value, __varinfo__ = $(DynamicPPL.dot_tilde_observe!!)(
__context__,
$(DynamicPPL.check_tilde_rhs)($right),
$(maybe_view(left)),
$vn,
__varinfo__,
)
$value
end
end

the is_fixed and is_conditioned checks work with vn. In contrast, the tilde-assume pipeline works with the "unwrapped" varnames; in the above we call generate_dot_tilde_assume for assume statements

function generate_dot_tilde_assume(left, right, vn)
# We don't need to use `Setfield.@set` here since
# `.=` is always going to be inplace + needs `left` to
# be something that supports `.=`.
@gensym value
return quote
$value, __varinfo__ = $(DynamicPPL.dot_tilde_assume!!)(
__context__,
$(DynamicPPL.unwrap_right_left_vns)(
$(DynamicPPL.check_tilde_rhs)($right), $(maybe_view(left)), $vn
)...,
__varinfo__,
)
$left .= $value
$value
end
end

which, as we see, calls unwrap_right_left_vns (which in turn produces the corresponding array of varnames).

Now, to fix the aforementioned issue, we could make the is_fixed and is_conditioned checks act on the result of unwrap_right_left_vns instead of vn, which would indeed fix this issue.

However, this would potentially result in these checks becoming run-time instead of compile-time 😬

@penelopeysm penelopeysm changed the title generated_quantities resamples variables which are dot_assumed FixedContext and ConditionedContext don't use the same varnames as tilde-pipeline Oct 31, 2024
@penelopeysm
Copy link
Member Author

Semi-related but there's also this behaviour:

julia> using Distributions, DynamicPPL, Test

julia> @model function f()
           m = Vector{Float64}(undef, 2)
           m .~ Normal()
           s ~ Normal()
           return (; s=s, m=m)
       end
f (generic function with 2 methods)

julia> fix(model, (@varname(m) => [0.1]))()
(s = -0.2513270918863368, m = [0.1, 0.1])

IMO this should error.

@penelopeysm
Copy link
Member Author

penelopeysm commented Nov 1, 2024

@torfjelde I have been looking at this a little bit, and I don't entirely understand what you mean with unwrap_right_left_vns. Let's take a case like this where part of the array is fixed and part of it isn't (I realise this is a bit diabolical, but I suppose we should attempt to handle the general case)

@model function f()
    m = Vector{Float64}(undef, 2)
    m .~ Normal()
    s ~ Normal()
    return (; s=s, m=m)
end

# The following should return (s = (something), m = [0.1, (something else)])
fix(f(), (@varname(m[1]) => 0.1))()

In such a case, contextual_isfixed(__context__, @varname(m[1])) needs to return True, which we could do with something like this

_, _, $vns = $(DynamicPPL.unwrap_right_left_vns)($right, $left, $vn)
for $v in $vns
    if $(DynamicPPL.isfixed)($left, $v)
        # we don't actually want $left here, we want only the bit of $left that corresponds to $v
        # (but I don't know how to do that)
        $left = $(DynamicPPL.getfixed_nested)(__context__, $v)
    else
        # normal tilde pipeline, which is where `m[2]` should go to
    end
end

I believe this would miss the opposite case where we actually fix the whole array, though:

fix(f(), (@varname(m) => [0.1, 0.2]))()

because now isfixed is being run against the 'split up' varnames m[1] and m[2] rather than m, which is in the FixedContext. Maybe we could put in a call to subsumes somewhere here?

It seems to me that we can use unwrap_right_left_vns in the later generate_dot_tilde_assume function, because if we reach that point we already know that all of the constituent vns are not fixed, so all of them should go to the tilde pipeline. But at this earlier stage, we need to unwrap and then getfixed some of them and assume some others, which I don't quite know how to handle nicely.

Maybe I'm approaching this the wrong way? This is the first time I'm digging into the compiler, after all. Let me know 😄

I would have thought that the simplest solution would be to overload the tilde pipeline for FixedContext and ConditionContext. That way we can still send all the vns down the tilde pipeline but they will come back with their appropriate values. I didn't yet try to do this, so haven't quite realised what weird behaviour I will find if I attempted to do that (or indeed if it's possible :))

@torfjelde
Copy link
Member

IMO this should error.

I don't actually mind this not erroring since it's just following the .= behavior that is standard in Julia 🤷

Let's take a case like this where part of the array is fixed and part of it isn't (I realise this is a bit diabolical, but I suppose we should attempt to handle the general case)

Though I agree it would be nice to support this in general, it'll be something of an hassle to support and I'm somewhat skeptical it's worth the maintanence burden.

which we could do with something like this

If we do that, then there's no point in having the dot_tilde_* pipeline at all since we're just iterating over the varnames, right? This removes the possibility of performing vectorized operations in the dot_tilde_* pipeline, e.g. vectorized logpdf computation that currently happens in dot_assume. It will also be non-trivial to implement because we might have statements such as m .~ [Normal()] which is also completely valid under .= semantics in your above example.

I would have thought that the simplest solution would be to overload the tilde pipeline for FixedContext and ConditionContext. That way we can still send all the vns down the tilde pipeline but they will come back with their appropriate values. I didn't yet try to do this, so haven't quite realised what weird behaviour I will find if I attempted to do that (or indeed if it's possible :))

As I mentioned above, this is technically possible and might be the way to go. However, it does mean that we don't have a "nice" distinction anymore that tilde_assume is only for things that are considered to be random variables, and other contexts affecting tilde_assume might be assuming so, and thus performing operations on either the value or varinfo in a way that is not longer consistent with the idea of the variable being "fixed".

@torfjelde
Copy link
Member

Also, worth pointing out that this issue was introduced in #555 . Specifically, the change in the ext/DynamicPPLMCMCChainsExt.jl, where we moved to making use of fix. This is probably something I should have caught in my review, as I was aware that fix didn't support the case brought up in this issue 😕

@torfjelde
Copy link
Member

One thing we could do is to overload the the tilde-pipeline for FixedContext to function as a fallback for the case when the initial is_fixed check fails 🤷 This way we can still get the performance benefits when possible, while specifically trying to resolve issues with the dot-tilde stuff if the first check fails.

I don't like the inconsistency that something like this introudces, but it might be worth it.

@torfjelde
Copy link
Member

#710 inspects all of this a bit further with a possible way to address this issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants