Skip to content

Commit

Permalink
fix: avoid shape mismatches in certain cases by axes for max and the …
Browse files Browse the repository at this point in the history
…value updating loop
  • Loading branch information
Ishticode committed Feb 19, 2024
1 parent e272942 commit c07ced7
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions ivy/functional/ivy/experimental/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2257,7 +2257,7 @@ def adaptive_max_pool3d(
)

if not (adaptive_d or adaptive_h or adaptive_w):
ret = ivy.max(vals, axis=(-4, -3, -1))
ret = ivy.max(vals, axis=(-3, -1))
ret = ivy.squeeze(ret, axis=0) if squeeze else ret
return ret

Expand All @@ -2273,14 +2273,13 @@ def adaptive_max_pool3d(

ret = None
for i, j, k in itertools.product(
range(vals.shape[-3]), range(vals.shape[-2]), range(vals.shape[-1])
range(vals.shape[-4]), range(vals.shape[-2]), range(vals.shape[-1])
):
if ret is None:
ret = vals[..., :, i, j, k]
ret = vals[..., i, :, j, k]
else:
ret = ivy.maximum(ret, vals[..., :, i, j, k])
ret = ivy.maximum(ret, vals[..., i, :, j, k])
pooled_output = ret.astype(vals.dtype)

pooled_output = ivy.squeeze(pooled_output, axis=0) if squeeze else pooled_output
return pooled_output

Expand Down

0 comments on commit c07ced7

Please sign in to comment.