Skip to content

Commit

Permalink
Deprecate remaining uses of Rebroadcast in favor of Unbroadcast
Browse files Browse the repository at this point in the history
  • Loading branch information
Ricardo authored and brandonwillard committed Jul 7, 2022
1 parent ac52d68 commit 7f8af9b
Show file tree
Hide file tree
Showing 18 changed files with 337 additions and 538 deletions.
2 changes: 1 addition & 1 deletion aesara/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def _as_symbolic(x, **kwargs) -> Variable:
def get_scalar_constant_value(v):
"""Return the constant scalar (i.e. 0-D) value underlying variable `v`.
If `v` is the output of dim-shuffles, fills, allocs, rebroadcasts, cast
If `v` is the output of dim-shuffles, fills, allocs, cast, etc.
this function digs through them.
If ``aesara.sparse`` is also there, we will look over CSM `Op`.
Expand Down
4 changes: 2 additions & 2 deletions aesara/compile/function/pfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,8 @@ def clone_inputs(i):
err_sug = (
"If the difference is related to the broadcast pattern,"
" you can call the"
" tensor.unbroadcast(var, axis_to_unbroadcast[, ...])"
" function to remove broadcastable dimensions."
" tensor.shape.unbroadcast(var, axis_to_unbroadcast[, ...])"
" function to mask broadcastable dimensions."
)

raise TypeError(err_msg, err_sug)
Expand Down
5 changes: 2 additions & 3 deletions aesara/ifelse.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@
from aesara.graph.basic import Apply, Variable, clone_replace, is_in_ancestors
from aesara.graph.op import _NoPythonOp
from aesara.graph.opt import GlobalOptimizer, in2out, local_optimizer
from aesara.tensor import basic
from aesara.tensor.shape import Reshape, Shape, SpecifyShape
from aesara.tensor.shape import Reshape, Shape, SpecifyShape, Unbroadcast


__docformat__ = "restructedtext en"
Expand Down Expand Up @@ -451,7 +450,7 @@ def cond_make_inplace(fgraph, node):
Shape,
SpecifyShape,
Reshape,
basic.Rebroadcast,
Unbroadcast,
at.math.Dot,
at.math.MaxAndArgmax,
at.subtensor.Subtensor,
Expand Down
19 changes: 5 additions & 14 deletions aesara/link/jax/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
Eye,
Join,
MakeVector,
Rebroadcast,
ScalarFromTensor,
TensorFromScalar,
)
Expand All @@ -50,7 +49,7 @@
from aesara.tensor.nlinalg import SVD, Det, Eig, Eigh, MatrixInverse, QRFull
from aesara.tensor.nnet.basic import LogSoftmax, Softmax, SoftmaxGrad
from aesara.tensor.random.op import RandomVariable
from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape, Unbroadcast
from aesara.tensor.slinalg import Cholesky, Solve, SolveTriangular
from aesara.tensor.subtensor import (
AdvancedIncSubtensor,
Expand Down Expand Up @@ -347,20 +346,12 @@ def specifyshape(x, *shape):
return specifyshape


@jax_funcify.register(Rebroadcast)
def jax_funcify_Rebroadcast(op, **kwargs):
op_axis = op.axis

def rebroadcast(x):
for axis, value in op_axis.items():
if value and x.shape[axis] != 1:
raise ValueError(
"Dimension %s in Rebroadcast's input was"
" supposed to be 1 (got %s instead)" % (axis, x.shape[axis])
)
@jax_funcify.register(Unbroadcast)
def jax_funcify_Unbroadcast(op, **kwargs):
def unbroadcast(x):
return x

return rebroadcast
return unbroadcast


@jax_funcify.register(ViewOp)
Expand Down
19 changes: 5 additions & 14 deletions aesara/link/numba/dispatch/tensor_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
Eye,
Join,
MakeVector,
Rebroadcast,
ScalarFromTensor,
TensorFromScalar,
)
from aesara.tensor.shape import Unbroadcast


@numba_funcify.register(AllocEmpty)
Expand Down Expand Up @@ -195,22 +195,13 @@ def makevector({", ".join(input_names)}):
return numba_basic.numba_njit(makevector_fn)


@numba_funcify.register(Rebroadcast)
def numba_funcify_Rebroadcast(op, **kwargs):
# Make sure op_axis only has ints. This way we can avoid literal_unroll
# which causes a segfault, see GH issue https://github.com/numba/numba/issues/8215
op_axis = tuple((axis, int(value)) for axis, value in op.axis.items())

@numba_funcify.register(Unbroadcast)
def numba_funcify_Unbroadcast(op, **kwargs):
@numba_basic.numba_njit
def rebroadcast(x):
for axis, value in op_axis:
if value and x.shape[axis] != 1:
raise ValueError(
("Dimension in Rebroadcast's input was supposed to be 1")
)
def unbroadcast(x):
return x

return rebroadcast
return unbroadcast


@numba_funcify.register(TensorFromScalar)
Expand Down
8 changes: 4 additions & 4 deletions aesara/scan/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from aesara.tensor.basic import get_scalar_constant_value
from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.math import minimum
from aesara.tensor.shape import shape_padleft
from aesara.tensor.shape import shape_padleft, unbroadcast
from aesara.tensor.type import TensorType, integer_dtypes
from aesara.updates import OrderedUpdates

Expand Down Expand Up @@ -751,7 +751,7 @@ def wrap_into_list(x):
# defined in scan utils
sit_sot_scan_inputs.append(
expand_empty(
at.unbroadcast(shape_padleft(actual_arg), 0),
unbroadcast(shape_padleft(actual_arg), 0),
actual_n_steps,
)
)
Expand Down Expand Up @@ -881,7 +881,7 @@ def wrap_into_list(x):
# this will represent only a slice and it will have one
# dimension less.
if isinstance(inner_out.type, TensorType) and return_steps.get(pos, 0) != 1:
outputs[pos] = at.unbroadcast(shape_padleft(inner_out), 0)
outputs[pos] = unbroadcast(shape_padleft(inner_out), 0)

if not return_list and len(outputs) == 1:
outputs = outputs[0]
Expand Down Expand Up @@ -1010,7 +1010,7 @@ def wrap_into_list(x):
sit_sot_inner_inputs.append(new_var)
sit_sot_scan_inputs.append(
expand_empty(
at.unbroadcast(shape_padleft(input.variable), 0),
unbroadcast(shape_padleft(input.variable), 0),
actual_n_steps,
)
)
Expand Down
Loading

0 comments on commit 7f8af9b

Please sign in to comment.