Skip to content

Commit

Permalink
Add tensor conversion to flatnonzero, nonzero_values, tile, `in…
Browse files Browse the repository at this point in the history
…verse_permutation`, and `diag`
  • Loading branch information
ltoniazzi authored Sep 2, 2022
1 parent c2ed818 commit e40c827
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 21 deletions.
44 changes: 25 additions & 19 deletions aesara/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -940,9 +940,10 @@ def flatnonzero(a):
nonzero_values : Return the non-zero elements of the input array
"""
if a.ndim == 0:
_a = as_tensor_variable(a)
if _a.ndim == 0:
raise ValueError("Nonzero only supports non-scalar arrays.")
return nonzero(a.flatten(), return_matrix=False)[0]
return nonzero(_a.flatten(), return_matrix=False)[0]


def nonzero_values(a):
Expand Down Expand Up @@ -1324,9 +1325,10 @@ def identity_like(x, dtype: Optional[Union[str, np.generic, np.dtype]] = None):
tensor
tensor the shape of x with ones on main diagonal and zeroes elsewhere of type of dtype.
"""
_x = as_tensor_variable(x)
if dtype is None:
dtype = x.dtype
return eye(x.shape[0], x.shape[1], k=0, dtype=dtype)
dtype = _x.dtype
return eye(_x.shape[0], _x.shape[1], k=0, dtype=dtype)


def infer_broadcastable(shape):
Expand Down Expand Up @@ -2773,8 +2775,9 @@ def tile(x, reps, ndim=None):
"""
from aesara.tensor.math import ge

if ndim is not None and ndim < x.ndim:
raise ValueError("ndim should be equal or larger than x.ndim")
_x = as_tensor_variable(x)
if ndim is not None and ndim < _x.ndim:
raise ValueError("ndim should be equal or larger than _x.ndim")

# If reps is a scalar, integer or vector, we convert it to a list.
if not isinstance(reps, (list, tuple)):
Expand All @@ -2799,8 +2802,8 @@ def tile(x, reps, ndim=None):
# assert that reps.shape[0] does not exceed ndim
offset = assert_op(offset, ge(offset, 0))

# if reps.ndim is less than x.ndim, we pad the reps with
# "1" so that reps will have the same ndim as x.
# if reps.ndim is less than _x.ndim, we pad the reps with
# "1" so that reps will have the same ndim as _x.
reps_ = [switch(i < offset, 1, reps[i - offset]) for i in range(ndim)]
reps = reps_

Expand All @@ -2817,17 +2820,17 @@ def tile(x, reps, ndim=None):
):
raise ValueError("elements of reps must be scalars of integer dtype")

# If reps.ndim is less than x.ndim, we pad the reps with
# "1" so that reps will have the same ndim as x
# If reps.ndim is less than _x.ndim, we pad the reps with
# "1" so that reps will have the same ndim as _x
reps = list(reps)
if ndim is None:
ndim = builtins.max(len(reps), x.ndim)
ndim = builtins.max(len(reps), _x.ndim)
if len(reps) < ndim:
reps = [1] * (ndim - len(reps)) + reps

_shape = [1] * (ndim - x.ndim) + [x.shape[i] for i in range(x.ndim)]
_shape = [1] * (ndim - _x.ndim) + [_x.shape[i] for i in range(_x.ndim)]
alloc_shape = reps + _shape
y = alloc(x, *alloc_shape)
y = alloc(_x, *alloc_shape)
shuffle_ind = np.arange(ndim * 2).reshape(2, ndim)
shuffle_ind = shuffle_ind.transpose().flatten()
y = y.dimshuffle(*shuffle_ind)
Expand Down Expand Up @@ -3288,8 +3291,9 @@ def inverse_permutation(perm):
Each row of input should contain a permutation of the first integers.
"""
_perm = as_tensor_variable(perm)
return permute_row_elements(
arange(perm.shape[-1], dtype=perm.dtype), perm, inverse=True
arange(_perm.shape[-1], dtype=_perm.dtype), _perm, inverse=True
)


Expand Down Expand Up @@ -3575,12 +3579,14 @@ def diag(v, k=0):
"""

if v.ndim == 1:
return AllocDiag(k)(v)
elif v.ndim >= 2:
return diagonal(v, offset=k)
_v = as_tensor_variable(v)

if _v.ndim == 1:
return AllocDiag(k)(_v)
elif _v.ndim >= 2:
return diagonal(_v, offset=k)
else:
raise ValueError("Input must has v.ndim >= 1.")
raise ValueError("Number of dimensions of `v` must be greater than one.")


def stacklists(arg):
Expand Down
40 changes: 38 additions & 2 deletions tests/tensor/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1023,6 +1023,12 @@ def check(m):
rand2d[:4] = 0
check(rand2d)

# Test passing a list
m = [1, 2, 0]
out = flatnonzero(m)
f = function([], out)
assert np.array_equal(f(), np.flatnonzero(m))

@config.change_flags(compute_test_value="raise")
def test_nonzero_values(self):
def check(m):
Expand Down Expand Up @@ -1449,8 +1455,6 @@ def test_roll(self):

assert (out == want).all()

# Pass a list to make sure `a` is converted to a
# TensorVariable by roll
a = [1, 2, 3, 4, 5, 6]
b = roll(a, get_shift(2))
want = np.array([5, 6, 1, 2, 3, 4])
Expand Down Expand Up @@ -2221,6 +2225,20 @@ def run_tile(x, x_, reps, use_symbolic_reps):
== np.tile(x_, (2, 3, 4, 6))
)

# Test passing a float
x = scalar()
x_val = 1.0
assert np.array_equal(
run_tile(x, x_val, (2,), use_symbolic_reps), np.tile(x_val, (2,))
)

# Test when x is a list
x = matrix()
x_val = [[1.0, 2.0], [3.0, 4.0]]
assert np.array_equal(
run_tile(x, x_val, (2,), use_symbolic_reps), np.tile(x_val, (2,))
)

# Test when reps is integer, scalar or vector.
# Test 1,2,3,4-dimensional cases.
# Test input x has the shape [2], [2, 4], [2, 4, 3], [2, 4, 3, 5].
Expand Down Expand Up @@ -2794,6 +2812,12 @@ def test_dim1(self):
assert np.all(p_val[inv_val] == np.arange(10))
assert np.all(inv_val[p_val] == np.arange(10))

# Test passing a list
p = [2, 4, 3, 0, 1]
inv = at.inverse_permutation(p)
f = aesara.function([], inv)
assert np.array_equal(f(), np.array([3, 4, 0, 2, 1]))

def test_dim2(self):
# Test the inversion of several permutations at a time
# Each row of p is a different permutation to inverse
Expand Down Expand Up @@ -3449,6 +3473,12 @@ def test_diag(self):
with pytest.raises(ValueError):
diag(xx)

# Test passing a list
xx = [[1, 2], [3, 4]]
g = diag(xx)
f = function([], g)
assert np.array_equal(f(), np.diag(xx))

def test_infer_shape(self):
rng = np.random.default_rng(utt.fetch_seed())

Expand Down Expand Up @@ -4136,6 +4166,12 @@ def test_identity_like_dtype():
m_out_float = identity_like(m, dtype=np.float64)
assert m_out_float.dtype == "float64"

# Test passing list
m = [[0, 1], [1, 3]]
out = at.identity_like(m)
f = aesara.function([], out)
assert np.array_equal(f(), np.eye(2))


def test_atleast_Nd():
ary1 = dscalar()
Expand Down

0 comments on commit e40c827

Please sign in to comment.