From aa46a0eb96d578cf83b1fe1e637bd1fb4652d315 Mon Sep 17 00:00:00 2001 From: Sam Armstrong <88863522+Sam-Armstrong@users.noreply.github.com> Date: Sat, 2 Mar 2024 00:30:47 +0000 Subject: [PATCH] fix: torch frontend max/min to support dim and keepdim as arg or kwarg (#28469) --- .../frontends/torch/reduction_ops.py | 28 +++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/ivy/functional/frontends/torch/reduction_ops.py b/ivy/functional/frontends/torch/reduction_ops.py index 03d6ec6d91fb8..564bbca97f9cc 100644 --- a/ivy/functional/frontends/torch/reduction_ops.py +++ b/ivy/functional/frontends/torch/reduction_ops.py @@ -100,11 +100,25 @@ def logsumexp(input, dim, keepdim=False, *, out=None): @numpy_to_torch_style_args @to_ivy_arrays_and_back +@with_unsupported_dtypes( + {"2.2 and below": ("complex64", "complex128")}, + "torch", +) def max(*input, dim=None, keepdim=False, out=None): if len(input) == 1: input = input[0] elif len(input) == 2: - return torch_frontend.maximum(*input) + input_0 = input[0] + input_1 = input[1] + if ivy.is_array(input_1): + return torch_frontend.maximum(*input) + else: + input = input_0 + dim = input_1 + else: + input = input[0] + dim = input[1] + keepdim = input[2] if dim is None: return ivy.max(input, axis=dim, keepdims=keepdim, out=out) elif out is not None: @@ -173,7 +187,17 @@ def min(*input, dim=None, keepdim=False, out=None): if len(input) == 1: input = input[0] elif len(input) == 2: - return torch_frontend.minimum(*input) + input_0 = input[0] + input_1 = input[1] + if ivy.is_array(input_1): + return torch_frontend.minimum(*input) + else: + input = input_0 + dim = input_1 + else: + input = input[0] + dim = input[1] + keepdim = input[2] if dim is None: return ivy.min(input, axis=dim, keepdims=keepdim, out=out) elif out is not None: