Skip to content

Commit

Permalink
Rename internal variables (#225)
Browse files Browse the repository at this point in the history
This PR contains only a cosmetic change and makes the internal variable names consistent with the convention used by other packages, such as Julia base (`__module__` and `__source__`) and Zygote (`__context__`).

It is not strictly necessary to deprecate the current variable names since they are not exported but it seemed reasonable since probably at least `_varinfo` is known and used.

Co-authored-by: David Widmann <[email protected]>
  • Loading branch information
devmotion and devmotion committed Apr 10, 2021
1 parent d2678d5 commit 1f2f160
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 37 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.10.13"
version = "0.10.14"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down
43 changes: 29 additions & 14 deletions src/compiler.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
const DISTMSG = "Right-hand side of a ~ must be subtype of Distribution or a vector of " *
"Distributions."

const INTERNALNAMES = (:_model, :_sampler, :_context, :_varinfo, :_rng)
const INTERNALNAMES = (:__model__, :__sampler__, :__context__, :__varinfo__, :__rng__)
const DEPRECATED_INTERNALNAMES = (:_model, :_sampler, :_context, :_varinfo, :_rng)

"""
isassumption(expr)
Expand All @@ -24,7 +25,7 @@ function isassumption(expr::Union{Symbol, Expr})
let $vn = $(varname(expr))
# This branch should compile nicely in all cases except for partial missing data
# For example, when `expr` is `:(x[i])` and `x isa Vector{Union{Missing, Float64}}`
if !$(DynamicPPL.inargnames)($vn, _model) || $(DynamicPPL.inmissings)($vn, _model)
if !$(DynamicPPL.inargnames)($vn, __model__) || $(DynamicPPL.inmissings)($vn, __model__)
true
else
# Evaluate the LHS
Expand Down Expand Up @@ -167,10 +168,20 @@ generate_mainbody(mod, expr, warn) = generate_mainbody!(mod, Symbol[], expr, war

generate_mainbody!(mod, found, x, warn) = x
function generate_mainbody!(mod, found, sym::Symbol, warn)
if sym in DEPRECATED_INTERNALNAMES
newsym = Symbol(:_, sym, :__)
Base.depwarn(
"internal variable `$sym` is deprecated, use `$newsym` instead.",
:generate_mainbody!,
)
return generate_mainbody!(mod, found, newsym, warn)
end

if warn && sym in INTERNALNAMES && sym found
@warn "you are using the internal variable `$(sym)`"
@warn "you are using the internal variable `$sym`"
push!(found, sym)
end

return sym
end
function generate_mainbody!(mod, found, expr::Expr, warn)
Expand Down Expand Up @@ -228,18 +239,20 @@ function generate_tilde(left, right)
$isassumption = $(DynamicPPL.isassumption(left))
if $isassumption
$left = $(DynamicPPL.tilde_assume)(
_rng, _context, _sampler, $tmpright, $vn, $inds, _varinfo)
__rng__, __context__, __sampler__, $tmpright, $vn, $inds, __varinfo__
)
else
$(DynamicPPL.tilde_observe)(
_context, _sampler, $tmpright, $left, $vn, $inds, _varinfo)
__context__, __sampler__, $tmpright, $left, $vn, $inds, __varinfo__
)
end
end
end

# If the LHS is a literal, it is always an observation
return quote
$(top...)
$(DynamicPPL.tilde_observe)(_context, _sampler, $tmpright, $left, _varinfo)
$(DynamicPPL.tilde_observe)(__context__, __sampler__, $tmpright, $left, __varinfo__)
end
end

Expand All @@ -263,18 +276,20 @@ function generate_dot_tilde(left, right)
$isassumption = $(DynamicPPL.isassumption(left)) || $left === missing
if $isassumption
$left .= $(DynamicPPL.dot_tilde_assume)(
_rng, _context, _sampler, $tmpright, $left, $vn, $inds, _varinfo)
__rng__, __context__, __sampler__, $tmpright, $left, $vn, $inds, __varinfo__
)
else
$(DynamicPPL.dot_tilde_observe)(
_context, _sampler, $tmpright, $left, $vn, $inds, _varinfo)
__context__, __sampler__, $tmpright, $left, $vn, $inds, __varinfo__
)
end
end
end

# If the LHS is a literal, it is always an observation
return quote
$(top...)
$(DynamicPPL.dot_tilde_observe)(_context, _sampler, $tmpright, $left, _varinfo)
$(DynamicPPL.dot_tilde_observe)(__context__, __sampler__, $tmpright, $left, __varinfo__)
end
end

Expand All @@ -298,11 +313,11 @@ function build_output(modelinfo, linenumbernode)
# Add the internal arguments to the user-specified arguments (positional + keywords).
evaluatordef[:args] = vcat(
[
:(_rng::$(Random.AbstractRNG)),
:(_model::$(DynamicPPL.Model)),
:(_varinfo::$(DynamicPPL.AbstractVarInfo)),
:(_sampler::$(DynamicPPL.AbstractSampler)),
:(_context::$(DynamicPPL.AbstractContext)),
:(__rng__::$(Random.AbstractRNG)),
:(__model__::$(DynamicPPL.Model)),
:(__varinfo__::$(DynamicPPL.AbstractVarInfo)),
:(__sampler__::$(DynamicPPL.AbstractSampler)),
:(__context__::$(DynamicPPL.AbstractContext)),
],
modelinfo[:allargs_exprs],
)
Expand Down
2 changes: 1 addition & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ Add the result of the evaluation of `ex` to the joint log probability.
"""
macro addlogprob!(ex)
return quote
acclogp!($(esc(:(_varinfo))), $(esc(ex)))
acclogp!($(esc(:(__varinfo__))), $(esc(ex)))
end
end

Expand Down
34 changes: 17 additions & 17 deletions test/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -172,12 +172,12 @@ end
# Test use of internal names
@model function testmodel_missing3(x)
x[1] ~ Bernoulli(0.5)
global varinfo_ = _varinfo
global sampler_ = _sampler
global model_ = _model
global context_ = _context
global rng_ = _rng
global lp = getlogp(_varinfo)
global varinfo_ = __varinfo__
global sampler_ = __sampler__
global model_ = __model__
global context_ = __context__
global rng_ = __rng__
global lp = getlogp(__varinfo__)
return x
end
model = testmodel_missing3([1.0])
Expand All @@ -192,12 +192,12 @@ end
# disable warnings
@model function testmodel_missing4(x)
x[1] ~ Bernoulli(0.5)
global varinfo_ = _varinfo
global sampler_ = _sampler
global model_ = _model
global context_ = _context
global rng_ = _rng
global lp = getlogp(_varinfo)
global varinfo_ = __varinfo__
global sampler_ = __sampler__
global model_ = __model__
global context_ = __context__
global rng_ = __rng__
global lp = getlogp(__varinfo__)
return x
end false
lpold = lp
Expand Down Expand Up @@ -236,7 +236,7 @@ end
function makemodel(p)
@model function testmodel(x)
x[1] ~ Bernoulli(p)
global lp = getlogp(_varinfo)
global lp = getlogp(__varinfo__)
return x
end
return testmodel
Expand Down Expand Up @@ -295,11 +295,11 @@ end

@testset "macros within model" begin
# Macro expansion
@model function demo()
@model function demo1()
@mymodel1(y ~ Uniform())
end

@test haskey(VarInfo(demo()), @varname(x))
@test haskey(VarInfo(demo1()), @varname(x))

# Interpolation
# Will fail if:
Expand All @@ -308,9 +308,9 @@ end
# 2. `@mymodel` is expanded before entire `@model` has been
# expanded => errors since `MyModelStruct` is not a distribution,
# and hence `tilde_observe` errors.
@model function demo()
@model function demo2()
$(@mymodel2(y ~ Uniform()))
end
@test demo()() == 42
@test demo2()() == 42
end
end
4 changes: 2 additions & 2 deletions test/threadsafe.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
x = rand(10_000)

@model function wthreads(x)
global vi_ = _varinfo
global vi_ = __varinfo__
x[1] ~ Normal(0, 1)
Threads.@threads for i in 2:length(x)
x[i] ~ Normal(x[i-1], 1)
Expand Down Expand Up @@ -70,7 +70,7 @@
SampleFromPrior(), DefaultContext())

@model function wothreads(x)
global vi_ = _varinfo
global vi_ = __varinfo__
x[1] ~ Normal(0, 1)
for i in 2:length(x)
x[i] ~ Normal(x[i-1], 1)
Expand Down
4 changes: 2 additions & 2 deletions test/utils.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
@testset "utils.jl" begin
@testset "addlogprob!" begin
@model function testmodel()
global lp_before = getlogp(_varinfo)
global lp_before = getlogp(__varinfo__)
@addlogprob!(42)
global lp_after = getlogp(_varinfo)
global lp_after = getlogp(__varinfo__)
end

model = testmodel()
Expand Down

2 comments on commit 1f2f160

@devmotion
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/33999

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.10.14 -m "<description of version>" 1f2f16055abe8a86b0f2f45d277026a13433cd6e
git push origin v0.10.14

Please sign in to comment.