diff --git a/src/levanter/models/attention.py b/src/levanter/models/attention.py index ff6ba8876..701e0ea09 100644 --- a/src/levanter/models/attention.py +++ b/src/levanter/models/attention.py @@ -94,15 +94,15 @@ def dot_product_attention( k_ = haliax.rearrange(key, ("batch", "key_position", "heads", "head_size")).array v_ = haliax.rearrange(value, ("batch", "key_position", "heads", "head_size")).array - print(f"\nShape of q_ is now {q_.shape}\n") + # print(f"\nShape of q_ is now {q_.shape}\n") qkv = jnp.stack((q_, k_, v_), axis=2) - print(f"\nQKV shape is now: {qkv.shape}") + # print(f"\nQKV shape is now: {qkv.shape}") # scaling_factor = jax.lax.rsqrt(float(query.axis_size(Key))) scaling_factor = 1.0 - print(f"\n\nScaling factor: {scaling_factor}\n\n") + # print(f"\n\nScaling factor: {scaling_factor}\n\n") is_training = not inference # TODO: bias type is probably also configurable @@ -113,7 +113,7 @@ def dot_product_attention( # TODO: We have a mask type we can use to configure this attn_mask_type = AttnMaskType.CAUSAL_MASK - print(f"\nQuery Shape: {query.shape}\n") + # print(f"\nQuery Shape: {query.shape}\n") batch_size = query.shape["batch"] seq_len = query.shape["position"] mask = jnp.tril(jnp.ones((batch_size, seq_len, seq_len))) @@ -122,15 +122,15 @@ def dot_product_attention( # if mask: # fused_attn_mask = mask.array - print("\n\n") - print(f"qkv type: {type(qkv)}") - print(f"bias type: {type(bias)}") - print(f"mask type: {type(mask)}") - print(f"seed type: {type(prng)}") - print(f"attn_bias_Type type: {type(attn_bias_type)}") - print(f"attn_mask_type type: {type(attn_mask_type)}") - print(f"scaling factor type: {type(scaling_factor)}") - print("\n\n") + # print("\n\n") + # print(f"qkv type: {type(qkv)}") + # print(f"bias type: {type(bias)}") + # print(f"mask type: {type(mask)}") + # print(f"seed type: {type(prng)}") + # print(f"attn_bias_Type type: {type(attn_bias_type)}") + # print(f"attn_mask_type type: {type(attn_mask_type)}") + # print(f"scaling factor type: {type(scaling_factor)}") + # print("\n\n") attn_output = self_fused_attn( qkv=qkv, # jnp.ndarray,