diff --git a/sharktank/sharktank/layers/mmdit.py b/sharktank/sharktank/layers/mmdit.py index 0c970ab35..1557883ae 100644 --- a/sharktank/sharktank/layers/mmdit.py +++ b/sharktank/sharktank/layers/mmdit.py @@ -41,7 +41,7 @@ def attention(q, k, v, pe): q=q, k=k, v=v, a=None, is_causal=True, scale=None ) x = ops.permute(x, (0, 2, 1, 3)) - x = x.view(x.shape[0], x.shape[1], -1) + x = x.reshape(x.shape[0], x.shape[1], -1) return x