From f713068bd05ec2ec5ae573bff2d622ac58ef568b Mon Sep 17 00:00:00 2001 From: AnnaTz <111577222+AnnaTz@users.noreply.github.com> Date: Mon, 11 Sep 2023 13:54:14 +0100 Subject: [PATCH] fix(ivy): Extends ivy.pow to work for all input and exponent cases --- ivy/functional/backends/numpy/elementwise.py | 4 +++- .../backends/tensorflow/elementwise.py | 9 ++++++++- ivy/functional/ivy/elementwise.py | 7 ------- .../test_core/test_elementwise.py | 17 +---------------- 4 files changed, 12 insertions(+), 25 deletions(-) diff --git a/ivy/functional/backends/numpy/elementwise.py b/ivy/functional/backends/numpy/elementwise.py index 6c4459bbaaf4f..654c37a673245 100644 --- a/ivy/functional/backends/numpy/elementwise.py +++ b/ivy/functional/backends/numpy/elementwise.py @@ -611,7 +611,9 @@ def pow( out: Optional[np.ndarray] = None, ) -> np.ndarray: x1, x2 = ivy.promote_types_of_inputs(x1, x2) - return np.power(x1, x2, out=out) + if ivy.is_int_dtype(x1) and ivy.any(x2 < 0): + return np.float_power(x1, x2, casting='unsafe').astype(x1.dtype) + return np.power(x1, x2) pow.support_native_out = True diff --git a/ivy/functional/backends/tensorflow/elementwise.py b/ivy/functional/backends/tensorflow/elementwise.py index 1e060e8e8cae3..a8dc01ba58d3f 100644 --- a/ivy/functional/backends/tensorflow/elementwise.py +++ b/ivy/functional/backends/tensorflow/elementwise.py @@ -620,7 +620,14 @@ def pow( if x2.dtype.is_unsigned: x2 = tf.cast(x2, tf.float64) return tf.cast(tf.experimental.numpy.power(x1, x2), promoted_type) - return tf.experimental.numpy.power(x1, x2) + orig_x1_dtype = None + if ivy.is_int_dtype(x1) and ivy.any(x2 < 0): + orig_x1_dtype = x1.dtype + x1 = tf.cast(x1, tf.float32) + ret = tf.experimental.numpy.power(x1, x2) + if orig_x1_dtype is not None: + return tf.cast(ret, orig_x1_dtype) + return ret @with_unsupported_dtypes({"2.13.0 and below": ("bfloat16", "complex")}, backend_version) diff --git a/ivy/functional/ivy/elementwise.py b/ivy/functional/ivy/elementwise.py index 3d9eebc0aa71e..7909bb56cecd3 100644 --- a/ivy/functional/ivy/elementwise.py +++ b/ivy/functional/ivy/elementwise.py @@ -5201,13 +5201,6 @@ def pow( (the exponent), where ``x2_i`` is the corresponding element of the input array ``x2``. - .. note:: - If both ``x1`` and ``x2`` have integer data types, the result of ``pow`` when - ``x2_i`` is negative (i.e., less than zero) is unspecified and thus - implementation-dependent. If ``x1`` has an integer data type and ``x2`` has a - floating-point data type, behavior is implementation-dependent (type promotion - between data type "kinds" (integer versus floating-point) is unspecified). - **Special cases** For floating-point operands, diff --git a/ivy_tests/test_ivy/test_functional/test_core/test_elementwise.py b/ivy_tests/test_ivy/test_functional/test_core/test_elementwise.py index 953b621b1d70f..447e5d0a34448 100644 --- a/ivy_tests/test_ivy/test_functional/test_core/test_elementwise.py +++ b/ivy_tests/test_ivy/test_functional/test_core/test_elementwise.py @@ -136,10 +136,6 @@ def cast_filter(dtype1_x1_dtype2): ) ) dtype2 = dtype2[0] - if "int" in dtype2: - x2 = ivy.nested_map( - x2[0], lambda x: abs(x), include_derived={"list": True}, shallow=False - ) return [dtype1, dtype2], [x1, x2] @@ -1593,19 +1589,8 @@ def test_positive(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): ) def test_pow(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): input_dtype, x = dtype_and_x - # bfloat16 is not supported by numpy - assume("bfloat16" not in input_dtype) - - # Make sure x2 isn't a float when x1 is integer - assume( - not (ivy.is_int_dtype(input_dtype[0] and ivy.is_float_dtype(input_dtype[1]))) - ) - - # Make sure x2 is non-negative when both is integer - if ivy.is_int_dtype(input_dtype[1]) and ivy.is_int_dtype(input_dtype[0]): - x[1] = np.abs(x[1]) - + assume(not ("bfloat16" in input_dtype)) x[0] = not_too_close_to_zero(x[0]) x[1] = not_too_close_to_zero(x[1]) helpers.test_function(