Skip to content

Commit

Permalink
commented out debug commentts
Browse files Browse the repository at this point in the history
  • Loading branch information
Virginia Adams committed Feb 6, 2024
1 parent cf8e873 commit c5df598
Showing 1 changed file with 13 additions and 13 deletions.
26 changes: 13 additions & 13 deletions src/levanter/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,15 +94,15 @@ def dot_product_attention(
k_ = haliax.rearrange(key, ("batch", "key_position", "heads", "head_size")).array
v_ = haliax.rearrange(value, ("batch", "key_position", "heads", "head_size")).array

print(f"\nShape of q_ is now {q_.shape}\n")
# print(f"\nShape of q_ is now {q_.shape}\n")

qkv = jnp.stack((q_, k_, v_), axis=2)

print(f"\nQKV shape is now: {qkv.shape}")
# print(f"\nQKV shape is now: {qkv.shape}")

# scaling_factor = jax.lax.rsqrt(float(query.axis_size(Key)))
scaling_factor = 1.0
print(f"\n\nScaling factor: {scaling_factor}\n\n")
# print(f"\n\nScaling factor: {scaling_factor}\n\n")
is_training = not inference

# TODO: bias type is probably also configurable
Expand All @@ -113,7 +113,7 @@ def dot_product_attention(

# TODO: We have a mask type we can use to configure this
attn_mask_type = AttnMaskType.CAUSAL_MASK
print(f"\nQuery Shape: {query.shape}\n")
# print(f"\nQuery Shape: {query.shape}\n")
batch_size = query.shape["batch"]
seq_len = query.shape["position"]
mask = jnp.tril(jnp.ones((batch_size, seq_len, seq_len)))
Expand All @@ -122,15 +122,15 @@ def dot_product_attention(
# if mask:
# fused_attn_mask = mask.array

print("\n\n")
print(f"qkv type: {type(qkv)}")
print(f"bias type: {type(bias)}")
print(f"mask type: {type(mask)}")
print(f"seed type: {type(prng)}")
print(f"attn_bias_Type type: {type(attn_bias_type)}")
print(f"attn_mask_type type: {type(attn_mask_type)}")
print(f"scaling factor type: {type(scaling_factor)}")
print("\n\n")
# print("\n\n")
# print(f"qkv type: {type(qkv)}")
# print(f"bias type: {type(bias)}")
# print(f"mask type: {type(mask)}")
# print(f"seed type: {type(prng)}")
# print(f"attn_bias_Type type: {type(attn_bias_type)}")
# print(f"attn_mask_type type: {type(attn_mask_type)}")
# print(f"scaling factor type: {type(scaling_factor)}")
# print("\n\n")

attn_output = self_fused_attn(
qkv=qkv, # jnp.ndarray,
Expand Down

0 comments on commit c5df598

Please sign in to comment.