Skip to content

Commit

Permalink
Adds Short-Context Qwen Support
Browse files Browse the repository at this point in the history
  • Loading branch information
Helw150 committed Nov 20, 2024
1 parent 02961b2 commit 6bc87c3
Show file tree
Hide file tree
Showing 2 changed files with 435 additions and 0 deletions.
318 changes: 318 additions & 0 deletions src/levanter/models/qwen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,318 @@
from dataclasses import dataclass
from typing import Dict, Optional, Type

import equinox as eqx
import jax.numpy as jnp
import jax.random as jrandom

import haliax as hax
import haliax.nn as hnn
from haliax import Axis, NamedArray
from haliax.jax_utils import maybe_rng_split, named_call, shaped_rng_split
from haliax.nn.scan import Stacked

from levanter.compat.hf_checkpoints import HFCheckpointConverter
from levanter.compat.torch_serialization import (
StateDict,
StateDictSerializationMixin,
apply_prefix,
flatten_linear_layers,
unflatten_linear_layers,
)
from levanter.logging import silence_transformer_nag
from levanter.models.attention import AttentionMask, dot_product_attention
from levanter.models.llama import (
LlamaConfig,
LlamaEmbedding,
LlamaLMHeadModel,
LlamaMlp,
LlamaRMSNorm,
LlamaTransformer,
)
from levanter.models.lm_model import LmConfig
from levanter.models.rotary import RotaryEmbeddingsConfig
from levanter.types import BlockFoldable
from levanter.utils.flop_utils import lm_flops_per_token


silence_transformer_nag()
from transformers import PretrainedConfig as HfConfig # noqa: E402
from transformers import Qwen2Config as HfQwenConfig # noqa: E402


@LmConfig.register_subclass("qwen")
@dataclass(frozen=True)
class QwenConfig(LlamaConfig):
"""Extends LlamaConfig with Qwen specific features"""

use_sliding_window: bool = False
sliding_window: Optional[int] = None
max_window_layers: int = 0 # Only apply sliding window beyond this layer

def __post_init__(self):
assert (
self.num_heads % self.num_kv_heads == 0
), f"num_heads={self.num_heads} not divisible by num_kv_heads={self.num_kv_heads}."

def hf_checkpoint_converter(self) -> HFCheckpointConverter["QwenConfig"]: # type: ignore
return HFCheckpointConverter(
self.__class__,
reference_checkpoint=self.reference_checkpoint,
trust_remote_code=True,
tokenizer=self.tokenizer if self.tokenizer else self.reference_checkpoint,
HfConfigClass=HfQwenConfig,
)

@classmethod
def from_hf_config(cls, hf_config: HfConfig):
rope_theta = hf_config.rope_theta
rope_config = RotaryEmbeddingsConfig.from_hf_config(rope_theta, hf_config.rope_scaling)
return QwenConfig(
seq_len=hf_config.max_position_embeddings,
hidden_dim=hf_config.hidden_size,
intermediate_dim=hf_config.intermediate_size,
num_layers=hf_config.num_hidden_layers,
num_heads=hf_config.num_attention_heads,
num_kv_heads=hf_config.num_key_value_heads,
use_sliding_window=getattr(hf_config, "use_sliding_window", False),
sliding_window=getattr(hf_config, "sliding_window", None),
max_window_layers=getattr(hf_config, "max_window_layers", 0),
activation_function=hf_config.hidden_act,
initializer_range=hf_config.initializer_range,
layer_norm_epsilon=hf_config.rms_norm_eps,
tie_word_embeddings=hf_config.tie_word_embeddings,
rope=rope_config,
)

def to_hf_config(self, vocab_size: int, config_overrides: Optional[Dict] = None) -> HfQwenConfig:
if config_overrides is None:
config_overrides = {}

rope_theta, rope_scaling = self.rope.to_hf_config()

return HfQwenConfig(
max_position_embeddings=self.seq_len,
hidden_size=self.hidden_dim,
intermediate_size=self.intermediate_dim,
num_hidden_layers=self.num_layers,
num_attention_heads=self.num_heads,
num_key_value_heads=self.num_kv_heads,
use_sliding_window=self.use_sliding_window,
sliding_window=self.sliding_window,
max_window_layers=self.max_window_layers,
hidden_act=self.activation_function,
initializer_range=self.initializer_range,
rms_norm_eps=self.layer_norm_epsilon,
tie_word_embeddings=self.tie_word_embeddings,
vocab_size=vocab_size,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
**config_overrides,
)

@property
def model_type(self) -> Type["QwenLMHeadModel"]:
return QwenLMHeadModel

def flops_per_token(self, vocab_size: int):
return lm_flops_per_token(
hidden_dim=self.hidden_dim,
intermediate_dim=self.intermediate_dim,
num_layers=self.num_layers,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
seq_len=self.seq_len,
vocab_size=vocab_size,
glu=True,
)


# Modified attention class for Qwen
class QwenAttention(eqx.Module, StateDictSerializationMixin):
config: QwenConfig = eqx.static_field()
q_proj: hnn.Linear
k_proj: hnn.Linear
v_proj: hnn.Linear
o_proj: hnn.Linear

@staticmethod
def init(config: QwenConfig, *, key) -> "QwenAttention":
Embed = config.Embed
QHeadsPerGroup = hax.Axis("q_heads_per_group", config.num_heads // config.num_kv_heads)

k_q, k_k, k_v, k_o = jrandom.split(key, 4)
q_proj = hnn.Linear.init(
In=Embed,
Out=(config.KVHeads, QHeadsPerGroup, config.HeadSize),
key=k_q,
use_bias=True, # Qwen always uses bias in attention
out_first=True,
)
k_proj = hnn.Linear.init(
In=Embed, Out=(config.KVHeads, config.HeadSize), key=k_k, use_bias=True, out_first=True
)
v_proj = hnn.Linear.init(
In=Embed, Out=(config.KVHeads, config.HeadSize), key=k_v, use_bias=True, out_first=True
)
o_proj = hnn.Linear.init(
In=(config.Heads, config.HeadSize),
Out=Embed,
key=k_o,
use_bias=False, # Qwen doesn't use bias in o_proj
out_first=True,
)
return QwenAttention(config, q_proj, k_proj, v_proj, o_proj)

@named_call
def __call__(
self, x: NamedArray, mask: Optional[NamedArray | AttentionMask], layer_idx: int = 0, *, key=None
) -> NamedArray:
key_q, key_k, key_v, key_o = maybe_rng_split(key, 4)

# QKV projections
q = self.q_proj(x, key=key_q).rearrange((..., "kv_heads", "q_heads_per_group", "position", "head_size"))
k = self.k_proj(x, key=key_k).rearrange((..., "kv_heads", "position", "head_size"))
v = self.v_proj(x, key=key_v).rearrange((..., "kv_heads", "position", "head_size"))

# Apply rotary embeddings
rot_embs = self.config.rope.build(self.config.HeadSize, q.resolve_axis("position"))
q, k = rot_embs(self.config.HeadSize, q, k)

k = k.rename({"position": "key_position"})
v = v.rename({"position": "key_position"})

# Apply sliding window attention if configured and past max_window_layers
if (
self.config.use_sliding_window
and self.config.sliding_window is not None
and layer_idx >= self.config.max_window_layers
):
raise ValueError("Sliding Window Attention is not currently supported.")

# Perform attention
attn_output = dot_product_attention(
"position",
"key_position",
"head_size",
q,
k,
v,
mask,
attention_dtype=jnp.float32 if self.config.upcast_attn else x.dtype,
use_flash=self.config.use_flash_attention,
attn_backend=self.config.attn_backend,
flash_block_size=self.config.flash_attention_block_size,
)

attn_output = attn_output.flatten_axes(("kv_heads", "q_heads_per_group"), "heads")
attn_output = attn_output.astype(x.dtype)

attn_output = self.o_proj(attn_output, key=key_o)
return attn_output

def from_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None):
# unflatten the linear layers of HF state_dict to match the shape of LlamaAttention
d = {}
d.update(unflatten_linear_layers(apply_prefix(prefix, "q_proj"), state_dict, self.q_proj, True))
d.update(unflatten_linear_layers(apply_prefix(prefix, "k_proj"), state_dict, self.k_proj, True))
d.update(unflatten_linear_layers(apply_prefix(prefix, "v_proj"), state_dict, self.v_proj, True))
d.update(unflatten_linear_layers(apply_prefix(prefix, "o_proj"), state_dict, self.o_proj, True))

return super().from_state_dict(d, prefix)

def update_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None) -> StateDict:
# flatten the linear layers of LlamaAttention to match the shape of HF state_dict
my_dict: StateDict = {}
super().update_state_dict(my_dict, prefix)

my_dict.update(flatten_linear_layers(apply_prefix(prefix, "q_proj"), self.q_proj, True))
my_dict.update(flatten_linear_layers(apply_prefix(prefix, "k_proj"), self.k_proj, True))
my_dict.update(flatten_linear_layers(apply_prefix(prefix, "v_proj"), self.v_proj, True))
my_dict.update(flatten_linear_layers(apply_prefix(prefix, "o_proj"), self.o_proj, True))

state_dict.update(my_dict)
return state_dict


# Modified decoder layer for Qwen
class QwenDecoderLayer(eqx.Module):
config: QwenConfig = eqx.static_field()
self_attn: QwenAttention
mlp: LlamaMlp # Can reuse Llama MLP as structure is similar
input_layernorm: LlamaRMSNorm
post_attention_layernorm: LlamaRMSNorm

@staticmethod
def init(config: QwenConfig, *, key) -> "QwenDecoderLayer":
k_attn, k_mlp = jrandom.split(key, 2)

attn = QwenAttention.init(config, key=k_attn)
mlp = LlamaMlp.init(
config.Embed,
config.Mlp,
config.activation_function,
key=k_mlp,
use_bias=config.use_bias,
)
ln_1 = config.mk_LayerNorm(config.Embed)
ln_2 = config.mk_LayerNorm(config.Embed)

return QwenDecoderLayer(config, attn, mlp, ln_1, ln_2)

@named_call
def __call__(self, x: NamedArray, mask: Optional[NamedArray | AttentionMask], *, key=None) -> NamedArray:
k_attn, k_mlp = maybe_rng_split(key, 2)

residual = x
x = self.input_layernorm(x)
attn_output = self.self_attn(x=x, mask=mask, key=k_attn)
x = residual + attn_output

residual = x
x = self.post_attention_layernorm(x)
mlp_output = self.mlp(x, key=k_mlp)
output = residual + mlp_output
return output


# Modified transformer for Qwen
class QwenTransformer(LlamaTransformer):
config: QwenConfig = eqx.static_field()
layers: BlockFoldable[QwenDecoderLayer]
norm: LlamaRMSNorm

@staticmethod
def init(config: QwenConfig, *, key) -> "QwenTransformer":
S = Stacked
if not config.scan_layers:
from haliax.nn.scan import BlockSeq

S = BlockSeq

# Initialize layers with their indices
layers = S.init(config.Layers, QwenDecoderLayer, gradient_checkpointing=config.gradient_checkpointing)(
config,
key=shaped_rng_split(key, config.num_layers),
)

ln_f = config.mk_LayerNorm(config.Embed)
return QwenTransformer(config, layers, ln_f)


# Modified LM head model for Qwen
class QwenLMHeadModel(LlamaLMHeadModel):
transformer: QwenTransformer
embeddings: LlamaEmbedding # Can reuse Llama embeddings
lm_head: Optional[hnn.Linear]

@classmethod
def init(cls, Vocab: Axis, config: QwenConfig, *, key) -> "QwenLMHeadModel":
k_t, k_emb = jrandom.split(key, 2)
transformer = QwenTransformer.init(config, key=k_t)
embeddings = LlamaEmbedding.init(Vocab, config, key=k_emb)
if config.tie_word_embeddings:
lm_head = None
else:
lm_head = hnn.Linear.init(In=config.Embed, Out=Vocab, key=k_emb, use_bias=False, out_first=True)

return QwenLMHeadModel(transformer, embeddings, lm_head)
Loading

0 comments on commit 6bc87c3

Please sign in to comment.