Skip to content

Commit

Permalink
fix: torch frontend max/min to support dim and keepdim as arg or kwarg (
Browse files Browse the repository at this point in the history
  • Loading branch information
Sam-Armstrong authored Mar 2, 2024
1 parent f67a128 commit aa46a0e
Showing 1 changed file with 26 additions and 2 deletions.
28 changes: 26 additions & 2 deletions ivy/functional/frontends/torch/reduction_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,25 @@ def logsumexp(input, dim, keepdim=False, *, out=None):

@numpy_to_torch_style_args
@to_ivy_arrays_and_back
@with_unsupported_dtypes(
{"2.2 and below": ("complex64", "complex128")},
"torch",
)
def max(*input, dim=None, keepdim=False, out=None):
if len(input) == 1:
input = input[0]
elif len(input) == 2:
return torch_frontend.maximum(*input)
input_0 = input[0]
input_1 = input[1]
if ivy.is_array(input_1):
return torch_frontend.maximum(*input)
else:
input = input_0
dim = input_1
else:
input = input[0]
dim = input[1]
keepdim = input[2]
if dim is None:
return ivy.max(input, axis=dim, keepdims=keepdim, out=out)
elif out is not None:
Expand Down Expand Up @@ -173,7 +187,17 @@ def min(*input, dim=None, keepdim=False, out=None):
if len(input) == 1:
input = input[0]
elif len(input) == 2:
return torch_frontend.minimum(*input)
input_0 = input[0]
input_1 = input[1]
if ivy.is_array(input_1):
return torch_frontend.minimum(*input)
else:
input = input_0
dim = input_1
else:
input = input[0]
dim = input[1]
keepdim = input[2]
if dim is None:
return ivy.min(input, axis=dim, keepdims=keepdim, out=out)
elif out is not None:
Expand Down

0 comments on commit aa46a0e

Please sign in to comment.