Skip to content

Commit

Permalink
Require independent dimensions in multivariate size arguments to Rand…
Browse files Browse the repository at this point in the history
…omVariable
  • Loading branch information
brandonwillard committed Feb 4, 2022
1 parent da86d35 commit 2450186
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 47 deletions.
56 changes: 42 additions & 14 deletions aesara/tensor/random/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,14 +321,21 @@ def rng_fn(cls, rng, mean, cov, size):

if mean.ndim > 1 or cov.ndim > 2:
# Neither SciPy nor NumPy implement parameter broadcasting for
# multivariate normals (or many other multivariate distributions),
# so we have implement a quick and dirty one here
# multivariate normals (or any other multivariate distributions),
# so we need to implement that here
mean, cov = broadcast_params([mean, cov], cls.ndims_params)
size = tuple(size or ())

if size:
mean = np.broadcast_to(mean, size + mean.shape)
cov = np.broadcast_to(cov, size + cov.shape)
if (
0 < mean.ndim - 1 <= len(size)
and size[-mean.ndim + 1 :] != mean.shape[:-1]
):
raise ValueError(
"shape mismatch: objects cannot be broadcast to a single shape"
)
mean = np.broadcast_to(mean, size + mean.shape[-1:])
cov = np.broadcast_to(cov, size + cov.shape[-2:])

res = np.empty(mean.shape)
for idx in np.ndindex(mean.shape[:-1]):
Expand All @@ -352,16 +359,33 @@ class DirichletRV(RandomVariable):

@classmethod
def rng_fn(cls, rng, alphas, size):
if size is None:
size = ()
samples_shape = tuple(np.atleast_1d(size)) + alphas.shape
samples = np.empty(samples_shape)
alphas_bcast = np.broadcast_to(alphas, samples_shape)
if alphas.ndim > 1:
if size is None:
size = ()

for index in np.ndindex(*samples_shape[:-1]):
samples[index] = rng.dirichlet(alphas_bcast[index])
size = tuple(np.atleast_1d(size))

return samples
if size:
if (
0 < alphas.ndim - 1 <= len(size)
and size[-alphas.ndim + 1 :] != alphas.shape[:-1]
):
raise ValueError(
"shape mismatch: objects cannot be broadcast to a single shape"
)
samples_shape = size + alphas.shape[-1:]
else:
samples_shape = alphas.shape

samples = np.empty(samples_shape)
alphas_bcast = np.broadcast_to(alphas, samples_shape)

for index in np.ndindex(*samples_shape[:-1]):
samples[index] = rng.dirichlet(alphas_bcast[index])

return samples
else:
return rng.dirichlet(alphas, size=size)


dirichlet = DirichletRV()
Expand Down Expand Up @@ -579,8 +603,12 @@ def rng_fn(cls, rng, n, p, size):
size = tuple(size or ())

if size:
n = np.broadcast_to(n, size + n.shape)
p = np.broadcast_to(p, size + p.shape)
if 0 < p.ndim - 1 <= len(size) and size[-p.ndim + 1 :] != p.shape[:-1]:
raise ValueError(
"shape mismatch: objects cannot be broadcast to a single shape"
)
n = np.broadcast_to(n, size)
p = np.broadcast_to(p, size + p.shape[-1:])

res = np.empty(p.shape, dtype=cls.dtype)
for idx in np.ndindex(p.shape[:-1]):
Expand Down
29 changes: 14 additions & 15 deletions aesara/tensor/random/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,15 +190,14 @@ def _infer_shape(

size_len = get_vector_length(size)

if self.ndim_supp == 0 and size_len > 0:
# In this case, we have a univariate distribution with a non-empty
# `size` parameter, which means that the `size` parameter
# completely determines the shape of the random variable. More
# importantly, the `size` parameter may be the only correct source
# of information for the output shape, in that we would be misled
# by the `dist_params` if we tried to infer the relevant parts of
# the output shape from those.
return size
if size_len > 0:
if self.ndim_supp == 0:
return size
else:
supp_shape = self._shape_from_params(
dist_params, param_shapes=param_shapes
)
return tuple(size) + tuple(supp_shape)

# Broadcast the parameters
param_shapes = params_broadcast_shapes(
Expand Down Expand Up @@ -307,19 +306,19 @@ def make_node(self, rng, size, dtype, *dist_params):
Existing Aesara `Generator` or `RandomState` object to be used. Creates a
new one, if `None`.
size: int or Sequence
Numpy-like size of the output (i.e. replications).
NumPy-like size parameter.
dtype: str
The dtype of the sampled output. If the value ``"floatX"`` is
given, then ``dtype`` is set to ``aesara.config.floatX``. This
value is only used when `self.dtype` isn't set.
given, then `dtype` is set to ``aesara.config.floatX``. This value is
only used when ``self.dtype`` isn't set.
dist_params: list
Distribution parameters.
Results
-------
out: `Apply`
A node with inputs `(rng, size, dtype) + dist_args` and outputs
`(rng_var, out_var)`.
out: Apply
A node with inputs ``(rng, size, dtype) + dist_args`` and outputs
``(rng_var, out_var)``.
"""
size = normalize_size_param(size)
Expand Down
14 changes: 12 additions & 2 deletions aesara/tensor/random/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,19 @@ def local_rv_size_lift(fgraph, node):
if get_vector_length(size) > 0:
dist_params = [
broadcast_to(
p, (tuple(size) + tuple(p.shape)) if node.op.ndim_supp > 0 else size
p,
(
tuple(size)
+ (
tuple(p.shape)[-node.op.ndims_params[i] :]
if node.op.ndims_params[i] > 0
else ()
)
)
if node.op.ndim_supp > 0
else size,
)
for p in dist_params
for i, p in enumerate(dist_params)
]
else:
return
Expand Down
51 changes: 38 additions & 13 deletions tests/tensor/random/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,7 @@ def mvnormal_test_fn(mean=None, cov=None, size=None, random_state=None):
np.eye(3, dtype=config.floatX) * 10.0,
]
),
[2, 3],
[2, 3, 2],
),
(
np.array([[0, 1, 2], [4, 5, 6]], dtype=config.floatX),
Expand All @@ -551,12 +551,12 @@ def mvnormal_test_fn(mean=None, cov=None, size=None, random_state=None):
np.eye(3, dtype=config.floatX) * 10.0,
]
),
[2, 3],
[2, 3, 2, 2],
),
(
np.array([[0], [10], [100]], dtype=config.floatX),
np.eye(1, dtype=config.floatX) * 1e-6,
[2, 3],
[2, 3, 3],
),
],
)
Expand All @@ -567,6 +567,11 @@ def test_mvnormal_samples(mu, cov, size):
def test_mvnormal_default_args():
rv_numpy_tester(multivariate_normal, test_fn=mvnormal_test_fn)

with pytest.raises(ValueError, match="shape mismatch.*"):
multivariate_normal.rng_fn(
None, np.zeros((1, 2)), np.ones((1, 2, 2)), size=(4,)
)


@config.change_flags(compute_test_value="raise")
def test_mvnormal_ShapeFeature():
Expand Down Expand Up @@ -596,11 +601,10 @@ def test_mvnormal_ShapeFeature():
cov = at.as_tensor(test_covar).type()
cov.tag.test_value = test_covar

d_rv = multivariate_normal(mean, cov, size=[2, 3])
d_rv = multivariate_normal(mean, cov, size=[2, 3, 2])

fg = FunctionGraph(
[i for i in graph_inputs([d_rv]) if not isinstance(i, Constant)],
[d_rv],
outputs=[d_rv],
clone=False,
features=[ShapeFeature()],
)
Expand All @@ -617,10 +621,13 @@ def test_mvnormal_ShapeFeature():
"alphas, size",
[
(np.array([[100, 1, 1], [1, 100, 1], [1, 1, 100]], dtype=config.floatX), None),
(np.array([[100, 1, 1], [1, 100, 1], [1, 1, 100]], dtype=config.floatX), 10),
(
np.array([[100, 1, 1], [1, 100, 1], [1, 1, 100]], dtype=config.floatX),
(10, 2),
(10, 3),
),
(
np.array([[100, 1, 1], [1, 100, 1], [1, 1, 100]], dtype=config.floatX),
(10, 2, 3),
),
],
)
Expand All @@ -633,6 +640,15 @@ def dirichlet_test_fn(mean=None, cov=None, size=None, random_state=None):
rv_numpy_tester(dirichlet, alphas, size=size, test_fn=dirichlet_test_fn)


def test_dirichlet_rng():
alphas = np.array([[100, 1, 1], [1, 100, 1], [1, 1, 100]], dtype=config.floatX)

with pytest.raises(ValueError, match="shape mismatch.*"):
# The independent dimension's shape is missing from size (i.e. should
# be `(10, 2, 3)`)
dirichlet.rng_fn(None, alphas, size=(10, 2))


M_at = iscalar("M")
M_at.tag.test_value = 3

Expand All @@ -644,8 +660,8 @@ def dirichlet_test_fn(mean=None, cov=None, size=None, random_state=None):
(at.ones((M_at,)), (M_at + 1,)),
(at.ones((M_at,)), (2, M_at)),
(at.ones((M_at, M_at + 1)), ()),
(at.ones((M_at, M_at + 1)), (M_at + 2,)),
(at.ones((M_at, M_at + 1)), (2, M_at + 2, M_at + 3)),
(at.ones((M_at, M_at + 1)), (M_at + 2, M_at)),
(at.ones((M_at, M_at + 1)), (2, M_at + 2, M_at + 3, M_at)),
],
)
def test_dirichlet_infer_shape(M, size):
Expand Down Expand Up @@ -684,8 +700,7 @@ def test_dirichlet_ShapeFeature():
d_rv = dirichlet(at.ones((M_at, N_at)), name="Gamma")

fg = FunctionGraph(
[i for i in graph_inputs([d_rv]) if not isinstance(i, Constant)],
[d_rv],
outputs=[d_rv],
clone=False,
features=[ShapeFeature()],
)
Expand Down Expand Up @@ -1092,7 +1107,7 @@ def test_betabinom_samples(M, a, p, size):
(
np.array([10, 20], dtype=np.int64),
np.array([[0.999, 0.001], [0.001, 0.999]], dtype=config.floatX),
(3,),
(3, 2),
lambda *args, **kwargs: np.stack([np.array([[10, 0], [0, 20]])] * 3),
),
],
Expand All @@ -1109,6 +1124,16 @@ def test_multinomial_samples(M, p, size, test_fn):
)


def test_multinomial_rng():
test_M = np.array([10, 20], dtype=np.int64)
test_p = np.array([[0.999, 0.001], [0.001, 0.999]], dtype=config.floatX)

with pytest.raises(ValueError, match="shape mismatch.*"):
# The independent dimension's shape is missing from size (i.e. should
# be `(1, 2)`)
multinomial.rng_fn(None, test_M, test_p, size=(1,))


@pytest.mark.parametrize(
"p, size, test_fn",
[
Expand Down
15 changes: 12 additions & 3 deletions tests/tensor/random/test_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from aesara.tensor.elemwise import DimShuffle
from aesara.tensor.random.basic import (
dirichlet,
multinomial,
multivariate_normal,
normal,
poisson,
Expand Down Expand Up @@ -120,12 +121,20 @@ def test_inplace_optimization():
np.array([[0], [10], [100]], dtype=config.floatX),
np.diag(np.array([1e-6], dtype=config.floatX)),
],
[2, 3],
[2, 3, 3],
),
(
dirichlet,
[np.array([[100, 1, 1], [1, 100, 1], [1, 1, 100]], dtype=config.floatX)],
[2, 3],
[2, 3, 3],
),
(
multinomial,
[
np.array([10, 20], dtype="int64"),
np.array([[0.999, 0.001], [0.001, 0.999]], dtype=config.floatX),
],
[3, 2],
),
],
)
Expand Down Expand Up @@ -288,7 +297,7 @@ def test_local_rv_size_lift(dist_op, dist_params, size):
np.array([[-1, 20], [300, -4000]], dtype=config.floatX),
np.eye(2).astype(config.floatX) * 1e-6,
),
(3,),
(3, 2),
1e-3,
),
],
Expand Down

0 comments on commit 2450186

Please sign in to comment.