Skip to content

Commit

Permalink
Add closed-form posterior for gamma-poisson observation model
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf authored and brandonwillard committed Sep 6, 2022
1 parent e0336b6 commit e579992
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 4 deletions.
88 changes: 87 additions & 1 deletion aemcmc/conjugates.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,99 @@
from aesara.graph.rewriting.basic import in2out, node_rewriter
from aesara.graph.rewriting.db import LocalGroupDB
from aesara.graph.rewriting.unify import eval_if_etuple
from aesara.tensor.random.basic import BinomialRV
from aesara.tensor.random.basic import BinomialRV, PoissonRV
from etuples import etuple, etuplize
from kanren import eq, lall, run
from unification import var

from aemcmc.rewriting import sampler_finder_db


def gamma_poisson_conjugateo(observed_val, observed_rv_expr, posterior_expr):
r"""Produce a goal that represents the application of Bayes theorem
for a beta prior with a binomial observation model.
.. math::
\frac{
Y \sim \operatorname{Poisson}\left(\lambda\right), \quad
\lambda \sim \operatorname{Gamma}\left(\alpha, \beta\right)
}{
\left(\lambda|Y=y\right) \sim \operatorname{Gamma}\left(\alpha+y, \beta+1\right)
}
Parameters
----------
observed_val
The observed value.
observed_rv_expr
An expression that represents the observed variable.
posterior_exp
An expression that represents the posterior distribution of the latent
variable.
"""
# Gamma-poisson observation model
alpha_lv, beta_lv = var(), var()
z_rng_lv = var()
z_size_lv = var()
z_type_idx_lv = var()
z_et = etuple(
etuplize(at.random.gamma), z_rng_lv, z_size_lv, z_type_idx_lv, alpha_lv, beta_lv
)
Y_et = etuple(etuplize(at.random.poisson), var(), var(), var(), z_et)

# Posterior distribution for p
new_alpha_et = etuple(etuplize(at.add), alpha_lv, observed_val)
new_beta_et = etuple(etuplize(at.add), beta_lv, 1)
z_posterior_et = etuple(
etuplize(at.random.gamma),
new_alpha_et,
new_beta_et,
rng=z_rng_lv,
size=z_size_lv,
dtype=z_type_idx_lv,
)

return lall(
eq(observed_rv_expr, Y_et),
eq(posterior_expr, z_posterior_et),
)


@node_rewriter([PoissonRV])
def local_gamma_poisson_posterior(fgraph, node):

sampler_mappings = getattr(fgraph, "sampler_mappings", None)

rv_var = node.outputs[1]
key = ("local_gamma_poisson_posterior", rv_var)

if sampler_mappings is None or key in sampler_mappings.rvs_seen:
return None # pragma: no cover

q = var()

rv_et = etuplize(rv_var)

res = run(None, q, gamma_poisson_conjugateo(rv_var, rv_et, q))
res = next(res, None)

if res is None:
return None # pragma: no cover

gamma_rv = rv_et[-1].evaled_obj
gamma_posterior = eval_if_etuple(res)

sampler_mappings.rvs_to_samplers.setdefault(gamma_rv, []).append(
("local_gamma_poisson_posterior", gamma_posterior, None)
)
sampler_mappings.rvs_seen.add(key)

return rv_var.owner.outputs


def beta_binomial_conjugateo(observed_val, observed_rv_expr, posterior_expr):
r"""Produce a goal that represents the application of Bayes theorem
for a beta prior with a binomial observation model.
Expand Down Expand Up @@ -101,6 +186,7 @@ def local_beta_binomial_posterior(fgraph, node):
conjugates_db = LocalGroupDB(apply_all_rewrites=True)
conjugates_db.name = "conjugates_db"
conjugates_db.register("beta_binomial", local_beta_binomial_posterior, "basic")
conjugates_db.register("gamma_poisson", local_gamma_poisson_posterior, "basic")

sampler_finder_db.register(
"conjugates", in2out(conjugates_db.query("+basic"), name="gibbs"), "basic"
Expand Down
22 changes: 20 additions & 2 deletions tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
from aesara.graph.basic import graph_inputs, io_toposort
from aesara.ifelse import IfElse
from aesara.tensor.random import RandomStream
from aesara.tensor.random.basic import BetaRV
from aesara.tensor.random.basic import BetaRV, GammaRV
from scipy.linalg import toeplitz

from aemcmc.basic import construct_sampler
from aemcmc.rewriting import SubsumingElemwise


def test_closed_form_posterior():
def test_closed_form_posterior_beta_binomial():
srng = RandomStream(0)

alpha_tt = at.scalar("alpha")
Expand All @@ -31,6 +31,24 @@ def test_closed_form_posterior():
assert isinstance(p_posterior_step.owner.op, BetaRV)


def test_closed_form_posterior_gamma_poisson():
srng = RandomStream(0)

alpha_tt = at.scalar("alpha")
beta_tt = at.scalar("beta")
l_rv = srng.gamma(alpha_tt, beta_tt, name="p")

Y_rv = srng.poisson(l_rv, name="Y")

y_vv = Y_rv.clone()
y_vv.name = "y"

sample_steps, updates, initial_values = construct_sampler({Y_rv: y_vv}, srng)

p_posterior_step = sample_steps[l_rv]
assert isinstance(p_posterior_step.owner.op, GammaRV)


def test_no_samplers():
srng = RandomStream(0)

Expand Down
45 changes: 44 additions & 1 deletion tests/test_conjugates.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,50 @@
from kanren import run
from unification import var

from aemcmc.conjugates import beta_binomial_conjugateo
from aemcmc.conjugates import beta_binomial_conjugateo, gamma_poisson_conjugateo


def test_gamma_poisson_conjugate_contract():
"""Produce the closed-form posterior for the poisson observation model with
a gamma prior.
"""
srng = RandomStream(0)

alpha_tt = at.scalar("alpha")
beta_tt = at.scalar("beta")
z_rv = srng.gamma(alpha_tt, beta_tt)

Y_rv = srng.poisson(z_rv)
y_vv = Y_rv.clone()
y_vv.name = "y"

q_lv = var()
(posterior_expr,) = run(1, q_lv, gamma_poisson_conjugateo(y_vv, Y_rv, q_lv))
posterior = eval_if_etuple(posterior_expr)
aesara.dprint(posterior)

assert isinstance(posterior.owner.op, type(at.random.gamma))


@pytest.mark.xfail(
reason="Op.__call__ does not dispatch to Op.make_node for some RandomVariable and etuple evaluation returns an error"
)
def test_gamma_poisson_conjugate_expand():
"""Expand a contracted beta-binomial observation model."""

srng = RandomStream(0)

alpha_tt = at.scalar("alpha")
beta_tt = at.scalar("beta")
y_vv = at.iscalar("y")
Y_rv = srng.gamma(alpha_tt + y_vv, beta_tt + 1)

e_lv = var()
(expanded_expr,) = run(1, e_lv, gamma_poisson_conjugateo(e_lv, y_vv, Y_rv))
expanded = eval_if_etuple(expanded_expr)

assert isinstance(expanded.owner.op, type(at.random.gamma))


def test_beta_binomial_conjugate_contract():
Expand Down

0 comments on commit e579992

Please sign in to comment.