From ddba7fd875d27e74141f9d3d47b0f55266e6e388 Mon Sep 17 00:00:00 2001 From: David Hall Date: Fri, 9 Feb 2024 15:51:05 -0800 Subject: [PATCH] fix merge --- config/gpt2_large_sophia_g.yaml | 21 --------------------- config/gpt2_small_fast_sophia_g.yaml | 24 ------------------------ src/levanter/doremi.py | 2 +- 3 files changed, 1 insertion(+), 46 deletions(-) delete mode 100644 config/gpt2_large_sophia_g.yaml delete mode 100644 config/gpt2_small_fast_sophia_g.yaml diff --git a/config/gpt2_large_sophia_g.yaml b/config/gpt2_large_sophia_g.yaml deleted file mode 100644 index 53a1d0806..000000000 --- a/config/gpt2_large_sophia_g.yaml +++ /dev/null @@ -1,21 +0,0 @@ -data: !include data/openwebtext_source.yaml -model: - type: gpt2 - hidden_dim: 1280 - num_heads: 20 - num_layers: 36 - seq_len: 1024 - gradient_checkpointing: true - scale_attn_by_inverse_layer_idx: true -trainer: - wandb: - project: "levanter" - tags: [ "openwebtext", "gpt2", "sophia-g"] - - num_train_steps: 200000 - mp: p=f32,c=bfloat16 - -optimizer: - type: sophia-g - learning_rate: 2E-4 - weight_decay: 0.15 diff --git a/config/gpt2_small_fast_sophia_g.yaml b/config/gpt2_small_fast_sophia_g.yaml deleted file mode 100644 index 0f86ac503..000000000 --- a/config/gpt2_small_fast_sophia_g.yaml +++ /dev/null @@ -1,24 +0,0 @@ -data: !include data/openwebtext_source.yaml -model: - type: gpt2 - hidden_dim: 768 - num_heads: 12 - num_layers: 12 - seq_len: 1024 - gradient_checkpointing: true - scale_attn_by_inverse_layer_idx: true -trainer: - wandb: - project: "levanter" - tags: [ "openwebtext", "gpt2", "itest", "sophia-g"] - - mp: p=f32,c=bfloat16 - model_axis_size: 1 - per_device_parallelism: 8 - - train_batch_size: 256 - num_train_steps: 20000 -optimizer: - type: sophia-g - learning_rate: 1E-3 - weight_decay: 0.15 diff --git a/src/levanter/doremi.py b/src/levanter/doremi.py index 602dae4db..d9bcd5170 100644 --- a/src/levanter/doremi.py +++ b/src/levanter/doremi.py @@ -185,7 +185,7 @@ def doremi_step(state: DoremiState, ref, batch, domains): "train/mean_proxy_loss": mean_proxy_loss, **{f"alpha/{domain}": weight for domain, weight in alpha_dict.items()}, # just skip domains with no excess loss - **{f"train/{domain}/excess_loss": loss for domain, loss in per_domain_dict.items() if loss > 0}, + **{f"train/{domain}/excess_loss": loss for domain, loss in per_domain_dict.items()}, }, step=state._step, )