diff --git a/src/levanter/models/flash_attention.py b/src/levanter/models/flash_attention.py index b1adea0b3..000de57ee 100644 --- a/src/levanter/models/flash_attention.py +++ b/src/levanter/models/flash_attention.py @@ -18,7 +18,7 @@ from levanter.models.attention import AttentionMask, materialize_mask -# TODO: tune +# TODO: Tune BLOCK_SIZE = 128 diff --git a/src/levanter/models/gpt2.py b/src/levanter/models/gpt2.py index 475483b20..5822fab30 100644 --- a/src/levanter/models/gpt2.py +++ b/src/levanter/models/gpt2.py @@ -195,7 +195,6 @@ def __call__(self, x: NamedArray, mask: Optional[AttentionMask | NamedArray], la prng=k_drop, attention_dtype=jnp.float32 if self.config.upcast_attn else None, ) - print(f"\n\nATTENTION OUTPUT: {attn_output}\n\n") attn_output = self.c_proj(attn_output, key=k_out) if self.config.upcast_attn: