Skip to content

Commit

Permalink
revert change on llama
Browse files Browse the repository at this point in the history
  • Loading branch information
Ivan-Zhou committed Feb 13, 2024
1 parent 8d82a9c commit b905c9f
Showing 1 changed file with 25 additions and 7 deletions.
32 changes: 25 additions & 7 deletions src/levanter/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ class LlamaConfig(HFCompatConfig):
intermediate_dim (int, optional): dimension of the intermediate state. Defaults to 11008.
num_layers (int, optional): number of hidden layers in the Transformer encoder. Defaults to 32.
num_heads (int, optional): number of attention heads for each attention layer. Defaults to 32.
num_kv_heads (int, optional): number of attention heads for keys and values in each attention layer.
Setting to 1 means MQA. Setting to num_heads means MHA. Otherwise GQA.
Note that num_heads must be divisible by this number. Defaults to 32.
activation_function (str, optional): activation function for the hidden layer. Defaults to "silu".
rope_scaling (Dict, optional): dict containing the scaling configuration for the Rotary Positional Embedding.
"""
Expand All @@ -56,6 +59,7 @@ class LlamaConfig(HFCompatConfig):
intermediate_dim: int = 11008
num_layers: int = 32
num_heads: int = 32
num_kv_heads: int = 32
activation_function: str = "silu"
initializer_range: float = 0.02
layer_norm_epsilon: float = 1e-5
Expand All @@ -76,10 +80,16 @@ class LlamaConfig(HFCompatConfig):
KeyPos = property(lambda self: self.Pos.alias("key_position"))
Embed = property(lambda self: Axis(name="embed", size=self.hidden_dim))
Heads = property(lambda self: Axis(name="heads", size=self.num_heads))
KVHeads = property(lambda self: Axis(name="kv_heads", size=self.num_kv_heads))
Layers = property(lambda self: Axis(name="layers", size=self.num_layers))
Mlp = property(lambda self: Axis(name="mlp", size=self.intermediate_dim))
HeadSize = property(lambda self: Axis(name="head_size", size=self.hidden_dim // self.num_heads))

def __post_init__(self):
assert (
self.num_heads % self.num_kv_heads == 0
), f"num_heads={self.num_heads} not divisible by num_kv_heads={self.num_kv_heads}."

@cached_classproperty
def default_hf_checkpoint_converter(cls) -> HFCheckpointConverter["LlamaConfig"]: # type: ignore
return HFCheckpointConverter(
Expand All @@ -98,6 +108,7 @@ def from_hf_config(cls, hf_config: HfConfig):
intermediate_dim=hf_config.intermediate_size,
num_layers=hf_config.num_hidden_layers,
num_heads=hf_config.num_attention_heads,
num_kv_heads=hf_config.num_key_value_heads,
activation_function=hf_config.hidden_act,
initializer_range=hf_config.initializer_range,
layer_norm_epsilon=hf_config.rms_norm_eps,
Expand All @@ -123,6 +134,7 @@ def to_hf_config(self, vocab_size: int, config_overrides: Optional[Dict] = None)
intermediate_size=self.intermediate_dim,
num_hidden_layers=self.num_layers,
num_attention_heads=self.num_heads,
num_key_value_heads=self.num_kv_heads,
hidden_act=self.activation_function,
initializer_range=self.initializer_range,
rms_norm_eps=self.layer_norm_epsilon,
Expand Down Expand Up @@ -264,10 +276,14 @@ class LlamaAttention(StateDictSerializationMixin, eqx.Module):
def init(config: LlamaConfig, *, key) -> "LlamaAttention":
use_bias = config.use_bias
Embed = config.Embed
QHeadsPerGroup = hax.Axis("q_heads_per_group", config.num_heads // config.num_kv_heads)

k_q, k_k, k_v, k_o = jrandom.split(key, 4)
q_proj = hnn.Linear.init(In=Embed, Out=(config.Heads, config.HeadSize), key=k_q, use_bias=use_bias)
k_proj = hnn.Linear.init(In=Embed, Out=(config.Heads, config.HeadSize), key=k_k, use_bias=use_bias)
v_proj = hnn.Linear.init(In=Embed, Out=(config.Heads, config.HeadSize), key=k_v, use_bias=use_bias)
q_proj = hnn.Linear.init(
In=Embed, Out=(config.KVHeads, QHeadsPerGroup, config.HeadSize), key=k_q, use_bias=use_bias
)
k_proj = hnn.Linear.init(In=Embed, Out=(config.KVHeads, config.HeadSize), key=k_k, use_bias=use_bias)
v_proj = hnn.Linear.init(In=Embed, Out=(config.KVHeads, config.HeadSize), key=k_v, use_bias=use_bias)
o_proj = hnn.Linear.init(In=(config.Heads, config.HeadSize), Out=Embed, key=k_o, use_bias=use_bias)
rotary_emb = LlamaRotaryEmbedding(config.HeadSize, config.Pos)
return LlamaAttention(config, q_proj, k_proj, v_proj, o_proj, rotary_emb)
Expand All @@ -277,9 +293,9 @@ def __call__(self, x: NamedArray, mask: Optional[NamedArray], *, key=None) -> Na
key_q, key_k, key_v, key_o = maybe_rng_split(key, 4)

# reorder heads and position for better training throughput
q = self.q_proj(x, key=key_q).rearrange((..., "heads", "position", "head_size"))
k = self.k_proj(x, key=key_k).rearrange((..., "heads", "position", "head_size"))
v = self.v_proj(x, key=key_v).rearrange((..., "heads", "position", "head_size"))
q = self.q_proj(x, key=key_q).rearrange((..., "kv_heads", "q_heads_per_group", "position", "head_size"))
k = self.k_proj(x, key=key_k).rearrange((..., "kv_heads", "position", "head_size"))
v = self.v_proj(x, key=key_v).rearrange((..., "kv_heads", "position", "head_size"))

cos, sin = self.rotary_emb(seq_len=x.axis_size("position"))

Expand All @@ -305,6 +321,8 @@ def __call__(self, x: NamedArray, mask: Optional[NamedArray], *, key=None) -> Na
flash_block_size=c.flash_attention_block_size,
)

attn_output = attn_output.flatten_axes(("kv_heads", "q_heads_per_group"), "heads")

if self.config.upcast_attn:
attn_output = attn_output.astype(x.dtype)

Expand Down Expand Up @@ -574,7 +592,7 @@ def _rotate_half(x: NamedArray) -> NamedArray:


def _apply_rotary_pos_emb(
q: NamedArray, # [batch, position, heads, head_size]
q: NamedArray, # [batch, position, kv_heads, q_heads_per_group, head_size]
k: NamedArray, # [batch, position, kv_heads, head_size]
cos: NamedArray, # [position, head_size]
sin: NamedArray, # [position, head_size]
Expand Down

0 comments on commit b905c9f

Please sign in to comment.