From 7f8af9bc28755d93dca3afff2534a8a5f5ecbd80 Mon Sep 17 00:00:00 2001 From: Ricardo Date: Mon, 9 May 2022 12:59:29 +0200 Subject: [PATCH] Deprecate remaining uses of Rebroadcast in favor of Unbroadcast --- aesara/__init__.py | 2 +- aesara/compile/function/pfunc.py | 4 +- aesara/ifelse.py | 5 +- aesara/link/jax/dispatch.py | 19 +- aesara/link/numba/dispatch/tensor_basic.py | 19 +- aesara/scan/basic.py | 8 +- aesara/tensor/basic.py | 219 +-------------------- aesara/tensor/basic_opt.py | 109 ++++------ aesara/tensor/shape.py | 105 ++++++++++ aesara/tensor/subtensor_opt.py | 23 +-- tests/link/test_jax.py | 13 +- tests/link/test_numba.py | 45 ++--- tests/scan/test_printing.py | 30 +-- tests/tensor/test_basic.py | 79 +------- tests/tensor/test_basic_opt.py | 69 ++++--- tests/tensor/test_shape.py | 63 ++++++ tests/tensor/test_subtensor_opt.py | 58 +++--- tests/test_rop.py | 5 +- 18 files changed, 337 insertions(+), 538 deletions(-) diff --git a/aesara/__init__.py b/aesara/__init__.py index d9404c11e1..39eef9a041 100644 --- a/aesara/__init__.py +++ b/aesara/__init__.py @@ -147,7 +147,7 @@ def _as_symbolic(x, **kwargs) -> Variable: def get_scalar_constant_value(v): """Return the constant scalar (i.e. 0-D) value underlying variable `v`. - If `v` is the output of dim-shuffles, fills, allocs, rebroadcasts, cast + If `v` is the output of dim-shuffles, fills, allocs, cast, etc. this function digs through them. If ``aesara.sparse`` is also there, we will look over CSM `Op`. diff --git a/aesara/compile/function/pfunc.py b/aesara/compile/function/pfunc.py index c608157c8b..94be2db2c6 100644 --- a/aesara/compile/function/pfunc.py +++ b/aesara/compile/function/pfunc.py @@ -204,8 +204,8 @@ def clone_inputs(i): err_sug = ( "If the difference is related to the broadcast pattern," " you can call the" - " tensor.unbroadcast(var, axis_to_unbroadcast[, ...])" - " function to remove broadcastable dimensions." + " tensor.shape.unbroadcast(var, axis_to_unbroadcast[, ...])" + " function to mask broadcastable dimensions." ) raise TypeError(err_msg, err_sug) diff --git a/aesara/ifelse.py b/aesara/ifelse.py index bdc26e0ca4..cc6f01b14b 100644 --- a/aesara/ifelse.py +++ b/aesara/ifelse.py @@ -23,8 +23,7 @@ from aesara.graph.basic import Apply, Variable, clone_replace, is_in_ancestors from aesara.graph.op import _NoPythonOp from aesara.graph.opt import GlobalOptimizer, in2out, local_optimizer -from aesara.tensor import basic -from aesara.tensor.shape import Reshape, Shape, SpecifyShape +from aesara.tensor.shape import Reshape, Shape, SpecifyShape, Unbroadcast __docformat__ = "restructedtext en" @@ -451,7 +450,7 @@ def cond_make_inplace(fgraph, node): Shape, SpecifyShape, Reshape, - basic.Rebroadcast, + Unbroadcast, at.math.Dot, at.math.MaxAndArgmax, at.subtensor.Subtensor, diff --git a/aesara/link/jax/dispatch.py b/aesara/link/jax/dispatch.py index 26647a14bb..e0b8e188f3 100644 --- a/aesara/link/jax/dispatch.py +++ b/aesara/link/jax/dispatch.py @@ -29,7 +29,6 @@ Eye, Join, MakeVector, - Rebroadcast, ScalarFromTensor, TensorFromScalar, ) @@ -50,7 +49,7 @@ from aesara.tensor.nlinalg import SVD, Det, Eig, Eigh, MatrixInverse, QRFull from aesara.tensor.nnet.basic import LogSoftmax, Softmax, SoftmaxGrad from aesara.tensor.random.op import RandomVariable -from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape +from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape, Unbroadcast from aesara.tensor.slinalg import Cholesky, Solve, SolveTriangular from aesara.tensor.subtensor import ( AdvancedIncSubtensor, @@ -347,20 +346,12 @@ def specifyshape(x, *shape): return specifyshape -@jax_funcify.register(Rebroadcast) -def jax_funcify_Rebroadcast(op, **kwargs): - op_axis = op.axis - - def rebroadcast(x): - for axis, value in op_axis.items(): - if value and x.shape[axis] != 1: - raise ValueError( - "Dimension %s in Rebroadcast's input was" - " supposed to be 1 (got %s instead)" % (axis, x.shape[axis]) - ) +@jax_funcify.register(Unbroadcast) +def jax_funcify_Unbroadcast(op, **kwargs): + def unbroadcast(x): return x - return rebroadcast + return unbroadcast @jax_funcify.register(ViewOp) diff --git a/aesara/link/numba/dispatch/tensor_basic.py b/aesara/link/numba/dispatch/tensor_basic.py index 3f1662e919..e329c793bd 100644 --- a/aesara/link/numba/dispatch/tensor_basic.py +++ b/aesara/link/numba/dispatch/tensor_basic.py @@ -14,10 +14,10 @@ Eye, Join, MakeVector, - Rebroadcast, ScalarFromTensor, TensorFromScalar, ) +from aesara.tensor.shape import Unbroadcast @numba_funcify.register(AllocEmpty) @@ -195,22 +195,13 @@ def makevector({", ".join(input_names)}): return numba_basic.numba_njit(makevector_fn) -@numba_funcify.register(Rebroadcast) -def numba_funcify_Rebroadcast(op, **kwargs): - # Make sure op_axis only has ints. This way we can avoid literal_unroll - # which causes a segfault, see GH issue https://github.com/numba/numba/issues/8215 - op_axis = tuple((axis, int(value)) for axis, value in op.axis.items()) - +@numba_funcify.register(Unbroadcast) +def numba_funcify_Unbroadcast(op, **kwargs): @numba_basic.numba_njit - def rebroadcast(x): - for axis, value in op_axis: - if value and x.shape[axis] != 1: - raise ValueError( - ("Dimension in Rebroadcast's input was supposed to be 1") - ) + def unbroadcast(x): return x - return rebroadcast + return unbroadcast @numba_funcify.register(TensorFromScalar) diff --git a/aesara/scan/basic.py b/aesara/scan/basic.py index 2a53283ef2..81c42cdc1f 100644 --- a/aesara/scan/basic.py +++ b/aesara/scan/basic.py @@ -14,7 +14,7 @@ from aesara.tensor.basic import get_scalar_constant_value from aesara.tensor.exceptions import NotScalarConstantError from aesara.tensor.math import minimum -from aesara.tensor.shape import shape_padleft +from aesara.tensor.shape import shape_padleft, unbroadcast from aesara.tensor.type import TensorType, integer_dtypes from aesara.updates import OrderedUpdates @@ -751,7 +751,7 @@ def wrap_into_list(x): # defined in scan utils sit_sot_scan_inputs.append( expand_empty( - at.unbroadcast(shape_padleft(actual_arg), 0), + unbroadcast(shape_padleft(actual_arg), 0), actual_n_steps, ) ) @@ -881,7 +881,7 @@ def wrap_into_list(x): # this will represent only a slice and it will have one # dimension less. if isinstance(inner_out.type, TensorType) and return_steps.get(pos, 0) != 1: - outputs[pos] = at.unbroadcast(shape_padleft(inner_out), 0) + outputs[pos] = unbroadcast(shape_padleft(inner_out), 0) if not return_list and len(outputs) == 1: outputs = outputs[0] @@ -1010,7 +1010,7 @@ def wrap_into_list(x): sit_sot_inner_inputs.append(new_var) sit_sot_scan_inputs.append( expand_empty( - at.unbroadcast(shape_padleft(input.variable), 0), + unbroadcast(shape_padleft(input.variable), 0), actual_n_steps, ) ) diff --git a/aesara/tensor/basic.py b/aesara/tensor/basic.py index 751daa2a31..1ed3d855da 100644 --- a/aesara/tensor/basic.py +++ b/aesara/tensor/basic.py @@ -10,7 +10,7 @@ from collections.abc import Sequence from functools import partial from numbers import Number -from typing import Dict, Optional, Tuple, Union +from typing import Optional, Tuple, Union from typing import cast as type_cast import numpy as np @@ -44,6 +44,7 @@ from aesara.tensor.shape import ( Shape, Shape_i, + Unbroadcast, shape, shape_padaxis, shape_padleft, @@ -254,7 +255,7 @@ def get_scalar_constant_value( ): """Return the constant scalar(0-D) value underlying variable `v`. - If `v` is the output of dimshuffles, fills, allocs, rebroadcasts, + If `v` is the output of dimshuffles, fills, allocs, etc, cast, OutputGuard, DeepCopyOp, ScalarFromTensor, ScalarOp, Elemwise and some pattern with Subtensor, this function digs through them. @@ -323,7 +324,7 @@ def get_scalar_constant_value( ( Alloc, DimShuffle, - Rebroadcast, + Unbroadcast, # outputguard is only used in debugmode but we # keep it here to avoid problems with old pickels. compile.ops.OutputGuard, @@ -495,7 +496,7 @@ def get_scalar_constant_value( gp_broadcastable = grandparent.type.broadcastable ndim = grandparent.type.ndim if grandparent.owner and isinstance( - grandparent.owner.op, Rebroadcast + grandparent.owner.op, Unbroadcast ): ggp_broadcastable = grandparent.owner.inputs[0].broadcastable l = [ @@ -616,185 +617,6 @@ def c_code_cache_version(self): scalar_from_tensor = ScalarFromTensor() -class Rebroadcast(COp): - """ - Change the input's broadcastable fields in some predetermined way. - - See Also - -------- - unbroadcast - - Notes - ----- - Works inplace and works for CudaNdarrayType. - - Examples - -------- - ``Rebroadcast((0, True), (1, False))(x)`` would make `x` broadcastable in - axis 0 and not broadcastable in axis 1. - - """ - - view_map = {0: [0]} - _f16_ok = True - # Mapping from Type to C code (and version) to use. - # In the C code, the name of the input variable is %(iname)s, - # the output variable is %(oname)s. - c_code_and_version: Dict = {} - - check_input = False - __props__ = ("axis",) - _f16_ok = True - - def __init__(self, *axis): - # Sort them to make sure we merge all possible case. - items = sorted(axis) - self.axis = dict(items) - for axis, broad in self.axis.items(): - if not isinstance(axis, (np.integer, int)): - raise TypeError(f"Rebroadcast needs integer axes. Got {axis}") - - if not isinstance(broad, (np.bool_, bool)): - raise TypeError( - f"Rebroadcast needs bool for new broadcast pattern. Got {broad}" - ) - - def __hash__(self): - # Need special __hash__ as dict aren't hashable. - # no ambiguity because each item key is unique - items = sorted(self.axis.items()) - return hash((type(self), tuple(items))) - - def __str__(self): - return f"{self.__class__.__name__}{{{','.join(str(i) for i in self.axis.items())}}}" - - def make_node(self, x): - if self.axis.keys() and (x.ndim <= max(self.axis.keys())): - raise ValueError("Trying to rebroadcast non-existent dimension") - t = x.type.clone( - shape=[self.axis.get(i, b) for i, b in enumerate(x.type.broadcastable)] - ) - return Apply(self, [x], [t()]) - - def perform(self, node, inp, out_): - (x,) = inp - (out,) = out_ - for axis, value in self.axis.items(): - if value and x.shape[axis] != 1: - raise ValueError( - f"Dimension {axis} in Rebroadcast's input was" - f" supposed to be 1 (got {x.shape[axis]} instead)" - ) - out[0] = x - - def grad(self, inp, grads): - (x,) = inp - (gz,) = grads - # restore the broadcasting pattern of the input - return ( - Rebroadcast( - *[ - (axis, x.type.broadcastable[axis]) - for axis, value in self.axis.items() - ] - )(gz), - ) - - def infer_shape(self, fgraph, node, ishapes): - assert len(ishapes) == 1 - l = [] - one = aesara.tensor.basic.constant(1) - for ax in range(len(ishapes[0])): - if self.axis.get(ax, False): - l.append(one) - else: - l.append(ishapes[0][ax]) - - return [tuple(l)] - - def R_op(self, inputs, eval_points): - if eval_points[0] is None: - return [None] - return self(*eval_points, return_list=True) - - def c_code(self, node, nodename, inp, out, sub): - (iname,) = inp - (oname,) = out - fail = sub["fail"] - - itype = node.inputs[0].type.__class__ - if itype in self.c_code_and_version: - code, version = self.c_code_and_version[itype] - final_code = "" - for axis, value in self.axis.items(): - if value: - final_code += code % locals() - return ( - final_code - + f""" - Py_XDECREF({oname}); - {oname} = {iname}; - Py_XINCREF({oname}); - """ - ) - raise NotImplementedError() - - def c_code_cache_version(self): - version = [] - # If any of the c code is unversioned, we have to return () - # Else, we will return a list of (type name, version) pairs. - for t, (c, v) in sorted( - self.c_code_and_version.items(), key=lambda pair: str(pair[0]) - ): - if not v: - warnings.warn( - f"Type {t} has C code for Rebroadcast, but it " - "has no version. You should add a 'version' " - "keyword arg when calling " - "register_rebroadcast_c_code.", - stacklevel=2, - ) - return () - version.append((str(t), v)) - - if version: - version.append(1) - return tuple(version) - - -def register_rebroadcast_c_code(typ, code, version=()): - """ - Tell Rebroadcast how to generate C code for an Aesara Type. - - typ : Aesara type - It must be the Aesara class itself and not an instance of the class. - code : C code - That checks if the dimension %(axis)s is of shape 1 for the Aesara type - 'typ'. Use %(iname)s and %(oname)s for the input and output C variable - names respectively, and %(axis)s for the axis that we need to check. - This code is put in a loop for all axes. - version - A number indicating the version of the code, for cache. - - """ - Rebroadcast.c_code_and_version[typ] = (code, version) - - -register_rebroadcast_c_code( - TensorType, - """ - if(PyArray_DIMS(%(iname)s)[%(axis)s] != 1){ - PyErr_Format(PyExc_ValueError, - "Dimension %(axis)s in Rebroadcast's input was" - " supposed to be 1 (got %%d instead)", - PyArray_DIMS(%(iname)s)[%(axis)s]); - %(fail)s - } - """, - version=1, -) - - # to be removed as we get the epydoc routine-documenting thing going # -JB 20080924 def _conversion(real_value: Op, name: str) -> Op: @@ -2254,36 +2076,6 @@ def c_code(self, node, name, inputs, outputs, sub): ) -def unbroadcast(x, *axes): - """ - Make the input impossible to broadcast in the specified axes. - - For example, unbroadcast(x, 0) will make the first dimension - of x not broadcastable. When performing the function, if the length - of x along that dimension is not 1, a ValueError will be raised. - - We apply the opt here not to pollute the graph - - Parameters - ---------- - x : tensor_like - Input aesara tensor. - axis : an int or an iterable object such as list or tuple of int values - The dimension along which the tensor x should be unbroadcastable. - If the length of x along these dimensions is not 1, a ValueError will - be raised. - - Returns - ------- - tensor - A aesara tensor, which is unbroadcastable along the specified dimensions. - - """ - x = as_tensor_variable(x) - rval = Rebroadcast(*[(axis, False) for axis in axes])(x) - return aesara.tensor.basic_opt.apply_rebroadcast_opt(rval) - - class Join(COp): r""" Concatenate multiple `TensorVariable`\s along some axis. @@ -4195,7 +3987,6 @@ def take_along_axis(arr, indices, axis=0): "stack", "roll", "join", - "unbroadcast", "split", "transpose", "extract_constant", diff --git a/aesara/tensor/basic_opt.py b/aesara/tensor/basic_opt.py index dc94ee078e..1760537381 100644 --- a/aesara/tensor/basic_opt.py +++ b/aesara/tensor/basic_opt.py @@ -48,7 +48,6 @@ AllocEmpty, Join, MakeVector, - Rebroadcast, ScalarFromTensor, Split, TensorFromScalar, @@ -77,9 +76,11 @@ Shape, Shape_i, SpecifyShape, + Unbroadcast, shape_i, shape_padleft, specify_shape, + unbroadcast, ) from aesara.tensor.sort import TopKOp from aesara.tensor.subtensor import Subtensor, get_idx_list @@ -2226,10 +2227,13 @@ def local_upcast_elemwise_constant_inputs(fgraph, node): @register_useless @register_canonicalize @register_specialize -@local_optimizer([Rebroadcast]) -def local_useless_rebroadcast(fgraph, node): - """Remove `Rebroadcast` if it does not actually change the broadcasting pattern.""" - if isinstance(node.op, Rebroadcast): +@local_optimizer([Unbroadcast]) +def local_useless_unbroadcast(fgraph, node): + """Remove `Unbroadcast` if it does not actually change the broadcasting pattern. + + TODO: Implement equivalent rewrite for SpecifyShape + """ + if isinstance(node.op, Unbroadcast): x = node.inputs[0] if x.broadcastable == node.outputs[0].broadcastable: # No broadcastable flag was modified @@ -2238,15 +2242,12 @@ def local_useless_rebroadcast(fgraph, node): return [x] else: # Keep the flags that modify something - new_axis = {} - for dim, bc in node.op.axis.items(): - if x.broadcastable[dim] != bc: - new_axis[dim] = bc - if new_axis == node.op.axis: + new_axes = tuple(ax for ax in node.op.axes if x.type.shape[ax] == 1) + if new_axes == node.op.axes: # All flags are useful - return + return None else: - r = Rebroadcast(*new_axis.items())(x) + r = unbroadcast(x, *new_axes) # Copy over stacktrace from previous output copy_stack_trace(node.outputs, r) return [r] @@ -2254,93 +2255,49 @@ def local_useless_rebroadcast(fgraph, node): @register_canonicalize @register_specialize -@local_optimizer([Rebroadcast]) -def local_rebroadcast_lift(fgraph, node): +@local_optimizer([Unbroadcast]) +def local_unbroadcast_lift(fgraph, node): """ - Lifts Rebroadcast through unary Elemwise operations, - and merges consecutive Rebroadcasts. + Lifts `Unbroadcast` through unary Elemwise operations, + and merges consecutive `Unbroadcast`s. - Rebroadcast(Elemwise(x)) => Elemwise(Rebroadcast(x)) - Rebroadcast(Rebroadcast(x)) => Rebroadcast(x) + Unbroadcast(Elemwise(x)) => Elemwise(Unbroadcast(x)) + Unbroadcast(Unbroadcast(x)) => Unbroadcast(x) + TODO: Implement equivalent Elemwise lift for SpecifyShape """ op = node.op - if not isinstance(op, Rebroadcast): + if not isinstance(op, Unbroadcast): return False inp = node.inputs[0] inode = inp.owner if inode and isinstance(inode.op, Elemwise) and len(inode.inputs) == 1: - # It may happen that `input` has no client because this optimization - # is called from `apply_rebroadcast_opt`, which in particular is used - # by the `unbroadcast` function before we are in the actual function - # compilation phase. if len(fgraph.clients.get(inp, ())) == 1: - rebroadcasted = Rebroadcast(*list(op.axis.items()))(inode.inputs[0]) - # Copy over stacktrace from previous output (after rebroadcasting) - # to new output, because an error in the new graph right after - # rebroadcasting must have been caused by the previous rebroadcasting. - copy_stack_trace(node.outputs, rebroadcasted) + unbroadcasted = unbroadcast(inode.inputs[0], *op.axes) + copy_stack_trace(node.outputs, unbroadcasted) - rval = inode.op.make_node(rebroadcasted).outputs + rval = inode.op.make_node(unbroadcasted).outputs - # Copy over stacktrace from previous output (after rebroadcasting) + # Copy over stacktrace from previous output (after unbroadcasting) # and input (after elemwise operation) to new output, because an # error in the new graph could have been caused by either of the # two ops. copy_stack_trace(node.outputs + node.inputs, rval) - return rval - if inode and isinstance(inode.op, Rebroadcast): - # the "axis" specification in the outer Rebroadcast overrides - # the axis of the inner one - axis = inode.op.axis.copy() - axis.update(op.axis) - iinput = inode.inputs[0] - - rval = [Rebroadcast(*list(axis.items()))(iinput)] - # Copy over stacktrace from previous output (after second rebroadcast) - # and from previous input (after first rebroadcast op) because an error in - # the new graph could have been caused by either of the two - # rebroadcast ops. + if inode and isinstance(inode.op, Unbroadcast): + # Merge axis of each unbroadcast + axis = tuple(set(inode.op.axes).union(set(op.axes))) + iinput = inode.inputs[0] + rval = [unbroadcast(iinput, *axis)] + # Copy over stacktrace from previous output (after second unbroadcasting) + # and from previous input (after first unbroadcasting) because an error in + # the new graph could have been caused by either of the two Unbroadcast ops. copy_stack_trace(node.outputs + node.inputs, rval) return rval -def apply_rebroadcast_opt(rval): - """ - Apply as many times as required the optimization local_useless_rebroadcast - and local_rebroadcast_lift. - - Parameters - ---------- - rval: a Variable - - Returns - ------- - A Variable (the same if no optimization can be applied) - - """ - - fg = FunctionGraph([], []) - changed = True - while changed and rval.owner: - changed = False - rval2 = local_useless_rebroadcast.transform(fg, rval.owner) - if rval2: - assert len(rval2) == 1 - rval = rval2[0] - changed = True - if rval.owner: - rval2 = local_rebroadcast_lift.transform(fg, rval.owner) - if rval2: - assert len(rval2) == 1 - rval = rval2[0] - changed = True - return rval - - @register_specialize @register_canonicalize @register_useless diff --git a/aesara/tensor/shape.py b/aesara/tensor/shape.py index 7a70696b24..f6ed3590c5 100644 --- a/aesara/tensor/shape.py +++ b/aesara/tensor/shape.py @@ -926,3 +926,108 @@ def specify_broadcastable(x, *axes): shape_info = [1 if i in axes else None for i in range(len(x.type.shape))] return specify_shape(x, shape_info) + + +class Unbroadcast(COp): + """ + Mask static broadcastable dimensions of input as `None` + + See Also + -------- + unbroadcast + + + Examples + -------- + ``Unbroadcast((1,))(x)`` would make `x` second static dimension be `None` + + """ + + view_map = {0: [0]} + _f16_ok = True + # Mapping from Type to C code (and version) to use. + # In the C code, the name of the input variable is %(iname)s, + # the output variable is %(oname)s. + c_code_and_version: Dict = {} + + check_input = False + __props__ = ("axes",) + _f16_ok = True + + def __init__(self, *axis): + # Sort them to make sure we merge all possible case. + items = tuple(sorted(axis)) + self.axes = items + for axis in self.axes: + if not isinstance(axis, (np.integer, int)): + raise TypeError(f"Unbroadcast needs integer axes. Got {axis}") + + def __str__(self): + return f"{self.__class__.__name__}{{{','.join(str(i) for i in self.axes)}}}" + + def make_node(self, x): + x = as_tensor_variable(x) + if x.type.ndim <= max(self.axes): + raise ValueError("Trying to unbroadcast of non-existent dimension") + shape = [ + None if (sh == 1 and i in self.axes) else sh + for i, sh in enumerate(x.type.shape) + ] + return Apply(self, [x], [x.type.clone(shape=shape)()]) + + def perform(self, node, inp, out_): + (x,) = inp + (out,) = out_ + out[0] = x + + def grad(self, inp, grads): + (x,) = inp + (gz,) = grads + # restore the broadcasting pattern of the input + return [specify_shape(gz, x.type.shape)] + + def infer_shape(self, fgraph, node, ishapes): + assert len(ishapes) == 1 + return [tuple(ishapes[0])] + + def R_op(self, inputs, eval_points): + if eval_points[0] is None: + return [None] + return self(*eval_points, return_list=True) + + def c_code(self, node, nodename, inp, out, sub): + (iname,) = inp + (oname,) = out + + return f""" + Py_XDECREF({oname}); + {oname} = {iname}; + Py_XINCREF({oname}); + """ + + def c_code_cache_version(self): + return (3,) + + +def unbroadcast(x, *axes): + """ + Mask static broadcastable dimensions of input as `None` + + Parameters + ---------- + x : tensor_like + Input aesara tensor. + axis : an int or an iterable object such as list or tuple of int values + The broadcastable dimensions of x that should be unbroadcasted. + + Returns + ------- + tensor + A aesara tensor, with static broadcastable dimensions masked as `None` + + """ + x = as_tensor_variable(x) + unbroadcasted_axes = [axis for axis in axes if x.type.shape[axis] == 1] + if not unbroadcasted_axes: + return x + return Unbroadcast(*unbroadcasted_axes)(x) diff --git a/aesara/tensor/subtensor_opt.py b/aesara/tensor/subtensor_opt.py index 2118c0ce6d..bdf777ecfd 100644 --- a/aesara/tensor/subtensor_opt.py +++ b/aesara/tensor/subtensor_opt.py @@ -14,7 +14,6 @@ ARange, Join, MakeVector, - Rebroadcast, ScalarFromTensor, TensorFromScalar, alloc, @@ -50,9 +49,11 @@ from aesara.tensor.shape import ( Shape, SpecifyShape, + Unbroadcast, shape_padleft, shape_tuple, specify_shape, + unbroadcast, ) from aesara.tensor.sharedvar import TensorSharedVariable from aesara.tensor.subtensor import ( @@ -370,7 +371,7 @@ def local_subtensor_lift(fgraph, node): Handles the following unary ops: elemwise(x,...)[idx] -> elemwise(x[idx],...) when x,... are broadcasted scalar or not broadcasted at all - rebroadcast(x)[idx] => rebroadcast(x[idx]) + Unbroadcast(x)[idx] => Unbroadcast(x[idx]) """ if isinstance(node.op, Subtensor): @@ -429,34 +430,34 @@ def local_subtensor_lift(fgraph, node): copy_stack_trace([node.outputs[0], node.inputs[0]], ret) return [ret] - if isinstance(u.owner.op, Rebroadcast): - # make sure that Rebroadcast has only 1 input - assert len(u.owner.inputs) == 1 - + if isinstance(u.owner.op, Unbroadcast): # Subtensor might reduce dim., adapt broadcast pattern accordingly - new_axis = [] + old_axes = u.owner.op.axes + new_axes = [] # loop through indices being subtensor-ed # i indexes broadcastable pattern before subtensor # j indexes broadcastable pattern after subtensor j = 0 for (i, x) in enumerate(node.op.idx_list): - # if its not a slice, it will reduce the dimension, should + # if it is not a slice, it will reduce the dimension, should # not appear in the broascastable dimensions if isinstance(x, slice): - new_axis += [(j, u.broadcastable[i])] + if i in old_axes: + new_axes.append(j) j += 1 # now keep the broadcastable pattern of all # items not appearing in subtensor list for i in range(len(node.op.idx_list), len(u.broadcastable)): - new_axis += [(j, u.broadcastable[i])] + if i in old_axes: + new_axes.append(j) j += 1 subt_x = node.op(u.owner.inputs[0], *node.inputs[1:]) # Copy over previous output stacktrace copy_stack_trace(node.outputs[0], subt_x) - rbcast_subt_x = Rebroadcast(*new_axis)(subt_x) + rbcast_subt_x = unbroadcast(subt_x, *new_axes) # Copy over previous output stacktrace # and stacktrace from previous unary operation copy_stack_trace([node.outputs[0], node.inputs[0]], rbcast_subt_x) diff --git a/tests/link/test_jax.py b/tests/link/test_jax.py index ca9d28ab7b..4c392efd57 100644 --- a/tests/link/test_jax.py +++ b/tests/link/test_jax.py @@ -39,7 +39,7 @@ from aesara.tensor.nnet.basic import SoftmaxGrad from aesara.tensor.random.basic import RandomVariable, normal from aesara.tensor.random.utils import RandomStream -from aesara.tensor.shape import Shape, Shape_i, SpecifyShape, reshape +from aesara.tensor.shape import Shape, Shape_i, SpecifyShape, Unbroadcast, reshape from aesara.tensor.type import ( dscalar, dvector, @@ -201,20 +201,11 @@ def test_jax_compile_ops(): compare_jax_and_py(x_fg, []) x_np = np.zeros((20, 1, 1)) - x = at.Rebroadcast((0, False), (1, True), (2, False))(at.as_tensor_variable(x_np)) + x = Unbroadcast(0, 2)(at.as_tensor_variable(x_np)) x_fg = FunctionGraph([], [x]) compare_jax_and_py(x_fg, []) - with config.change_flags(compute_test_value="off"): - x = at.Rebroadcast((0, True), (1, False), (2, False))( - at.as_tensor_variable(x_np) - ) - x_fg = FunctionGraph([], [x]) - - with pytest.raises(ValueError): - compare_jax_and_py(x_fg, []) - x = ViewOp()(at.as_tensor_variable(x_np)) x_fg = FunctionGraph([], [x]) diff --git a/tests/link/test_numba.py b/tests/link/test_numba.py index 70b6ec80fd..a92abe4445 100644 --- a/tests/link/test_numba.py +++ b/tests/link/test_numba.py @@ -40,7 +40,7 @@ from aesara.tensor import subtensor as at_subtensor from aesara.tensor.elemwise import Elemwise from aesara.tensor.math import All, Any, Max, Mean, Min, Prod, ProdWithoutZeros, Sum -from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape +from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape, Unbroadcast class MyType(Type): @@ -769,39 +769,18 @@ def test_ScalarFromTensor(v): ) -@pytest.mark.parametrize( - "v, axis, fails", - [ - ( - set_test_value(at.matrix(), np.array([[1.0]], dtype=config.floatX)), - [(0, True), (1, True)], - False, - ), - ( - set_test_value(at.matrix(), np.array([[1.0, 2.0]], dtype=config.floatX)), - [(0, True), (1, False)], - False, - ), - ( - set_test_value(at.matrix(), np.array([[1.0, 2.0]], dtype=config.floatX)), - [(0, True), (1, True)], - True, - ), - ], -) -def test_Rebroadcast(v, axis, fails): - g = atb.Rebroadcast(*axis)(v) +def test_Unbroadcast(): + v = set_test_value(at.row(), np.array([[1.0, 2.0]], dtype=config.floatX)) + g = Unbroadcast(0)(v) g_fg = FunctionGraph(outputs=[g]) - cm = contextlib.suppress() if not fails else pytest.raises(ValueError) - with cm: - compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, (SharedVariable, Constant)) - ], - ) + compare_numba_and_py( + g_fg, + [ + i.tag.test_value + for i in g_fg.inputs + if not isinstance(i, (SharedVariable, Constant)) + ], + ) @pytest.mark.parametrize( diff --git a/tests/scan/test_printing.py b/tests/scan/test_printing.py index 1497462b8a..4046555563 100644 --- a/tests/scan/test_printing.py +++ b/tests/scan/test_printing.py @@ -36,7 +36,7 @@ def test_debugprint_sitsot(): | | | | | |k [id D] | | | | | |Subtensor{int64} [id H] | | | | | |Shape [id I] - | | | | | | |Rebroadcast{(0, False)} [id J] + | | | | | | |Unbroadcast{0} [id J] | | | | | | |InplaceDimShuffle{x,0} [id K] | | | | | | |Elemwise{second,no_inplace} [id L] | | | | | | |A [id M] @@ -45,9 +45,9 @@ def test_debugprint_sitsot(): | | | | | |ScalarConstant{0} [id P] | | | | |Subtensor{int64} [id Q] | | | | |Shape [id R] - | | | | | |Rebroadcast{(0, False)} [id J] + | | | | | |Unbroadcast{0} [id J] | | | | |ScalarConstant{1} [id S] - | | | |Rebroadcast{(0, False)} [id J] + | | | |Unbroadcast{0} [id J] | | | |ScalarFromTensor [id T] | | | |Subtensor{int64} [id H] | | |A [id M] (outer_in_non_seqs-0) @@ -91,7 +91,7 @@ def test_debugprint_sitsot_no_extra_info(): | | | | | |k [id D] | | | | | |Subtensor{int64} [id H] | | | | | |Shape [id I] - | | | | | | |Rebroadcast{(0, False)} [id J] + | | | | | | |Unbroadcast{0} [id J] | | | | | | |InplaceDimShuffle{x,0} [id K] | | | | | | |Elemwise{second,no_inplace} [id L] | | | | | | |A [id M] @@ -100,9 +100,9 @@ def test_debugprint_sitsot_no_extra_info(): | | | | | |ScalarConstant{0} [id P] | | | | |Subtensor{int64} [id Q] | | | | |Shape [id R] - | | | | | |Rebroadcast{(0, False)} [id J] + | | | | | |Unbroadcast{0} [id J] | | | | |ScalarConstant{1} [id S] - | | | |Rebroadcast{(0, False)} [id J] + | | | |Unbroadcast{0} [id J] | | | |ScalarFromTensor [id T] | | | |Subtensor{int64} [id H] | | |A [id M] @@ -261,7 +261,7 @@ def compute_A_k(A, k): > | | | | | | |*3- [id BF] -> [id X] (inner_in_non_seqs-1) > | | | | | | |Subtensor{int64} [id BJ] > | | | | | | |Shape [id BK] - > | | | | | | | |Rebroadcast{(0, False)} [id BL] + > | | | | | | | |Unbroadcast{0} [id BL] > | | | | | | | |InplaceDimShuffle{x,0} [id BM] > | | | | | | | |Elemwise{second,no_inplace} [id BN] > | | | | | | | |*2- [id BO] -> [id W] (inner_in_non_seqs-0) @@ -270,9 +270,9 @@ def compute_A_k(A, k): > | | | | | | |ScalarConstant{0} [id BR] > | | | | | |Subtensor{int64} [id BS] > | | | | | |Shape [id BT] - > | | | | | | |Rebroadcast{(0, False)} [id BL] + > | | | | | | |Unbroadcast{0} [id BL] > | | | | | |ScalarConstant{1} [id BU] - > | | | | |Rebroadcast{(0, False)} [id BL] + > | | | | |Unbroadcast{0} [id BL] > | | | | |ScalarFromTensor [id BV] > | | | | |Subtensor{int64} [id BJ] > | | | |*2- [id BO] -> [id W] (inner_in_non_seqs-0) (outer_in_non_seqs-0) @@ -350,7 +350,7 @@ def compute_A_k(A, k): > | | | | | | |*3- [id BB] (inner_in_non_seqs-1) > | | | | | | |Subtensor{int64} [id BL] > | | | | | | |Shape [id BM] - > | | | | | | | |Rebroadcast{(0, False)} [id BN] + > | | | | | | | |Unbroadcast{0} [id BN] > | | | | | | | |InplaceDimShuffle{x,0} [id BO] > | | | | | | | |Elemwise{second,no_inplace} [id BP] > | | | | | | | |*2- [id BA] (inner_in_non_seqs-0) @@ -359,9 +359,9 @@ def compute_A_k(A, k): > | | | | | | |ScalarConstant{0} [id BS] > | | | | | |Subtensor{int64} [id BT] > | | | | | |Shape [id BU] - > | | | | | | |Rebroadcast{(0, False)} [id BN] + > | | | | | | |Unbroadcast{0} [id BN] > | | | | | |ScalarConstant{1} [id BV] - > | | | | |Rebroadcast{(0, False)} [id BN] + > | | | | |Unbroadcast{0} [id BN] > | | | | |ScalarFromTensor [id BW] > | | | | |Subtensor{int64} [id BL] > | | | |*2- [id BA] (inner_in_non_seqs-0) (outer_in_non_seqs-0) @@ -487,7 +487,7 @@ def test_debugprint_mitmot(): | | | | | | | |k [id G] | | | | | | | |Subtensor{int64} [id K] | | | | | | | |Shape [id L] - | | | | | | | | |Rebroadcast{(0, False)} [id M] + | | | | | | | | |Unbroadcast{0} [id M] | | | | | | | | |InplaceDimShuffle{x,0} [id N] | | | | | | | | |Elemwise{second,no_inplace} [id O] | | | | | | | | |A [id P] @@ -496,9 +496,9 @@ def test_debugprint_mitmot(): | | | | | | | |ScalarConstant{0} [id S] | | | | | | |Subtensor{int64} [id T] | | | | | | |Shape [id U] - | | | | | | | |Rebroadcast{(0, False)} [id M] + | | | | | | | |Unbroadcast{0} [id M] | | | | | | |ScalarConstant{1} [id V] - | | | | | |Rebroadcast{(0, False)} [id M] + | | | | | |Unbroadcast{0} [id M] | | | | | |ScalarFromTensor [id W] | | | | | |Subtensor{int64} [id K] | | | | |A [id P] (outer_in_non_seqs-0) diff --git a/tests/tensor/test_basic.py b/tests/tensor/test_basic.py index 77687b9061..f764c5aaf4 100644 --- a/tests/tensor/test_basic.py +++ b/tests/tensor/test_basic.py @@ -34,7 +34,6 @@ Join, MakeVector, PermuteRowElements, - Rebroadcast, ScalarFromTensor, Split, TensorFromScalar, @@ -86,7 +85,6 @@ triu, triu_indices, triu_indices_from, - unbroadcast, vertical_stack, zeros_like, ) @@ -104,7 +102,6 @@ dscalar, dscalars, dtensor3, - dtensor4, dvector, fmatrix, fscalar, @@ -337,7 +334,7 @@ def _numpy_second(x, y): ) -# Partial un broadcast of a dimshuffled input +# Partial unbroadcast of a dimshuffled input TestAllocDimshuffleGradBroadcast = makeBroadcastTester( name="Allocb4GradTester", op=lambda x: alloc(x.dimshuffle("x", "x", 0), 1, s2, s3), @@ -3223,80 +3220,6 @@ def test_too_big(self): constant()[[val, val]] -class TestBroadcast: - def test_unbroadcast(self): - # test that the unbroadcast fct don't insert not needed broadcast - # and fuse consecutive Rebroadcast op - - x = matrix() - assert unbroadcast(x, 0) is x - assert unbroadcast(x, 1) is x - assert unbroadcast(x, 1, 0) is x - assert unbroadcast(x, 0, 1) is x - - x = row() - assert unbroadcast(x, 0) is not x - assert unbroadcast(x, 1) is x - assert unbroadcast(x, 1, 0) is not x - assert unbroadcast(x, 0, 1) is not x - - # The first broadcast is remove the broadcast, so the second - # should not make one - assert unbroadcast(unbroadcast(x, 0), 0).owner.inputs[0] is x - - # Test that consecutive Rebroadcast op are fused - x = TensorType(dtype="float64", shape=(True, True))() - assert unbroadcast(unbroadcast(x, 1), 0).owner.inputs[0] is x - - def test_infer_shape(self): - x = matrix() - y = unbroadcast(x, 0) - f = aesara.function([x], y.shape) - assert (f(np.zeros((2, 5), dtype=config.floatX)) == [2, 5]).all() - topo = f.maker.fgraph.toposort() - if config.mode != "FAST_COMPILE": - assert len(topo) == 3 - assert isinstance(topo[0].op, Shape_i) - assert isinstance(topo[1].op, Shape_i) - assert isinstance(topo[2].op, MakeVector) - - x = row() - y = unbroadcast(x, 0) - f = aesara.function([x], y.shape) - assert (f(np.zeros((1, 5), dtype=config.floatX)) == [1, 5]).all() - topo = f.maker.fgraph.toposort() - if config.mode != "FAST_COMPILE": - assert len(topo) == 2 - assert isinstance(topo[0].op, Shape_i) - assert isinstance(topo[1].op, MakeVector) - - -class TestRebroadcast(utt.InferShapeTester): - def test_rebroadcast(self): - rng = np.random.default_rng(3453) - # Rebroadcast - adtens4 = dtensor4() - adict = [(0, False), (1, True), (2, False), (3, True)] - adtens4_val = rng.random((2, 1, 3, 1)).astype(config.floatX) - self._compile_and_check( - [adtens4], - [Rebroadcast(*adict)(adtens4)], - [adtens4_val], - Rebroadcast, - warn=False, - ) - - adtens4_bro = TensorType("float64", (True, True, True, False))() - bdict = [(0, True), (1, False), (2, False), (3, False)] - adtens4_bro_val = rng.random((1, 1, 1, 3)).astype(config.floatX) - self._compile_and_check( - [adtens4_bro], - [Rebroadcast(*bdict)(adtens4_bro)], - [adtens4_bro_val], - Rebroadcast, - ) - - def test_len(): for shape_ in [(5,), (3, 4), (7, 4, 6)]: x = tensor(dtype="floatX", shape=(False,) * len(shape_)) diff --git a/tests/tensor/test_basic_opt.py b/tests/tensor/test_basic_opt.py index 6fba0c788c..d8e07d0a06 100644 --- a/tests/tensor/test_basic_opt.py +++ b/tests/tensor/test_basic_opt.py @@ -28,7 +28,6 @@ Alloc, Join, MakeVector, - Rebroadcast, ScalarFromTensor, Split, TensorFromScalar, @@ -40,7 +39,6 @@ ) from aesara.tensor.basic_opt import ( ShapeFeature, - apply_rebroadcast_opt, assert_op, local_alloc_sink_dimshuffle, local_dimshuffle_lift, @@ -92,9 +90,11 @@ Reshape, Shape_i, SpecifyShape, + Unbroadcast, reshape, shape, specify_shape, + unbroadcast, ) from aesara.tensor.subtensor import ( AdvancedIncSubtensor1, @@ -1898,18 +1898,46 @@ def test_local_useless_tile(self): f(data) -class TestRebroadcast: - def test_local_useless_rebroadcast(self): - mode = get_default_mode().including("canonicalize") - v1 = vector() - v2 = vector() - j = at.join(0, v1, v2) - f = function([v1, v2], j, mode=mode) - f([1, 2], [3, 4, 5]) - e = f.maker.fgraph.toposort() - assert len([n for n in e if isinstance(n.op, Rebroadcast)]) == 0 +class TestUnbroadcast: + def setup_method(self): + self.mode = get_default_mode().including("canonicalize") + + def test_local_useless_unbroadcast(self): + x1 = tensor("float64", shape=(1, 2)) + x2 = tensor("float64", shape=(2, 1)) + unbroadcast_op = Unbroadcast(0) + + f = function([x1], unbroadcast_op(x1), mode=self.mode) + assert ( + sum(isinstance(node.op, Unbroadcast) for node in f.maker.fgraph.toposort()) + == 1 + ) + + f = function([x2], unbroadcast_op(x2), mode=self.mode) + assert ( + sum(isinstance(node.op, Unbroadcast) for node in f.maker.fgraph.toposort()) + == 0 + ) + + def test_local_unbroadcast_lift(self): + x = tensor("float64", shape=(1, 1)) + y = unbroadcast(at.exp(unbroadcast(x, 0)), 1) + + assert ( + sum( + isinstance(node.op, Unbroadcast) + for node in FunctionGraph([x], [y], copy_inputs=False).toposort() + ) + == 2 + ) + + f = function([x], y, mode=self.mode) + assert ( + sum(isinstance(node.op, Unbroadcast) for node in f.maker.fgraph.toposort()) + == 1 + ) - assert check_stack_trace(f, ops_to_check="all") + np.testing.assert_almost_equal(f([[1]]), np.exp([[1]])) class TestUselessElemwise: @@ -3167,21 +3195,6 @@ def test_local_useless_alloc(): assert isinstance(topo[-1].op, Alloc) -def test_apply_rebroadcast_opt(): - # Test the `Elemwise` case in `local_rebroadcast_lift` with `fgraph=None`. - # This is called by in `apply_rebroadcast_opt`. - a = vector(dtype="float32") - b = tensor("float64", [True]) - x = b.astype(a.dtype) - - broadcastable = (False,) - axis = [(i, broadcastable[i]) for i in range(len(broadcastable))] - rval = Rebroadcast(*axis)(x) - - res = apply_rebroadcast_opt(rval) - assert res is rval - - @pytest.mark.parametrize("return_index", [False]) @pytest.mark.parametrize("return_counts", [False]) @pytest.mark.parametrize("return_inverse", [False]) diff --git a/tests/tensor/test_shape.py b/tests/tensor/test_shape.py index ee866a28cd..153ea2ef49 100644 --- a/tests/tensor/test_shape.py +++ b/tests/tensor/test_shape.py @@ -17,12 +17,14 @@ Reshape, Shape_i, SpecifyShape, + Unbroadcast, _specify_shape, reshape, shape, shape_i, specify_broadcastable, specify_shape, + unbroadcast, ) from aesara.tensor.subtensor import Subtensor from aesara.tensor.type import ( @@ -36,6 +38,7 @@ lscalar, matrix, scalar, + tensor, tensor3, vector, ) @@ -594,3 +597,63 @@ def test_get_vector_length(): # Test `SpecifyShape` x = specify_shape(ivector(), (10,)) assert get_vector_length(x) == 10 + + +class TestUnbroadcast: + def test_basic(self): + x = matrix() + assert unbroadcast(x, 0) is x + assert unbroadcast(x, 1) is x + assert unbroadcast(x, 1, 0) is x + assert unbroadcast(x, 0, 1) is x + + x = row() + assert unbroadcast(x, 0) is not x + assert unbroadcast(x, 1) is x + assert unbroadcast(x, 1, 0) is not x + assert unbroadcast(x, 0, 1) is not x + + assert unbroadcast(unbroadcast(x, 0), 0).owner.inputs[0] is x + + def test_infer_shape(self): + x = matrix() + y = unbroadcast(x, 0) + f = aesara.function([x], y.shape) + assert (f(np.zeros((2, 5), dtype=config.floatX)) == [2, 5]).all() + topo = f.maker.fgraph.toposort() + if config.mode != "FAST_COMPILE": + assert len(topo) == 3 + assert isinstance(topo[0].op, Shape_i) + assert isinstance(topo[1].op, Shape_i) + assert isinstance(topo[2].op, MakeVector) + + x = row() + y = unbroadcast(x, 0) + f = aesara.function([x], y.shape) + assert (f(np.zeros((1, 5), dtype=config.floatX)) == [1, 5]).all() + topo = f.maker.fgraph.toposort() + if config.mode != "FAST_COMPILE": + assert len(topo) == 2 + assert isinstance(topo[0].op, Shape_i) + assert isinstance(topo[1].op, MakeVector) + + def test_error_checks(self): + with pytest.raises(TypeError, match="needs integer axes"): + Unbroadcast(0.0) + + with pytest.raises(ValueError, match="^Trying to unbroadcast"): + Unbroadcast(1)(vector()) + + +class TestUnbroadcastInferShape(utt.InferShapeTester): + def test_basic(self): + rng = np.random.default_rng(3453) + adtens4 = tensor("float64", shape=(1, 1, 1, None)) + adtens4_val = rng.random((1, 1, 1, 3)).astype(config.floatX) + self._compile_and_check( + [adtens4], + [Unbroadcast(0, 2)(adtens4)], + [adtens4_val], + Unbroadcast, + warn=False, + ) diff --git a/tests/tensor/test_subtensor_opt.py b/tests/tensor/test_subtensor_opt.py index 3116481ff9..0d7dd1ffe1 100644 --- a/tests/tensor/test_subtensor_opt.py +++ b/tests/tensor/test_subtensor_opt.py @@ -16,16 +16,10 @@ from aesara.graph.type import Type from aesara.raise_op import Assert from aesara.tensor import inplace -from aesara.tensor.basic import ( - Alloc, - MakeVector, - Rebroadcast, - _convert_to_int8, - make_vector, -) +from aesara.tensor.basic import Alloc, MakeVector, _convert_to_int8, make_vector from aesara.tensor.elemwise import DimShuffle, Elemwise from aesara.tensor.math import Dot, add, dot, exp, sqr -from aesara.tensor.shape import SpecifyShape, _shape, shape, specify_shape +from aesara.tensor.shape import SpecifyShape, Unbroadcast, _shape, shape, specify_shape from aesara.tensor.subtensor import ( AdvancedIncSubtensor, AdvancedIncSubtensor1, @@ -843,61 +837,61 @@ def test_basic_7(self): f([1, 2, 3], 4) # let debugmode test something def test_basic_8(self): - # Test that Subtensor(Rebroadcast(x)) gets optimized into - # Rebroadcast(Subtensor(x)). + # Test that Subtensor(Unbroadcast(x)) gets optimized into + # Unbroadcast(Subtensor(x)). # test basic case - x = matrix("x") + x = row("x") xval = np.random.random((1, 10)).astype(config.floatX) - assert x.broadcastable == (False, False) - newx = Rebroadcast((0, True), (1, False))(x) - assert newx.broadcastable == (True, False) + assert x.broadcastable == (True, False) + newx = Unbroadcast(0)(x) + assert newx.broadcastable == (False, False) f1 = function([x], newx[:2, :5], mode=mode_opt) # Check stacktrace was copied over correctly after opt was applied - assert check_stack_trace(f1, ops_to_check=[Subtensor, Rebroadcast]) + assert check_stack_trace(f1, ops_to_check=[Subtensor, Unbroadcast]) prog = f1.maker.fgraph.toposort() assert isinstance(prog[0].op, Subtensor) - assert isinstance(prog[1].op, Rebroadcast) + assert isinstance(prog[1].op, Unbroadcast) assert (f1(xval) == xval[:2, :5]).all() - # corner case 1: rebroadcast changes dims which are dropped through subtensor - y = tensor4("x") + # corner case 1: Unbroadcast changes dims which are dropped through subtensor + y = tensor("float64", shape=(1, 10, 1, 3), name="x") yval = np.random.random((1, 10, 1, 3)).astype(config.floatX) - assert y.broadcastable == (False, False, False, False) - newy = Rebroadcast((0, True), (2, True))(y) - assert newy.broadcastable == (True, False, True, False) + assert y.broadcastable == (True, False, True, False) + newy = Unbroadcast(0, 2)(y) + assert newy.broadcastable == (False, False, False, False) f2 = function([y], newy[:, 3, 0, :], mode=mode_opt) # Check stacktrace was copied over correctly after opt was applied - assert check_stack_trace(f2, ops_to_check=[Subtensor, Rebroadcast]) + assert check_stack_trace(f2, ops_to_check=[Subtensor, Unbroadcast]) prog = f2.maker.fgraph.toposort() assert isinstance(prog[0].op, Subtensor) - assert isinstance(prog[1].op, Rebroadcast) + assert isinstance(prog[1].op, Unbroadcast) assert (f2(yval) == yval[:, 3, 0, :]).all() # corner case 2: subtensor idx_list is shorter than resulting broadcast pattern f3 = function([y], newy[:, 3, 0], mode=mode_opt) # Check stacktrace was copied over correctly after opt was applied - assert check_stack_trace(f3, ops_to_check=[Subtensor, Rebroadcast]) + assert check_stack_trace(f3, ops_to_check=[Subtensor, Unbroadcast]) prog = f3.maker.fgraph.toposort() assert isinstance(prog[0].op, Subtensor) - assert isinstance(prog[1].op, Rebroadcast) + assert isinstance(prog[1].op, Unbroadcast) assert (f3(yval) == yval[:, 3, 0]).all() - # corner case 3: subtensor idx_list is shorter than rebroadcast.axis - z = tensor4("x") + # corner case 3: subtensor idx_list is shorter than Unbroadcast.axis + z = tensor("float64", shape=(4, 10, 3, 1), name="x") zval = np.random.random((4, 10, 3, 1)).astype(config.floatX) - assert z.broadcastable == (False, False, False, False) - newz = Rebroadcast((3, True))(z) - assert newz.broadcastable == (False, False, False, True) + assert z.broadcastable == (False, False, False, True) + newz = Unbroadcast(3)(z) + assert newz.broadcastable == (False, False, False, False) f4 = function([z], newz[:, 3, 0], mode=mode_opt) # Check stacktrace was copied over correctly after opt was applied - assert check_stack_trace(f4, ops_to_check=[Subtensor, Rebroadcast]) + assert check_stack_trace(f4, ops_to_check=[Subtensor, Unbroadcast]) prog = f4.maker.fgraph.toposort() assert isinstance(prog[0].op, Subtensor) - assert isinstance(prog[1].op, Rebroadcast) + assert isinstance(prog[1].op, Unbroadcast) assert (f4(zval) == zval[:, 3, 0]).all() diff --git a/tests/test_rop.py b/tests/test_rop.py index 8b95d13122..d1f85307f6 100644 --- a/tests/test_rop.py +++ b/tests/test_rop.py @@ -26,6 +26,7 @@ from aesara.tensor.math import argmax, dot from aesara.tensor.math import max as at_max from aesara.tensor.nnet import conv, conv2d +from aesara.tensor.shape import unbroadcast from aesara.tensor.signal.pool import Pool from aesara.tensor.type import TensorType, matrix, vector from tests import unittest_tools as utt @@ -237,11 +238,11 @@ def test_dimshuffle(self): # vector self.check_rop_lop(self.x[:4].dimshuffle("x", 0).sum(axis=0), (4,)) - def test_rebroadcast(self): + def test_unbroadcast(self): # I need the sum, because the setup expects the output to be a # vector self.check_rop_lop( - at.unbroadcast(self.x[:4].dimshuffle("x", 0), 0).sum(axis=1), (1,) + unbroadcast(self.x[:4].dimshuffle("x", 0), 0).sum(axis=1), (1,) ) @pytest.mark.slow