diff --git a/config/gpt2_1536_sophiah.yaml b/config/gpt2_1536_sophiah.yaml new file mode 100644 index 000000000..0d1008106 --- /dev/null +++ b/config/gpt2_1536_sophiah.yaml @@ -0,0 +1,32 @@ +data: + train_urls: + - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_train.{1..128}-of-128.jsonl.gz" + validation_urls: + - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_val.{1..8}-of-8.jsonl.gz" + cache_dir: "gs://levanter-data/tokenized/openwebtext/" + tokenizer: "gpt2" +model: + type: gpt2 + hidden_dim: 1536 + num_heads: 24 + num_layers: 48 + seq_len: 1024 + gradient_checkpointing: true + scale_attn_by_inverse_layer_idx: true +trainer: + tracker: + project: "levanter" + tags: [ "openwebtext", "gpt2"] + + mp: p=f32,c=bfloat16 + model_axis_size: 1 + per_device_parallelism: 2 + per_device_eval_parallelism: 8 +optimizer: + type: sophia-h + learning_rate: 2E-4 + weight_decay: 0.2 + min_lr_ratio: 0.1 + gamma: 0.01 + # sophia needs a minimum amount of warmup or it doesn't do well + warmup: 2000 diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index 41f5d04ab..f85336df4 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -493,13 +493,14 @@ def _take_train_step(self, state: S, model, grads, *batch, **batch_kwargs) -> S: """ # only train on the trainable parameters. We're leaning on JAX to do dead code elimination for us with hax.axis_mapping(self.parameter_axis_mapping): + opt_state = state.opt_state train_grads = _partition_trainable_params(grads, state.is_trainable)[0] trainable_model = _partition_trainable_params(model, state.is_trainable)[0] - updates, opt_state = self.optimizer.update(train_grads, state.opt_state, params=trainable_model) - partial_fn = lambda model: self.loss_fn(model, *batch, **batch_kwargs) - updates, opt_state = self.optimizer.update(grads, opt_state, params=trainable_model, obj_fn=partial_fn) + updates, opt_state = self.optimizer.update( + train_grads, opt_state, params=trainable_model, obj_fn=partial_fn + ) model = eqx.apply_updates(model, updates) return dataclasses.replace(state, _step=state._step + 1, model=model, opt_state=opt_state)