Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/dev' into doremi
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh committed Jan 23, 2024
2 parents efbdd31 + 6148381 commit 3e3c9da
Show file tree
Hide file tree
Showing 11 changed files with 570 additions and 21 deletions.
1 change: 1 addition & 0 deletions config/llama2_nano.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ model:
type: llama
hidden_dim: 32
num_heads: 4
num_kv_heads: 4
num_layers: 2
trainer:
tracker:
Expand Down
Binary file added docs/figures/finetune_func_cm_full_weight.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/figures/finetune_func_cm_lora.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
427 changes: 427 additions & 0 deletions docs/tutorials/Fine-Tuning-Semantic-Parsing.md

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ nav:
- 'Tutorials':
- "Fine-Tuning.md"
- "LoRA.md"
- "tutorials/Fine-Tuning-Semantic-Parsing.md"
- "Hardware-Agnostic-Training.md"
- 'Developer Guide':
- 'dev/Port-Models.md'
Expand Down
2 changes: 0 additions & 2 deletions src/levanter/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,12 +500,10 @@ def to_hf_config(config: LoraConfig, base_model_name_or_path: Optional[str] = No
return {
"base_model_name_or_path": base_model_name_or_path,
"bias": "none", # TODO: support bias
"enable_lora": None,
"fan_in_fan_out": False, # TODO: support fan_in_fan_out
"inference_mode": True, # TODO: support inference_mode
"lora_alpha": config.alpha,
"lora_dropout": 0.00, # TODO: support dropout
"merge_weights": False,
"modules_to_save": None, # TODO: support modules_to_save?
"peft_type": "LORA",
"r": config.r,
Expand Down
32 changes: 25 additions & 7 deletions src/levanter/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ class LlamaConfig(HFCompatConfig):
intermediate_dim (int, optional): dimension of the intermediate state. Defaults to 11008.
num_layers (int, optional): number of hidden layers in the Transformer encoder. Defaults to 32.
num_heads (int, optional): number of attention heads for each attention layer. Defaults to 32.
num_kv_heads (int, optional): number of attention heads for keys and values in each attention layer.
Setting to 1 means MQA. Setting to num_heads means MHA. Otherwise GQA.
Note that num_heads must be divisible by this number. Defaults to 32.
activation_function (str, optional): activation function for the hidden layer. Defaults to "silu".
rope_scaling (Dict, optional): dict containing the scaling configuration for the Rotary Positional Embedding.
"""
Expand All @@ -56,6 +59,7 @@ class LlamaConfig(HFCompatConfig):
intermediate_dim: int = 11008
num_layers: int = 32
num_heads: int = 32
num_kv_heads: int = 32
activation_function: str = "silu"
initializer_range: float = 0.02
layer_norm_epsilon: float = 1e-5
Expand All @@ -76,10 +80,16 @@ class LlamaConfig(HFCompatConfig):
KeyPos = property(lambda self: self.Pos.alias("key_position"))
Embed = property(lambda self: Axis(name="embed", size=self.hidden_dim))
Heads = property(lambda self: Axis(name="heads", size=self.num_heads))
KVHeads = property(lambda self: Axis(name="kv_heads", size=self.num_kv_heads))
Layers = property(lambda self: Axis(name="layers", size=self.num_layers))
Mlp = property(lambda self: Axis(name="mlp", size=self.intermediate_dim))
HeadSize = property(lambda self: Axis(name="head_size", size=self.hidden_dim // self.num_heads))

def __post_init__(self):
assert (
self.num_heads % self.num_kv_heads == 0
), f"num_heads={self.num_heads} not divisible by num_kv_heads={self.num_kv_heads}."

@cached_classproperty
def default_hf_checkpoint_converter(cls) -> HFCheckpointConverter["LlamaConfig"]: # type: ignore
return HFCheckpointConverter(
Expand All @@ -98,6 +108,7 @@ def from_hf_config(cls, hf_config: HfConfig):
intermediate_dim=hf_config.intermediate_size,
num_layers=hf_config.num_hidden_layers,
num_heads=hf_config.num_attention_heads,
num_kv_heads=hf_config.num_key_value_heads,
activation_function=hf_config.hidden_act,
initializer_range=hf_config.initializer_range,
layer_norm_epsilon=hf_config.rms_norm_eps,
Expand All @@ -123,6 +134,7 @@ def to_hf_config(self, vocab_size: int, config_overrides: Optional[Dict] = None)
intermediate_size=self.intermediate_dim,
num_hidden_layers=self.num_layers,
num_attention_heads=self.num_heads,
num_key_value_heads=self.num_kv_heads,
hidden_act=self.activation_function,
initializer_range=self.initializer_range,
rms_norm_eps=self.layer_norm_epsilon,
Expand Down Expand Up @@ -264,10 +276,14 @@ class LlamaAttention(StateDictSerializationMixin, eqx.Module):
def init(config: LlamaConfig, *, key) -> "LlamaAttention":
use_bias = config.use_bias
Embed = config.Embed
QHeadsPerGroup = hax.Axis("q_heads_per_group", config.num_heads // config.num_kv_heads)

k_q, k_k, k_v, k_o = jrandom.split(key, 4)
q_proj = hnn.Linear.init(In=Embed, Out=(config.Heads, config.HeadSize), key=k_q, use_bias=use_bias)
k_proj = hnn.Linear.init(In=Embed, Out=(config.Heads, config.HeadSize), key=k_k, use_bias=use_bias)
v_proj = hnn.Linear.init(In=Embed, Out=(config.Heads, config.HeadSize), key=k_v, use_bias=use_bias)
q_proj = hnn.Linear.init(
In=Embed, Out=(config.KVHeads, QHeadsPerGroup, config.HeadSize), key=k_q, use_bias=use_bias
)
k_proj = hnn.Linear.init(In=Embed, Out=(config.KVHeads, config.HeadSize), key=k_k, use_bias=use_bias)
v_proj = hnn.Linear.init(In=Embed, Out=(config.KVHeads, config.HeadSize), key=k_v, use_bias=use_bias)
o_proj = hnn.Linear.init(In=(config.Heads, config.HeadSize), Out=Embed, key=k_o, use_bias=use_bias)
rotary_emb = LlamaRotaryEmbedding(config.HeadSize, config.Pos)
return LlamaAttention(config, q_proj, k_proj, v_proj, o_proj, rotary_emb)
Expand All @@ -277,9 +293,9 @@ def __call__(self, x: NamedArray, mask: Optional[NamedArray], *, key=None) -> Na
key_q, key_k, key_v, key_o = maybe_rng_split(key, 4)

# reorder heads and position for better training throughput
q = self.q_proj(x, key=key_q).rearrange((..., "heads", "position", "head_size"))
k = self.k_proj(x, key=key_k).rearrange((..., "heads", "position", "head_size"))
v = self.v_proj(x, key=key_v).rearrange((..., "heads", "position", "head_size"))
q = self.q_proj(x, key=key_q).rearrange((..., "kv_heads", "q_heads_per_group", "position", "head_size"))
k = self.k_proj(x, key=key_k).rearrange((..., "kv_heads", "position", "head_size"))
v = self.v_proj(x, key=key_v).rearrange((..., "kv_heads", "position", "head_size"))

cos, sin = self.rotary_emb(seq_len=x.axis_size("position"))

Expand All @@ -305,6 +321,8 @@ def __call__(self, x: NamedArray, mask: Optional[NamedArray], *, key=None) -> Na
flash_block_size=c.flash_attention_block_size,
)

attn_output = attn_output.flatten_axes(("kv_heads", "q_heads_per_group"), "heads")

if self.config.upcast_attn:
attn_output = attn_output.astype(x.dtype)

Expand Down Expand Up @@ -574,7 +592,7 @@ def _rotate_half(x: NamedArray) -> NamedArray:


def _apply_rotary_pos_emb(
q: NamedArray, # [batch, position, heads, head_size]
q: NamedArray, # [batch, position, kv_heads, q_heads_per_group, head_size]
k: NamedArray, # [batch, position, kv_heads, head_size]
cos: NamedArray, # [position, head_size]
sin: NamedArray, # [position, head_size]
Expand Down
33 changes: 31 additions & 2 deletions src/levanter/optim/config.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
import abc
import re
import warnings
from dataclasses import dataclass
from typing import Optional

import draccus
import equinox as eqx
import jax
import optax
from jax import numpy as jnp

from levanter.utils.jax_utils import leaf_key_paths


@dataclass
class OptimizerConfig(draccus.ChoiceRegistry, abc.ABC):
Expand All @@ -20,6 +25,9 @@ class OptimizerConfig(draccus.ChoiceRegistry, abc.ABC):
cooldown: float = 0.0
"""fraction of training steps to use as cooldown, or steps to use. 0.0 means no cooldown"""
lr_schedule: str = "cosine" # constant, cosine, linear
weight_decay_modules: Optional[list[str] | str] = None
"""A regex or a list of strings to identify where to mask weight.
For nano-GPT, this field can be set as `r".*attn.*weight|.*mlp.*weight|.*token_embeddings|.*position_embeddings"`"""

@classmethod
def default_choice_name(cls) -> Optional[str]:
Expand All @@ -29,6 +37,28 @@ def default_choice_name(cls) -> Optional[str]:
def build(self, num_train_steps: int):
raise NotImplementedError

def build_weight_decay_mask(self):
if self.weight_decay_modules is None:
return None
else:
# mask based on regex or module path
def _apply_on(x, key_path):
if isinstance(self.weight_decay_modules, str):
compiled_regex = re.compile(self.weight_decay_modules)
return compiled_regex.match(key_path) is not None
else:
return any(key_path.__contains__(target) for target in self.weight_decay_modules)

def mask_fn(model):
return jax.tree_util.tree_map(
_apply_on,
model,
leaf_key_paths(model, is_leaf=eqx.is_array),
is_leaf=eqx.is_array,
)

return mask_fn

def lr_scheduler(self, num_train_steps):
warmup_steps = self._convert_warmup(num_train_steps)
cooldown_steps = _convert_ratio_or_steps(self.cooldown, num_train_steps)
Expand Down Expand Up @@ -118,8 +148,7 @@ def _optimizer(learning_rate):
components.append(optax.scale_by_adam(self.beta1, self.beta2, self.epsilon))

if self.weight_decay > 0:
# TODO: add weight decay masking??
components.append(optax.add_decayed_weights(self.weight_decay))
components.append(optax.add_decayed_weights(self.weight_decay, self.build_weight_decay_mask()))

# - learning rate for descent
components.append(optax.scale(-learning_rate))
Expand Down
2 changes: 1 addition & 1 deletion src/levanter/optim/sophia.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def _optimizer(learning_rate, gamma) -> SecondOrderTransformation:
# Algorithm 3, step 11 (Note, this comes after clipping b/c it's not supposed to be clipped)
# In the paper, it comes as a prior step, but doesn't get clipped
if self.weight_decay > 0:
components.append(optax.add_decayed_weights(self.weight_decay))
components.append(optax.add_decayed_weights(self.weight_decay, self.build_weight_decay_mask()))

# - learning rate for descent
components.append(optax.scale(-learning_rate))
Expand Down
26 changes: 17 additions & 9 deletions tests/test_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,12 @@ def named_array_to_tensor(named_array):

@skip_if_no_torch
@pytest.mark.parametrize("use_flash", [True, False])
def test_llama_attention(use_flash):
@pytest.mark.parametrize("num_kv_heads", [1, 2, 4])
def test_llama_attention(use_flash, num_kv_heads):
import torch
from transformers.models.llama.modeling_llama import LlamaAttention as HFLlamaAttention

config = _get_llama_config(use_flash=use_flash)
config = _get_llama_config(use_flash=use_flash, num_kv_heads=num_kv_heads)

attention = LlamaAttention.init(config=config, key=random.PRNGKey(0))

Expand Down Expand Up @@ -181,11 +182,12 @@ def test_llama_rms_norm():


@skip_if_no_torch
def test_llama_decoder_layer():
@pytest.mark.parametrize("num_kv_heads", [1, 2, 4])
def test_llama_decoder_layer(num_kv_heads):
import torch
from transformers.models.llama.modeling_llama import LlamaDecoderLayer as HFLlamaDecoderLayer

llama_config = _get_llama_config()
llama_config = _get_llama_config(num_kv_heads=num_kv_heads)
key = random.PRNGKey(0)
llama_decoder_layer = LlamaDecoderLayer.init(config=llama_config, key=key)

Expand All @@ -208,8 +210,9 @@ def test_llama_decoder_layer():
).all(), f"{hf_out[0]} != {out}"


def test_llama_lm_head_model():
llama_config = _get_llama_config()
@pytest.mark.parametrize("num_kv_heads", [1, 2, 4])
def test_llama_lm_head_model(num_kv_heads):
llama_config = _get_llama_config(num_kv_heads=num_kv_heads)
Batch = hax.Axis("batch", 2)
Vocab = hax.Axis("vocab", 1000)
Pos = llama_config.Pos
Expand All @@ -222,7 +225,8 @@ def test_llama_lm_head_model():


@skip_if_no_torch
def test_llama_roundtrip():
@pytest.mark.parametrize("num_kv_heads", [1, 2, 4])
def test_llama_roundtrip(num_kv_heads):
import torch
from transformers import AutoModelForCausalLM, LlamaForCausalLM

Expand All @@ -232,6 +236,7 @@ def test_llama_roundtrip():
seq_len=128,
hidden_dim=16,
num_heads=4,
num_kv_heads=num_kv_heads,
gradient_checkpointing=False,
)
Vocab = hax.Axis("vocab", 1000)
Expand Down Expand Up @@ -279,7 +284,7 @@ def compute(input):
assert np.isclose(torch_out2, np.array(jax_out), rtol=1e-2, atol=1e-2).all(), f"{torch_out2} != {jax_out}"


def _get_llama_config(use_flash=False) -> LlamaConfig:
def _get_llama_config(use_flash=False, num_kv_heads=4) -> LlamaConfig:
rope_scaling = {
"type": "linear",
"factor": 2.0,
Expand All @@ -288,6 +293,7 @@ def _get_llama_config(use_flash=False) -> LlamaConfig:
seq_len=128,
hidden_dim=16,
num_heads=4,
num_kv_heads=num_kv_heads,
rope_scaling=rope_scaling,
gradient_checkpointing=False, # disable for tests so debugging is easier
use_flash_attention=use_flash,
Expand All @@ -312,11 +318,13 @@ def test_llama_configs(config_file):
check_load_config(config_class, config_file)


def test_pass_different_length_seq():
@pytest.mark.parametrize("num_kv_heads", [1, 2])
def test_pass_different_length_seq(num_kv_heads):
config = LlamaConfig(
seq_len=32,
hidden_dim=16,
intermediate_dim=32,
num_heads=2,
num_kv_heads=num_kv_heads,
)
check_model_works_with_seqlen(LlamaLMHeadModel, config, 16)
67 changes: 67 additions & 0 deletions tests/test_weight_decay_mask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import equinox as eqx
import jax
import jax.random as jrandom

import haliax as hax

from levanter.models.gpt2 import Gpt2Config
from levanter.optim import AdamConfig


def test_weight_decay_masking():
def tree_at_mask(params):
# let's mask all leaves as False
params = jax.tree_util.tree_map(lambda _: False, params)

def apply_weight_decay(tree):
# there is no weight decay performed in LayerNorms and bias
nodes = []

# apply on embedding
nodes.append(tree.embeddings.token_embeddings.array)
nodes.append(tree.embeddings.position_embeddings.array)

# apply on attention
nodes.append(tree.transformer.blocks.stacked.attn.c_attn.weight.array)
nodes.append(tree.transformer.blocks.stacked.attn.c_proj.weight.array)

# apply on MLP
nodes.append(tree.transformer.blocks.stacked.mlp.c_fc.weight.array)
nodes.append(tree.transformer.blocks.stacked.mlp.c_proj.weight.array)

return nodes

# apply weight decay when necessary
params = eqx.tree_at(
where=apply_weight_decay,
pytree=params,
replace_fn=lambda _: True,
)

return params

gpt_config = Gpt2Config()
Vocab = hax.Axis("vocab", 100)
model = gpt_config.build(Vocab, key=jrandom.PRNGKey(0))
string_list_config = AdamConfig(
weight_decay_modules=[
"attn.c_attn.weight",
"attn.c_proj.weight",
"mlp.c_fc.weight",
"mlp.c_proj.weight",
"token_embeddings",
"position_embeddings",
]
)
regex_config = AdamConfig(
weight_decay_modules=r".*attn.*weight|.*mlp.*weight|.*token_embeddings|.*position_embeddings",
)
# masking using `equinox.tree_at`
true_mask = tree_at_mask(model)
# masking using list of module path
list_string_mask = string_list_config.build_weight_decay_mask()(model)

regex_mask = regex_config.build_weight_decay_mask()(model)

assert eqx.tree_equal(list_string_mask, true_mask)
assert eqx.tree_equal(regex_mask, true_mask)

0 comments on commit 3e3c9da

Please sign in to comment.