diff --git a/Project.toml b/Project.toml index 96ac7dbe0..0cf96de8f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.25.4" +version = "0.26.0" [deps] diff --git a/src/compiler.jl b/src/compiler.jl index 898acad10..55adc534c 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -602,10 +602,42 @@ hasmissing(::Type{>:Missing}) = true hasmissing(::Type{<:AbstractArray{TA}}) where {TA} = hasmissing(TA) hasmissing(::Type{Union{}}) = false # issue #368 +""" + TypeWrap{T} + +A wrapper type used internally to make expressions such as `::Type{TV}` in the model arguments +not ending up as a `DataType`. +""" +struct TypeWrap{T} end + +function arg_type_is_type(e) + return Meta.isexpr(e, :curly) && length(e.args) > 1 && e.args[1] === :Type +end + function splitarg_to_expr((arg_name, arg_type, is_splat, default)) return is_splat ? :($arg_name...) : arg_name end +""" + transform_args(args) + +Return transformed `args` used in both the model constructor and evaluator. + +Specifically, this replaces expressions of the form `::Type{TV}=Vector{Float64}` +with `::TypeWrap{TV}=TypeWrap{Vector{Float64}}()` to avoid introducing `DataType`. +""" +function transform_args(args) + splitargs = map(args) do arg + arg_name, arg_type, is_splat, default = MacroTools.splitarg(arg) + return if arg_type_is_type(arg_type) + arg_name, :($TypeWrap{$(arg_type.args[2])}), is_splat, :($TypeWrap{$default}()) + else + arg_name, arg_type, is_splat, default + end + end + return map(Base.splat(MacroTools.combinearg), splitargs) +end + function namedtuple_from_splitargs(splitargs) names = map(splitargs) do (arg_name, arg_type, is_splat, default) is_splat ? Symbol("#splat#$(arg_name)") : arg_name @@ -623,8 +655,12 @@ is_splat_symbol(s::Symbol) = startswith(string(s), "#splat#") Builds the output expression. """ function build_output(modeldef, linenumbernode) - args = modeldef[:args] - kwargs = modeldef[:kwargs] + args = transform_args(modeldef[:args]) + kwargs = transform_args(modeldef[:kwargs]) + + # Need to update `args` and `kwargs` since we might have added `TypeWrap` to the types. + modeldef[:args] = args + modeldef[:kwargs] = kwargs ## Build the anonymous evaluator from the user-provided model definition. evaluatordef = copy(modeldef) @@ -713,9 +749,13 @@ function matchingvalue(sampler, vi, value) return value end end +# If we hit `Type` or `TypeWrap`, we immediately jump to `get_matching_type`. function matchingvalue(sampler::AbstractSampler, vi, value::FloatOrArrayType) return get_matching_type(sampler, vi, value) end +function matchingvalue(sampler::AbstractSampler, vi, value::TypeWrap{T}) where {T} + return TypeWrap{get_matching_type(sampler, vi, T)}() +end function matchingvalue(context::AbstractContext, vi, value) return matchingvalue(NodeTrait(matchingvalue, context), context, vi, value) @@ -731,7 +771,7 @@ function matchingvalue(context::SamplingContext, vi, value) end """ - get_matching_type(spl::AbstractSampler, vi, ::Type{T}) where {T} + get_matching_type(spl::AbstractSampler, vi, ::TypeWrap{T}) where {T} Get the specialized version of type `T` for sampler `spl`. diff --git a/test/compiler.jl b/test/compiler.jl index 9fa36b5ff..f1f06eabe 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -717,4 +717,16 @@ module Issue537 end @test haskey(values, @varname(y)) end end + + @testset "signature parsing + TypeWrap" begin + @model function demo_typewrap( + a, b=1, ::Type{T1}=Float64; c, d=2, t::Type{T2}=Int + ) where {T1,T2} + return (; a, b, c, d, t) + end + + model = demo_typewrap(1; c=2) + res = model() + @test res == (a=1, b=1, c=2, d=2, t=DynamicPPL.TypeWrap{Int}()) + end end diff --git a/test/model.jl b/test/model.jl index b9e62827d..ddf327110 100644 --- a/test/model.jl +++ b/test/model.jl @@ -350,9 +350,7 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true @testset "Type stability of models" begin models_to_test = [ - # FIXME: Fix issues with type-stability in `DEMO_MODELS`. - # DynamicPPL.TestUtils.DEMO_MODELS..., - DynamicPPL.TestUtils.demo_lkjchol(2), + DynamicPPL.TestUtils.DEMO_MODELS..., DynamicPPL.TestUtils.demo_lkjchol(2) ] @testset "$(model.f)" for model in models_to_test vns = DynamicPPL.TestUtils.varnames(model) diff --git a/test/turing/compiler.jl b/test/turing/compiler.jl index 31bd5bbbe..5c46ab777 100644 --- a/test/turing/compiler.jl +++ b/test/turing/compiler.jl @@ -95,8 +95,12 @@ @test_throws ErrorException chain = sample(gauss2(; x=x), PG(10), 10) @test_throws ErrorException chain = sample(gauss2(; x=x), SMC(), 10) - @test_throws ErrorException chain = sample(gauss2(Vector{Float64}; x=x), PG(10), 10) - @test_throws ErrorException chain = sample(gauss2(Vector{Float64}; x=x), SMC(), 10) + @test_throws ErrorException chain = sample( + gauss2(DynamicPPL.TypeWrap{Vector{Float64}}(); x=x), PG(10), 10 + ) + @test_throws ErrorException chain = sample( + gauss2(DynamicPPL.TypeWrap{Vector{Float64}}(); x=x), SMC(), 10 + ) end @testset "new interface" begin obs = [0, 1, 0, 1, 1, 1, 1, 1, 1, 1] @@ -310,10 +314,12 @@ end t_loop = @elapsed res = sample(vdemo1(), alg, 250) - t_loop = @elapsed res = sample(vdemo1(Float64), alg, 250) + t_loop = @elapsed res = sample(vdemo1(DynamicPPL.TypeWrap{Float64}()), alg, 250) vdemo1kw(; T) = vdemo1(T) - t_loop = @elapsed res = sample(vdemo1kw(; T=Float64), alg, 250) + t_loop = @elapsed res = sample( + vdemo1kw(; T=DynamicPPL.TypeWrap{Float64}()), alg, 250 + ) @model function vdemo2(::Type{T}=Float64) where {T<:Real} x = Vector{T}(undef, N) @@ -321,10 +327,12 @@ end t_vec = @elapsed res = sample(vdemo2(), alg, 250) - t_vec = @elapsed res = sample(vdemo2(Float64), alg, 250) + t_vec = @elapsed res = sample(vdemo2(DynamicPPL.TypeWrap{Float64}()), alg, 250) vdemo2kw(; T) = vdemo2(T) - t_vec = @elapsed res = sample(vdemo2kw(; T=Float64), alg, 250) + t_vec = @elapsed res = sample( + vdemo2kw(; T=DynamicPPL.TypeWrap{Float64}()), alg, 250 + ) @model function vdemo3(::Type{TV}=Vector{Float64}) where {TV<:AbstractVector} x = TV(undef, N) @@ -332,9 +340,9 @@ end sample(vdemo3(), alg, 250) - sample(vdemo3(Vector{Float64}), alg, 250) + sample(vdemo3(DynamicPPL.TypeWrap{Vector{Float64}}()), alg, 250) vdemo3kw(; T) = vdemo3(T) - sample(vdemo3kw(; T=Vector{Float64}), alg, 250) + sample(vdemo3kw(; T=DynamicPPL.TypeWrap{Vector{Float64}}()), alg, 250) end end