From b905c9fabc88951aa6aa2fa2eadf1e7249e3a35f Mon Sep 17 00:00:00 2001 From: Ivan Zhou Date: Tue, 13 Feb 2024 04:43:03 +0000 Subject: [PATCH] revert change on llama --- src/levanter/models/llama.py | 32 +++++++++++++++++++++++++------- 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/src/levanter/models/llama.py b/src/levanter/models/llama.py index 1e05c2735..1f6afbe38 100644 --- a/src/levanter/models/llama.py +++ b/src/levanter/models/llama.py @@ -47,6 +47,9 @@ class LlamaConfig(HFCompatConfig): intermediate_dim (int, optional): dimension of the intermediate state. Defaults to 11008. num_layers (int, optional): number of hidden layers in the Transformer encoder. Defaults to 32. num_heads (int, optional): number of attention heads for each attention layer. Defaults to 32. + num_kv_heads (int, optional): number of attention heads for keys and values in each attention layer. + Setting to 1 means MQA. Setting to num_heads means MHA. Otherwise GQA. + Note that num_heads must be divisible by this number. Defaults to 32. activation_function (str, optional): activation function for the hidden layer. Defaults to "silu". rope_scaling (Dict, optional): dict containing the scaling configuration for the Rotary Positional Embedding. """ @@ -56,6 +59,7 @@ class LlamaConfig(HFCompatConfig): intermediate_dim: int = 11008 num_layers: int = 32 num_heads: int = 32 + num_kv_heads: int = 32 activation_function: str = "silu" initializer_range: float = 0.02 layer_norm_epsilon: float = 1e-5 @@ -76,10 +80,16 @@ class LlamaConfig(HFCompatConfig): KeyPos = property(lambda self: self.Pos.alias("key_position")) Embed = property(lambda self: Axis(name="embed", size=self.hidden_dim)) Heads = property(lambda self: Axis(name="heads", size=self.num_heads)) + KVHeads = property(lambda self: Axis(name="kv_heads", size=self.num_kv_heads)) Layers = property(lambda self: Axis(name="layers", size=self.num_layers)) Mlp = property(lambda self: Axis(name="mlp", size=self.intermediate_dim)) HeadSize = property(lambda self: Axis(name="head_size", size=self.hidden_dim // self.num_heads)) + def __post_init__(self): + assert ( + self.num_heads % self.num_kv_heads == 0 + ), f"num_heads={self.num_heads} not divisible by num_kv_heads={self.num_kv_heads}." + @cached_classproperty def default_hf_checkpoint_converter(cls) -> HFCheckpointConverter["LlamaConfig"]: # type: ignore return HFCheckpointConverter( @@ -98,6 +108,7 @@ def from_hf_config(cls, hf_config: HfConfig): intermediate_dim=hf_config.intermediate_size, num_layers=hf_config.num_hidden_layers, num_heads=hf_config.num_attention_heads, + num_kv_heads=hf_config.num_key_value_heads, activation_function=hf_config.hidden_act, initializer_range=hf_config.initializer_range, layer_norm_epsilon=hf_config.rms_norm_eps, @@ -123,6 +134,7 @@ def to_hf_config(self, vocab_size: int, config_overrides: Optional[Dict] = None) intermediate_size=self.intermediate_dim, num_hidden_layers=self.num_layers, num_attention_heads=self.num_heads, + num_key_value_heads=self.num_kv_heads, hidden_act=self.activation_function, initializer_range=self.initializer_range, rms_norm_eps=self.layer_norm_epsilon, @@ -264,10 +276,14 @@ class LlamaAttention(StateDictSerializationMixin, eqx.Module): def init(config: LlamaConfig, *, key) -> "LlamaAttention": use_bias = config.use_bias Embed = config.Embed + QHeadsPerGroup = hax.Axis("q_heads_per_group", config.num_heads // config.num_kv_heads) + k_q, k_k, k_v, k_o = jrandom.split(key, 4) - q_proj = hnn.Linear.init(In=Embed, Out=(config.Heads, config.HeadSize), key=k_q, use_bias=use_bias) - k_proj = hnn.Linear.init(In=Embed, Out=(config.Heads, config.HeadSize), key=k_k, use_bias=use_bias) - v_proj = hnn.Linear.init(In=Embed, Out=(config.Heads, config.HeadSize), key=k_v, use_bias=use_bias) + q_proj = hnn.Linear.init( + In=Embed, Out=(config.KVHeads, QHeadsPerGroup, config.HeadSize), key=k_q, use_bias=use_bias + ) + k_proj = hnn.Linear.init(In=Embed, Out=(config.KVHeads, config.HeadSize), key=k_k, use_bias=use_bias) + v_proj = hnn.Linear.init(In=Embed, Out=(config.KVHeads, config.HeadSize), key=k_v, use_bias=use_bias) o_proj = hnn.Linear.init(In=(config.Heads, config.HeadSize), Out=Embed, key=k_o, use_bias=use_bias) rotary_emb = LlamaRotaryEmbedding(config.HeadSize, config.Pos) return LlamaAttention(config, q_proj, k_proj, v_proj, o_proj, rotary_emb) @@ -277,9 +293,9 @@ def __call__(self, x: NamedArray, mask: Optional[NamedArray], *, key=None) -> Na key_q, key_k, key_v, key_o = maybe_rng_split(key, 4) # reorder heads and position for better training throughput - q = self.q_proj(x, key=key_q).rearrange((..., "heads", "position", "head_size")) - k = self.k_proj(x, key=key_k).rearrange((..., "heads", "position", "head_size")) - v = self.v_proj(x, key=key_v).rearrange((..., "heads", "position", "head_size")) + q = self.q_proj(x, key=key_q).rearrange((..., "kv_heads", "q_heads_per_group", "position", "head_size")) + k = self.k_proj(x, key=key_k).rearrange((..., "kv_heads", "position", "head_size")) + v = self.v_proj(x, key=key_v).rearrange((..., "kv_heads", "position", "head_size")) cos, sin = self.rotary_emb(seq_len=x.axis_size("position")) @@ -305,6 +321,8 @@ def __call__(self, x: NamedArray, mask: Optional[NamedArray], *, key=None) -> Na flash_block_size=c.flash_attention_block_size, ) + attn_output = attn_output.flatten_axes(("kv_heads", "q_heads_per_group"), "heads") + if self.config.upcast_attn: attn_output = attn_output.astype(x.dtype) @@ -574,7 +592,7 @@ def _rotate_half(x: NamedArray) -> NamedArray: def _apply_rotary_pos_emb( - q: NamedArray, # [batch, position, heads, head_size] + q: NamedArray, # [batch, position, kv_heads, q_heads_per_group, head_size] k: NamedArray, # [batch, position, kv_heads, head_size] cos: NamedArray, # [position, head_size] sin: NamedArray, # [position, head_size]