Skip to content

Commit

Permalink
fix(ivy): Extends ivy.pow to work for all input and exponent cases
Browse files Browse the repository at this point in the history
  • Loading branch information
AnnaTz committed Sep 11, 2023
1 parent e227737 commit f713068
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 25 deletions.
4 changes: 3 additions & 1 deletion ivy/functional/backends/numpy/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,7 +611,9 @@ def pow(
out: Optional[np.ndarray] = None,
) -> np.ndarray:
x1, x2 = ivy.promote_types_of_inputs(x1, x2)
return np.power(x1, x2, out=out)
if ivy.is_int_dtype(x1) and ivy.any(x2 < 0):
return np.float_power(x1, x2, casting='unsafe').astype(x1.dtype)
return np.power(x1, x2)


pow.support_native_out = True
Expand Down
9 changes: 8 additions & 1 deletion ivy/functional/backends/tensorflow/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,14 @@ def pow(
if x2.dtype.is_unsigned:
x2 = tf.cast(x2, tf.float64)
return tf.cast(tf.experimental.numpy.power(x1, x2), promoted_type)
return tf.experimental.numpy.power(x1, x2)
orig_x1_dtype = None
if ivy.is_int_dtype(x1) and ivy.any(x2 < 0):
orig_x1_dtype = x1.dtype
x1 = tf.cast(x1, tf.float32)
ret = tf.experimental.numpy.power(x1, x2)
if orig_x1_dtype is not None:
return tf.cast(ret, orig_x1_dtype)
return ret


@with_unsupported_dtypes({"2.13.0 and below": ("bfloat16", "complex")}, backend_version)
Expand Down
7 changes: 0 additions & 7 deletions ivy/functional/ivy/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -5201,13 +5201,6 @@ def pow(
(the exponent), where ``x2_i`` is the corresponding element of the input array
``x2``.
.. note::
If both ``x1`` and ``x2`` have integer data types, the result of ``pow`` when
``x2_i`` is negative (i.e., less than zero) is unspecified and thus
implementation-dependent. If ``x1`` has an integer data type and ``x2`` has a
floating-point data type, behavior is implementation-dependent (type promotion
between data type "kinds" (integer versus floating-point) is unspecified).
**Special cases**
For floating-point operands,
Expand Down
17 changes: 1 addition & 16 deletions ivy_tests/test_ivy/test_functional/test_core/test_elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,6 @@ def cast_filter(dtype1_x1_dtype2):
)
)
dtype2 = dtype2[0]
if "int" in dtype2:
x2 = ivy.nested_map(
x2[0], lambda x: abs(x), include_derived={"list": True}, shallow=False
)
return [dtype1, dtype2], [x1, x2]


Expand Down Expand Up @@ -1593,19 +1589,8 @@ def test_positive(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device):
)
def test_pow(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device):
input_dtype, x = dtype_and_x

# bfloat16 is not supported by numpy
assume("bfloat16" not in input_dtype)

# Make sure x2 isn't a float when x1 is integer
assume(
not (ivy.is_int_dtype(input_dtype[0] and ivy.is_float_dtype(input_dtype[1])))
)

# Make sure x2 is non-negative when both is integer
if ivy.is_int_dtype(input_dtype[1]) and ivy.is_int_dtype(input_dtype[0]):
x[1] = np.abs(x[1])

assume(not ("bfloat16" in input_dtype))
x[0] = not_too_close_to_zero(x[0])
x[1] = not_too_close_to_zero(x[1])
helpers.test_function(
Expand Down

0 comments on commit f713068

Please sign in to comment.