From be294b476fd092e7c8b85f9a64e2cd9c6c6de4f6 Mon Sep 17 00:00:00 2001 From: David Hall Date: Fri, 9 Feb 2024 23:55:47 -0800 Subject: [PATCH 1/5] Generic Tracker Interface (#459) * Trackers let us abstract out TB vs wandb --- README.md | 3 +- config/backpack.yaml | 2 +- config/gpt2_1536.yaml | 2 +- config/gpt2_20b.yaml | 2 +- config/gpt2_7b.yaml | 2 +- config/gpt2_large.yaml | 4 +- config/gpt2_medium.yaml | 2 +- config/gpt2_micro.yaml | 2 +- config/gpt2_nano.yaml | 3 +- config/gpt2_nano_tb.yaml | 25 +++ config/gpt2_small.yaml | 4 +- config/gpt2_small_fast.yaml | 7 +- config/gpt2_small_fast_mix.yaml | 2 +- config/gpt2_small_fast_pile.yaml | 2 +- config/gpt2_small_fast_wiki.yaml | 2 +- config/gpt2_small_sophiah.yaml | 2 +- config/gpt2_xl.yaml | 2 +- config/llama2_7b.yaml | 3 +- config/llama2_7b_continued.yaml | 3 +- config/llama2_nano.yaml | 2 +- config/lora/mpt_biomed.yaml | 3 +- config/mpt_7b_continued.yaml | 22 --- config/mpt_7b_continued_biomedlm.yaml | 27 --- docs/Configuration-Guide.md | 84 ++++++++- docs/Training-On-Your-Data.md | 3 +- docs/dev/Trackers.md | 104 +++++++++++ examples/alpaca-lora/alpaca_lora.py | 92 +++++----- examples/alpaca/alpaca.py | 2 +- examples/gsm8k-lora/gsm8k_lora.py | 93 +++++----- mkdocs.yml | 2 +- pyproject.toml | 2 +- src/levanter/__init__.py | 2 + src/levanter/callbacks.py | 61 ++++--- src/levanter/data/sharded_dataset.py | 5 +- src/levanter/logging.py | 237 ++------------------------ src/levanter/main/cache_dataset.py | 13 +- src/levanter/main/eval_lm.py | 2 +- src/levanter/main/lora_lm.py | 13 +- src/levanter/main/train_lm.py | 56 +++--- src/levanter/main/viz_logprobs.py | 6 +- src/levanter/tracker/__init__.py | 29 ++++ src/levanter/tracker/helpers.py | 75 ++++++++ src/levanter/tracker/tensorboard.py | 81 +++++++++ src/levanter/tracker/tracker.py | 117 +++++++++++++ src/levanter/tracker/tracker_fns.py | 235 +++++++++++++++++++++++++ src/levanter/tracker/wandb.py | 199 +++++++++++++++++++++ src/levanter/trainer.py | 120 ++++++++++--- tests/test_eval_lm.py | 2 +- tests/test_export_to_hf.py | 3 +- tests/test_logging.py | 4 +- tests/test_tracker.py | 80 +++++++++ tests/test_train_lm.py | 2 +- tests/test_viz_lm.py | 8 +- 53 files changed, 1368 insertions(+), 492 deletions(-) create mode 100644 config/gpt2_nano_tb.yaml delete mode 100644 config/mpt_7b_continued.yaml delete mode 100644 config/mpt_7b_continued_biomedlm.yaml create mode 100644 docs/dev/Trackers.md create mode 100644 src/levanter/tracker/__init__.py create mode 100644 src/levanter/tracker/helpers.py create mode 100644 src/levanter/tracker/tensorboard.py create mode 100644 src/levanter/tracker/tracker.py create mode 100644 src/levanter/tracker/tracker_fns.py create mode 100644 src/levanter/tracker/wandb.py create mode 100644 tests/test_tracker.py diff --git a/README.md b/README.md index 5a6b89cf6..13097d7dd 100644 --- a/README.md +++ b/README.md @@ -151,7 +151,8 @@ model: gradient_checkpointing: true scale_attn_by_inverse_layer_idx: true trainer: - wandb: + tracker: + type: wandb project: "levanter" tags: [ "openwebtext", "gpt2"] diff --git a/config/backpack.yaml b/config/backpack.yaml index 5b6cef3cb..493be77a3 100644 --- a/config/backpack.yaml +++ b/config/backpack.yaml @@ -10,7 +10,7 @@ model: num_senses: 16 sense_intermediate_scale: 4 trainer: - wandb: + tracker: project: "levanter" tags: [ "openwebtext", "backpack" ] diff --git a/config/gpt2_1536.yaml b/config/gpt2_1536.yaml index 50ccbd882..a3633bf65 100644 --- a/config/gpt2_1536.yaml +++ b/config/gpt2_1536.yaml @@ -8,7 +8,7 @@ model: gradient_checkpointing: true scale_attn_by_inverse_layer_idx: true trainer: - wandb: + tracker: project: "levanter" tags: [ "openwebtext", "gpt2"] diff --git a/config/gpt2_20b.yaml b/config/gpt2_20b.yaml index 76bf6ba96..6f5f40e1b 100644 --- a/config/gpt2_20b.yaml +++ b/config/gpt2_20b.yaml @@ -12,7 +12,7 @@ model: use_bias: false fcm_prob: 0.15 trainer: - wandb: + tracker: project: "levanter" tags: ["pile", "gpt2"] diff --git a/config/gpt2_7b.yaml b/config/gpt2_7b.yaml index affb67aa5..36a3d4fd2 100644 --- a/config/gpt2_7b.yaml +++ b/config/gpt2_7b.yaml @@ -11,7 +11,7 @@ model: resid_pdrop: 0.0 fcm_prob: 0.15 trainer: - wandb: + tracker: project: "levanter" tags: ["pile", "gpt2"] diff --git a/config/gpt2_large.yaml b/config/gpt2_large.yaml index 525a92c99..8a8aea8d7 100644 --- a/config/gpt2_large.yaml +++ b/config/gpt2_large.yaml @@ -8,13 +8,13 @@ model: gradient_checkpointing: true scale_attn_by_inverse_layer_idx: true trainer: - wandb: + tracker: project: "levanter" tags: [ "openwebtext", "gpt2"] mp: p=f32,c=bfloat16 model_axis_size: 1 - per_device_parallelism: 16 + per_device_parallelism: -1 optimizer: learning_rate: 2E-4 weight_decay: 0.1 diff --git a/config/gpt2_medium.yaml b/config/gpt2_medium.yaml index 9ea4408bc..47e21799c 100644 --- a/config/gpt2_medium.yaml +++ b/config/gpt2_medium.yaml @@ -8,7 +8,7 @@ model: gradient_checkpointing: true scale_attn_by_inverse_layer_idx: true trainer: - wandb: + tracker: project: "levanter" tags: [ "openwebtext", "gpt2"] diff --git a/config/gpt2_micro.yaml b/config/gpt2_micro.yaml index 274ecddaa..0a8283e78 100644 --- a/config/gpt2_micro.yaml +++ b/config/gpt2_micro.yaml @@ -6,7 +6,7 @@ model: num_heads: 8 num_layers: 4 trainer: - wandb: + tracker: project: "levanter" tags: [ "openwebtext", "gpt2"] diff --git a/config/gpt2_nano.yaml b/config/gpt2_nano.yaml index 993302670..1ad0ceb3b 100644 --- a/config/gpt2_nano.yaml +++ b/config/gpt2_nano.yaml @@ -14,8 +14,7 @@ trainer: - every: 50 save_interval: 5m - per_device_eval_parallelism: 1 - per_device_parallelism: 1 + per_device_parallelism: -1 train_batch_size: 32 tensor_parallel_axes: ["mlp", "heads"] diff --git a/config/gpt2_nano_tb.yaml b/config/gpt2_nano_tb.yaml new file mode 100644 index 000000000..f6847d693 --- /dev/null +++ b/config/gpt2_nano_tb.yaml @@ -0,0 +1,25 @@ +data: + id: dlwh/wikitext_103_detokenized +model: + type: gpt2 + hidden_dim: 32 + num_heads: 4 + num_layers: 2 +trainer: + mp: f32 + num_train_steps: 100 + + checkpointer: + keep: + - every: 50 + save_interval: 5m + + per_device_parallelism: -1 + train_batch_size: 32 + + tensor_parallel_axes: ["mlp", "heads"] + fsdp_axis: "embed" + batch_axis: "batch" + tracker: + type: tensorboard + logdir: tb_logs/ diff --git a/config/gpt2_small.yaml b/config/gpt2_small.yaml index 74d0e031a..b3e0295af 100644 --- a/config/gpt2_small.yaml +++ b/config/gpt2_small.yaml @@ -8,13 +8,13 @@ model: gradient_checkpointing: true scale_attn_by_inverse_layer_idx: true trainer: - wandb: + tracker: project: "levanter" tags: [ "openwebtext", "gpt2"] mp: p=f32,c=bfloat16 model_axis_size: 1 - per_device_parallelism: 4 + per_device_parallelism: -1 train_batch_size: 512 optimizer: diff --git a/config/gpt2_small_fast.yaml b/config/gpt2_small_fast.yaml index 4c8434f38..6242a37bc 100644 --- a/config/gpt2_small_fast.yaml +++ b/config/gpt2_small_fast.yaml @@ -8,9 +8,10 @@ model: gradient_checkpointing: true scale_attn_by_inverse_layer_idx: true trainer: - wandb: - project: "levanter" - tags: [ "openwebtext", "gpt2", "itest"] + tracker: + - type: wandb + project: "levanter" + tags: [ "openwebtext", "gpt2", "itest"] mp: p=f32,c=bfloat16 model_axis_size: 1 diff --git a/config/gpt2_small_fast_mix.yaml b/config/gpt2_small_fast_mix.yaml index 0785e9103..ca9fa2ca6 100644 --- a/config/gpt2_small_fast_mix.yaml +++ b/config/gpt2_small_fast_mix.yaml @@ -21,7 +21,7 @@ model: gradient_checkpointing: true scale_attn_by_inverse_layer_idx: true trainer: - wandb: + tracker: project: "levanter" tags: [ "openwebtext+wiki", "gpt2", "itest"] diff --git a/config/gpt2_small_fast_pile.yaml b/config/gpt2_small_fast_pile.yaml index f30743c1d..a0336da45 100644 --- a/config/gpt2_small_fast_pile.yaml +++ b/config/gpt2_small_fast_pile.yaml @@ -8,7 +8,7 @@ model: gradient_checkpointing: true scale_attn_by_inverse_layer_idx: true trainer: - wandb: + tracker: project: "levanter" tags: [ "pile", "gpt2", "itest"] diff --git a/config/gpt2_small_fast_wiki.yaml b/config/gpt2_small_fast_wiki.yaml index 407d8705b..a25736434 100644 --- a/config/gpt2_small_fast_wiki.yaml +++ b/config/gpt2_small_fast_wiki.yaml @@ -9,7 +9,7 @@ model: gradient_checkpointing: true scale_attn_by_inverse_layer_idx: true trainer: - wandb: + tracker: project: "levanter" tags: [ "openwebtext", "gpt2", "itest"] diff --git a/config/gpt2_small_sophiah.yaml b/config/gpt2_small_sophiah.yaml index 1dd5824c3..fd82ab226 100644 --- a/config/gpt2_small_sophiah.yaml +++ b/config/gpt2_small_sophiah.yaml @@ -8,7 +8,7 @@ model: gradient_checkpointing: true scale_attn_by_inverse_layer_idx: true trainer: - wandb: + tracker: project: "levanter" tags: [ "openwebtext", "gpt2", "sophia-h"] diff --git a/config/gpt2_xl.yaml b/config/gpt2_xl.yaml index 8230b56a5..026fc077e 100644 --- a/config/gpt2_xl.yaml +++ b/config/gpt2_xl.yaml @@ -8,7 +8,7 @@ model: gradient_checkpointing: true scale_attn_by_inverse_layer_idx: true trainer: - wandb: + tracker: project: "levanter" tags: [ "openwebtext", "gpt2"] mp: p=f32,c=bfloat16 diff --git a/config/llama2_7b.yaml b/config/llama2_7b.yaml index 68931f3fa..b4ebe705f 100644 --- a/config/llama2_7b.yaml +++ b/config/llama2_7b.yaml @@ -11,7 +11,8 @@ model: # initialize_from_hf: "meta-llama/Llama-2-7b-hf" # use_hf_model_config: true trainer: - wandb: + tracker: + type: wandb project: "levanter" tags: ["openwebtext", "llama"] diff --git a/config/llama2_7b_continued.yaml b/config/llama2_7b_continued.yaml index e03be7168..edb72a7e4 100644 --- a/config/llama2_7b_continued.yaml +++ b/config/llama2_7b_continued.yaml @@ -6,7 +6,8 @@ model: initialize_from_hf: true use_hf_model_config: true trainer: - wandb: + tracker: + type: wandb project: "levanter" tags: ["pile", "llama2"] diff --git a/config/llama2_nano.yaml b/config/llama2_nano.yaml index c3ae4cdb8..58415022e 100644 --- a/config/llama2_nano.yaml +++ b/config/llama2_nano.yaml @@ -12,7 +12,7 @@ model: num_kv_heads: 4 num_layers: 2 trainer: - wandb: + tracker: project: "levanter" tags: ["openwebtext", "llama"] mp: p=f32 diff --git a/config/lora/mpt_biomed.yaml b/config/lora/mpt_biomed.yaml index f49267ca1..6b19d0ab5 100644 --- a/config/lora/mpt_biomed.yaml +++ b/config/lora/mpt_biomed.yaml @@ -11,7 +11,8 @@ lora: alpha: 32.0 target_modules: ["Wqkv"] trainer: - wandb: + tracker: + type: wandb project: "levanter" tags: ["mpt", "lora", "pubmed"] diff --git a/config/mpt_7b_continued.yaml b/config/mpt_7b_continued.yaml deleted file mode 100644 index a7eaf800b..000000000 --- a/config/mpt_7b_continued.yaml +++ /dev/null @@ -1,22 +0,0 @@ -data: !include data/pile_source_old.yaml -model: - type: mpt -initialize_from_hf: true -use_hf_model_config: true -trainer: - wandb: - project: "levanter" - tags: ["pile", "mpt"] - - mp: p=f32,c=bfloat16 - - model_axis_size: 1 - per_device_parallelism: 4 - per_device_eval_parallelism: 4 - - train_batch_size: 1024 - num_train_steps: 10000 - steps_per_eval: 500 -optimizer: - learning_rate: 1.2e-4 - weight_decay: 0.1 diff --git a/config/mpt_7b_continued_biomedlm.yaml b/config/mpt_7b_continued_biomedlm.yaml deleted file mode 100644 index 44961df46..000000000 --- a/config/mpt_7b_continued_biomedlm.yaml +++ /dev/null @@ -1,27 +0,0 @@ -data: - train_urls: - - "gs://pubmed-mosaic/pubmed-sharded/pubmedRandomized_train.{1..128}-of-128.jsonl.gz" - validation_urls: - - "gs://pubmed-mosaic/pubmed-sharded/pubmedRandomized_val.{1..8}-of-8.jsonl.gz" - cache_dir: "gs://pubmed-mosaic/tokenized/pubmed-sharded-neox/" - tokenizer: "EleutherAI/gpt-neox-20b" -model: - type: mpt -initialize_from_hf: "mosaicml/mpt-7b@68e1a8e0ebb9b30f3c45c1ef6195980f29063ae2" -use_hf_model_config: true -trainer: - wandb: - project: "levanter" - tags: ["pubmed", "mpt", "continued"] - - mp: p=f32,c=bfloat16 - - model_axis_size: 1 - per_device_parallelism: 8 - - train_batch_size: 2048 - num_train_steps: 50000 - steps_per_eval: 1000 -optimizer: - learning_rate: 1.2e-5 - weight_decay: 0.1 diff --git a/docs/Configuration-Guide.md b/docs/Configuration-Guide.md index 607129e1a..bdb09e4f1 100644 --- a/docs/Configuration-Guide.md +++ b/docs/Configuration-Guide.md @@ -35,7 +35,8 @@ model: gradient_checkpointing: true scale_attn_by_inverse_layer_idx: true trainer: - wandb: + tracker: + type: wandb project: "levanter" tags: [ "openwebtext", "gpt2"] @@ -179,12 +180,34 @@ The default step-based checkpoint policy is to save a checkpoint every 10,000 st -## WandB +## Trackers and Logging -We mostly use wandb for logging, including using wandb for allocating the run id. We may change this. -These all live in a nested object `wandb` inside `trainer`. Most of these are the same as the corresponding `wandb.init` -parameters. +We mostly use [W&B](https://wandb.ai/site) for tracking values and other metadata about a run. However, we also support +Tensorboard and a few other trackers. You can also use multiple trackers at once, or even write your own. +See [Trackers](dev/Trackers.md) for more information. + +### W&B + +Wandb is the default tracker and is installed by default. To use it, you can configure it in your config file: + +```yaml +trainer: + tracker: + type: wandb + project: my-project + entity: my-entity +``` + +Because wandb is the default, you can also just do: + +```yaml +trainer: + tracker: + project: my-project + entity: my-entity +``` + | Parameter | Description | Default | @@ -206,6 +229,35 @@ of your main script. To use it, you must also set the right environment variables. Something like `XLA_FLAGS="--xla_dump_to=/tmp/output_folder/xla_dumps --xla_dump_hlo_pass_re=.*`. We will automatically parse out the env variable. +### Tensorboard + +Tensorboard is also supported. To use it, you can configure it in your config file: + +```yaml +trainer: + tracker: + type: tensorboard + logdir: logs +``` + +### Multiple Trackers + +In some cases, you may want to use multiple trackers at once. +For example, you may want to use both W&B and Tensorboard. + +To do this, you can use the [levanter.tracker.tracker.CompositeTracker][] class, or, if using a config file, you +can specify multiple trackers: + +```yaml +trainer: + tracker: + - type: wandb + project: my-project + entity: my-entity + - type: tensorboard + logdir: logs +``` + ## Ray Config Levanter will by default automatically start a Ray cluster with all @@ -277,8 +329,26 @@ We won't go into detail here. You can see the auto-generated docs below. ::: levanter.checkpoint.Checkpointer -### Wandb -::: levanter.logging.WandbConfig +### Trackers and Metrics + +See also [Trackers](dev/Trackers.md) for more information. Basic configuration is shown below. + +#### Single Tracker + +```yaml +trainer: + tracker: + type: wandb + project: my-project + entity: my-entity +``` + + + +::: levanter.tracker.wandb.WandbConfig + +::: levanter.tracker.tensorboard.TensorboardConfig + ### Distributed and Ray diff --git a/docs/Training-On-Your-Data.md b/docs/Training-On-Your-Data.md index edf33e0af..4c543b04f 100644 --- a/docs/Training-On-Your-Data.md +++ b/docs/Training-On-Your-Data.md @@ -214,7 +214,8 @@ model: gradient_checkpointing: true scale_attn_by_inverse_layer_idx: true trainer: - wandb: + tracker: + type: wandb project: "levanter" # TODO tags: ["gpt2"] diff --git a/docs/dev/Trackers.md b/docs/dev/Trackers.md new file mode 100644 index 000000000..1f1677d52 --- /dev/null +++ b/docs/dev/Trackers.md @@ -0,0 +1,104 @@ +# Trackers and Metrics + +Logging values and other metadata about a run is a core requirement for any ML framework. +Until recently, Levanter had a hard dependency on [W&B](https://wandb.ai/site) for tracking such values. + +In the latest version, we introduce the [levanter.tracker.Tracker][] interface, which allows you to use any tracking backend you want. +The interface name is taken from the [HuggingFace Accelerate](https://github.com/huggingface/accelerate/blob/0f2686c8d3e6d949c4b7efa15d7f2dee44f7ce91/src/accelerate/tracking.py#L395) +framework. + +Given Levanter's historical dependency on W&B, the interface is designed to look similar to W&B's API. +The methods currently exposed are: + +* [levanter.tracker.current_tracker][]: returns the current tracker instance or sets it. +* [levanter.tracker.log_metrics][]: logs a dictionary of metrics for a given step. +* [levanter.tracker.log_summary][]: logs a dictionary of "summary" information, analogous to W&B's version. +* [levanter.tracker.get_tracker][]: returns a tracker with the given name. +* [levanter.tracker.jit_log_metrics][]: a version of [levanter.tracker.log_metrics][] that works inside JAX jit. + +A basic example of using the tracker interface is shown below: + +```python +import wandb +from levanter.tracker import current_tracker, log_metrics, log_summary +from levanter.tracker.wandb import WandbTracker + +with current_tracker(WandbTracker(wandb.init())): + for step in range(100): + log_metrics({"loss": 100 -0.01 * step}, step=step) + + log_summary({"best_loss": 0.0}) +``` + +A more typical example would be to use it in a config file, as we do with Trainer: + +```yaml +trainer: + tracker: + type: wandb + project: my-project + entity: my-entity +``` + +### Multiple Trackers + +In some cases, you may want to use multiple trackers at once. +For example, you may want to use both W&B and Tensorboard. + +To do this, you can use the [levanter.tracker.tracker.CompositeTracker][] class, or, if using a config file, you +can specify multiple trackers: + +```yaml +trainer: + tracker: + - type: wandb + project: my-project + entity: my-entity + - type: tensorboard + logdir: logs +``` + +## Adding your own tracker + +To add your own tracker, you need to implement the [levanter.tracker.Tracker][] interface. +You will also want to register your config with TrackerConfig as a "choice" in the choice type. +Follow the pattern for Tensorboard and W&B. + +TODO: expand this section. + + +## API Reference + +### Core Functions + +::: levanter.tracker.current_tracker + +::: levanter.tracker.log_metrics + +::: levanter.tracker.log_summary + +::: levanter.tracker.get_tracker + +::: levanter.tracker.jit_log_metrics + +### Trackers + +::: levanter.tracker.Tracker + +::: levanter.tracker.tracker.CompositeTracker + +::: levanter.tracker.tracker.NoopTracker + +::: levanter.tracker.tensorboard.TensorboardTracker + +::: levanter.tracker.wandb.WandbTracker + +### Tracker Config + +::: levanter.tracker.TrackerConfig + +::: levanter.tracker.tracker.NoopConfig + +::: levanter.tracker.tensorboard.TensorboardConfig + +::: levanter.tracker.wandb.WandbConfig diff --git a/examples/alpaca-lora/alpaca_lora.py b/examples/alpaca-lora/alpaca_lora.py index a4380a92b..0e7c5790e 100644 --- a/examples/alpaca-lora/alpaca_lora.py +++ b/examples/alpaca-lora/alpaca_lora.py @@ -8,7 +8,6 @@ import jax.random as jrandom import transformers -import wandb import haliax as hax @@ -49,7 +48,7 @@ class TrainArgs(alpaca.TrainArgs): def train(config: TrainArgs): - config.trainer.initialize(config) + levanter.initialize(config) # Since Levanter has different implementations of models from HF, we need to convert the HF checkpoint. # This class is a wrapper around the HF checkpoint converter that also downloads the checkpoint if necessary. @@ -101,53 +100,58 @@ def loraize_hf_model(model): def compute_loss(model: LmHeadModel, example: LmExample, key=None): return model.compute_loss(example, key=key).scalar() - trainer = Trainer(config.trainer, optimizer, compute_loss, is_trainable=lora_param_filter) - # end major difference from Alpaca - trainer.add_default_hooks() - state = trainer.initial_state(training_key, model=model) - - # log some info about the model - all_param_count = parameter_count(state.model) - just_lora_params = parameter_count(trainer.trainable_params_only(state.model)) - - wandb.summary["parameter_count"] = all_param_count - wandb.summary["trainable_parameter_count"] = just_lora_params - logger.info(f"Total parameter count: {all_param_count}") - logger.info(f"Trainable parameter count: {just_lora_params}") - logger.info(f"Fraction of parameters that are trainable: {just_lora_params * 1.0 / all_param_count%.3}") - - # Levanter has two kinds of data loaders: sharded and replicated. Replicated is simpler and allows for - # single pass training. Sharded only loads a subset of the data on each device, and is more efficient for large - # datasets. We use replicated here since the dataset is small. - loader = trainer.replicated_loader(train_dataset, trainer.TrainBatch) - loader = non_caching_cycle(loader) - - if state.step != 0: - logger.info(f"Resuming training from step {state.step}") - for i in range(state.step): - next(loader) # type: ignore - - # Save HF PEFT checkpoints periodically (and at the end of training), which is just the lora weights - if config.hf_save_path is not None: - full_save_path = os.path.join(config.hf_save_path, trainer.run_id) - trainer.add_hook( - save_peft_checkpoint_callback( - full_save_path, config.lora, config.model_name_or_path, tokenizer, config.hf_upload - ), - every=config.hf_save_steps, - ) + with Trainer(config.trainer, optimizer, compute_loss, is_trainable=lora_param_filter) as trainer: + trainer.add_default_hooks() + state = trainer.initial_state(training_key, model=model) + + # log some info about the model + all_param_count = parameter_count(state.model) + just_lora_params = parameter_count(trainer.trainable_params_only(state.model)) - # Save merged HF checkpoints if requested - if config.merged_hf_save_path is not None: - full_save_path = os.path.join(config.merged_hf_save_path, trainer.run_id) - trainer.add_hook( - save_merged_hf_checkpoint_callback(full_save_path, converter, config.merged_hf_upload), - every=config.hf_save_steps, + levanter.tracker.log_summary( + { + "parameter_count": all_param_count, + "trainable_parameter_count": just_lora_params, + "fraction_trainable": just_lora_params * 1.0 / all_param_count, + } ) - trainer.train(state, loader) + logger.info(f"Total parameter count: {all_param_count}") + logger.info(f"Trainable parameter count: {just_lora_params}") + logger.info(f"Fraction of parameters that are trainable: {just_lora_params * 1.0 / all_param_count%.3}") + + # Levanter has two kinds of data loaders: sharded and replicated. Replicated is simpler and allows for + # single pass training. Sharded only loads a subset of the data on each device, and is more efficient for large + # datasets. We use replicated here since the dataset is small. + loader = trainer.replicated_loader(train_dataset, trainer.TrainBatch) + loader = non_caching_cycle(loader) + + if state.step != 0: + logger.info(f"Resuming training from step {state.step}") + for i in range(state.step): + next(loader) # type: ignore + + # Save HF PEFT checkpoints periodically (and at the end of training), which is just the lora weights + if config.hf_save_path is not None: + full_save_path = os.path.join(config.hf_save_path, trainer.run_id) + trainer.add_hook( + save_peft_checkpoint_callback( + full_save_path, config.lora, config.model_name_or_path, tokenizer, config.hf_upload + ), + every=config.hf_save_steps, + ) + + # Save merged HF checkpoints if requested + if config.merged_hf_save_path is not None: + full_save_path = os.path.join(config.merged_hf_save_path, trainer.run_id) + trainer.add_hook( + save_merged_hf_checkpoint_callback(full_save_path, converter, config.merged_hf_upload), + every=config.hf_save_steps, + ) + + trainer.train(state, loader) if __name__ == "__main__": diff --git a/examples/alpaca/alpaca.py b/examples/alpaca/alpaca.py index 36a6dd943..a20f357fe 100644 --- a/examples/alpaca/alpaca.py +++ b/examples/alpaca/alpaca.py @@ -194,7 +194,7 @@ def get_prompts(prompt_path) -> dict: def train(config: TrainArgs): - config.trainer.initialize(config) + levanter.initialize(config) # Since Levanter has different implementations of models from HF, we need to convert the HF checkpoint. # This class is a wrapper around the HF checkpoint converter that also downloads the checkpoint if necessary. diff --git a/examples/gsm8k-lora/gsm8k_lora.py b/examples/gsm8k-lora/gsm8k_lora.py index 5e4927d2f..7361ed864 100644 --- a/examples/gsm8k-lora/gsm8k_lora.py +++ b/examples/gsm8k-lora/gsm8k_lora.py @@ -9,7 +9,6 @@ import jax.random as jrandom import numpy as np import transformers -import wandb import haliax as hax @@ -127,7 +126,7 @@ def format_output(ex): def train(config: TrainArgs): - config.trainer.initialize(config) + levanter.initialize(config) # Since Levanter has different implementations of models from HF, we need to convert the HF checkpoint. # This class is a wrapper around the HF checkpoint converter that also downloads the checkpoint if necessary. @@ -169,53 +168,59 @@ def loraize_hf_model(model): def compute_loss(model: LmHeadModel, example: LmExample, key=None): return model.compute_loss(example, key=key).scalar() - trainer = Trainer(config.trainer, optimizer, compute_loss, is_trainable=lora_param_filter) - # end major difference from Alpaca - trainer.add_default_hooks() - state = trainer.initial_state(training_key, model=model) - - # log some info about the model - all_param_count = parameter_count(state.model) - just_lora_params = parameter_count(trainer.trainable_params_only(state.model)) - - wandb.summary["parameter_count"] = all_param_count - wandb.summary["trainable_parameter_count"] = just_lora_params - logger.info(f"Total parameter count: {all_param_count}") - logger.info(f"Trainable parameter count: {just_lora_params}") - logger.info(f"Fraction of parameters that are trainable: {just_lora_params * 1.0 / all_param_count%.3}") - - # Levanter has two kinds of data loaders: sharded and replicated. Replicated is simpler and allows for - # single pass training. Sharded only loads a subset of the data on each device, and is more efficient for large - # datasets. We use replicated here since the dataset is small. - loader = trainer.replicated_loader(train_dataset, trainer.TrainBatch) - loader = non_caching_cycle(loader) - - if state.step != 0: - logger.info(f"Resuming training from step {state.step}") - for i in range(state.step): - next(loader) # type: ignore - - # Save HF PEFT checkpoints periodically (and at the end of training), which is just the lora weights - if config.hf_save_path is not None: - full_save_path = os.path.join(config.hf_save_path, trainer.run_id) - trainer.add_hook( - save_peft_checkpoint_callback( - full_save_path, config.lora, config.model_name_or_path, tokenizer, config.hf_upload - ), - every=config.hf_save_steps, - ) + with Trainer(config.trainer, optimizer, compute_loss, is_trainable=lora_param_filter) as trainer: + state = trainer.initial_state(training_key, model=model) + + # log some info about the model + all_param_count = parameter_count(state.model) + just_lora_params = parameter_count(trainer.trainable_params_only(state.model)) - # Save merged HF checkpoints if requested - if config.merged_hf_save_path is not None: - full_save_path = os.path.join(config.merged_hf_save_path, trainer.run_id) - trainer.add_hook( - save_merged_hf_checkpoint_callback(full_save_path, converter, config.merged_hf_upload), - every=config.hf_save_steps, + levanter.tracker.log_summary( + { + "parameter_count": all_param_count, + "trainable_parameter_count": just_lora_params, + "fraction_trainable": just_lora_params * 1.0 / all_param_count, + } ) - trainer.train(state, loader) + logger.info(f"Total parameter count: {all_param_count}") + logger.info(f"Trainable parameter count: {just_lora_params}") + logger.info(f"Fraction of parameters that are trainable: {just_lora_params * 1.0 / all_param_count%.3}") + + # Levanter has two kinds of data loaders: sharded and replicated. Replicated is simpler and allows for + # single pass training. Sharded only loads a subset of the data on each device, and is more efficient for large + # datasets. We use replicated here since the dataset is small. + loader = trainer.replicated_loader(train_dataset, trainer.TrainBatch) + loader = non_caching_cycle(loader) + + trainer.add_default_hooks() + + if state.step != 0: + logger.info(f"Resuming training from step {state.step}") + for i in range(state.step): + next(loader) # type: ignore + + # Save HF PEFT checkpoints periodically (and at the end of training), which is just the lora weights + if config.hf_save_path is not None: + full_save_path = os.path.join(config.hf_save_path, trainer.run_id) + trainer.add_hook( + save_peft_checkpoint_callback( + full_save_path, config.lora, config.model_name_or_path, tokenizer, config.hf_upload + ), + every=config.hf_save_steps, + ) + + # Save merged HF checkpoints if requested + if config.merged_hf_save_path is not None: + full_save_path = os.path.join(config.merged_hf_save_path, trainer.run_id) + trainer.add_hook( + save_merged_hf_checkpoint_callback(full_save_path, converter, config.merged_hf_upload), + every=config.hf_save_steps, + ) + + trainer.train(state, loader) if __name__ == "__main__": diff --git a/mkdocs.yml b/mkdocs.yml index 568716ac4..28fdb9849 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -98,7 +98,7 @@ nav: - "Hardware-Agnostic-Training.md" - 'Developer Guide': - 'dev/Port-Models.md' -# - 'dev/Trackers.md' + - 'dev/Trackers.md' - 'FAQ' : 'faq.md' - Other: - 'Levanter-1.0-Release.md' diff --git a/pyproject.toml b/pyproject.toml index 14f010c1b..a717d9d97 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,7 @@ dependencies = [ "transformers>=4.22.0", "optax", "wandb", - "draccus>=0.6", + "draccus>=0.7.1", "pyarrow>=11.0.0", "zstandard>=0.20.0", "datasets==2.16.1", diff --git a/src/levanter/__init__.py b/src/levanter/__init__.py index 30c32a712..548a113a0 100644 --- a/src/levanter/__init__.py +++ b/src/levanter/__init__.py @@ -5,5 +5,7 @@ import levanter.logging as logging import levanter.models as models import levanter.optim as optim +import levanter.tracker as tracker import levanter.trainer as trainer import levanter.visualization as visualization +from levanter.trainer import initialize diff --git a/src/levanter/callbacks.py b/src/levanter/callbacks.py index 2292c714a..b0244e0e3 100644 --- a/src/levanter/callbacks.py +++ b/src/levanter/callbacks.py @@ -1,5 +1,5 @@ import copy -import logging +import logging as pylogging import os import re import subprocess @@ -11,20 +11,24 @@ import humanfriendly import jax -import wandb from tqdm import tqdm -from levanter.logging import WandbConfig, log_optimizer_hyperparams, save_xla_dumps_to_wandb +import levanter.tracker +from levanter.logging import save_xla_dumps_to_wandb +from levanter.tracker.helpers import log_optimizer_hyperparams +from levanter.tracker.wandb import WandbConfig from levanter.trainer import StepInfo from levanter.utils.jax_utils import jnp_to_python from levanter.visualization import compute_and_visualize_log_probs as viz_probs -logger = logging.getLogger(__name__) +logger = pylogging.getLogger(__name__) def eval_loss_loop(loss_fn, model, dataset, max_batches: Optional[int] = None, name: Optional[str] = None): total_loss = 0.0 + total_load_time = 0.0 + total_loss_time = 0.0 n = 0 if name is not None: @@ -33,10 +37,20 @@ def eval_loss_loop(loss_fn, model, dataset, max_batches: Optional[int] = None, n desc = "eval" pbar = tqdm(dataset, desc=desc, position=1, leave=False, total=max_batches) - for batch in pbar: + iter_ = iter(pbar) + while True: + time_in = time.time() + batch = next(iter_, None) + if batch is None: + break + load_time = time.time() - time_in + total_load_time += load_time loss = loss_fn(model, batch) total_loss += loss.item() n += 1 + loss_time = time.time() - time_in - load_time + total_loss_time += loss_time + pbar.set_postfix(loss=total_loss / n) if max_batches is not None and n >= max_batches: @@ -45,6 +59,9 @@ def eval_loss_loop(loss_fn, model, dataset, max_batches: Optional[int] = None, n if n > 0: total_loss /= n + # logger.info(f"eval loading time: {total_load_time / n:.3f} s/ba") + # logger.info(f"eval loss time: {total_loss_time / n:.3f} s/ba") + return total_loss @@ -57,11 +74,10 @@ def compute_validation_loss( def compute_loss(info: StepInfo): loss = eval_loss_loop(loss_fn, info.model, dataset, max_batches=max_batches, name=name) - if wandb.run is not None: - prefix = "eval" - if name: - prefix += "/" + name - wandb.log({f"{prefix}/loss": loss}, step=info.step) + prefix = "eval" + if name: + prefix += "/" + name + levanter.tracker.log_metrics({f"{prefix}/loss": loss}, step=info.step) if name: logger.info(f"{name} validation loss: {loss:.3f}") @@ -73,12 +89,14 @@ def compute_loss(info: StepInfo): return compute_loss -def log_to_wandb(step: StepInfo): - wandb.log({"train/loss": step.loss, "global_step": step.step}, step=step.step) +def log_step_info(step: StepInfo): + levanter.tracker.log_metrics({"train/loss": step.loss, "global_step": step.step}, step=step.step) log_optimizer_hyperparams(step.opt_state, step=step.step, prefix="optim") def wandb_xla_logger(config: WandbConfig): + import wandb + last_mtime = wandb.run and wandb.run.start_time or time.time() def log_xla_to_wandb(step: StepInfo): @@ -108,14 +126,14 @@ def log_performance_stats(step_info: StepInfo): # log these totals because it's useful for comparing different seqlens, batch sizes, etc total_tokens = tokens_per_example * batch_size * step_info.step - wandb.log({wrap_key("total_tokens"): total_tokens}, step=step_info.step) + levanter.tracker.log_metrics({wrap_key("total_tokens"): total_tokens}, step=step_info.step) if flops_per_example: total_flops = flops_per_example * batch_size * step_info.step - wandb.log({wrap_key("total_gflops"): total_flops / 1e9}, step=step_info.step) + levanter.tracker.log_metrics({wrap_key("total_gflops"): total_flops / 1e9}, step=step_info.step) if step_info.step_duration != 0.0: - wandb.log( + levanter.tracker.log_metrics( { wrap_key("examples_per_second"): float(batch_size) / step_info.step_duration, wrap_key("tokens_per_second"): float(tokens_per_example) / step_info.step_duration * batch_size, @@ -125,7 +143,7 @@ def log_performance_stats(step_info: StepInfo): ) if flops_per_example is not None: - wandb.log( + levanter.tracker.log_metrics( { wrap_key("gflops_per_second"): flops_per_example / 1e9 / step_info.step_duration * batch_size, }, @@ -152,7 +170,7 @@ def update_pbar(step: StepInfo): def log_memory_usage(sample_interval: float = 1.0, log_individual_devices: bool = False): """ - Logs memory usage to wandb. This runs a loop that samples memory usage every `sample_interval` seconds. + Logs memory usage. This runs a loop that samples memory usage every `sample_interval` seconds. We only log when hooks are invoked, so there's not much point in running this much more frequently than you invoke the hook. @@ -218,7 +236,7 @@ def log_memory_usage(step: StepInfo): match = regex.search(by_kind) if match: memory_usage = humanfriendly.parse_size(match.group(1)) - wandb.log({"memory/total": memory_usage / 1e6}, step=step.step) + levanter.tracker.log_metrics({"memory/total": memory_usage / 1e6}, step=step.step) # this works for the "kind" and the individual devices regex = re.compile(r"([\d.]+[a-zA-Z]+) \(([\d.]+)%\): ([\w\d:_]+)") @@ -229,14 +247,14 @@ def log_memory_usage(step: StepInfo): for match in regex.finditer(per_device): memory_usage = humanfriendly.parse_size(match.group(1)) device_name = match.group(3) - wandb.log({f"memory/device/{device_name}": memory_usage / 1e6}, step=step.step) + levanter.tracker.log_metrics({f"memory/device/{device_name}": memory_usage / 1e6}, step=step.step) # now, get the memory usage per kind. # same regex as above for match in regex.finditer(by_kind): memory_usage = match.group(1) memory_usage = humanfriendly.parse_size(memory_usage) - wandb.log({f"memory/{match.group(3)}": memory_usage / 1e6}, step=step.step) + levanter.tracker.log_metrics({f"memory/{match.group(3)}": memory_usage / 1e6}, step=step.step) return log_memory_usage @@ -262,6 +280,9 @@ def compute_and_viz_log_probs(step: StepInfo): path = os.path.join(html_dir, f"step_{step}.html") viz_probs(path, model, tokenizer, log_prob_fn, test_data, max_docs=max_docs) + # TODO: convert to generic logging + import wandb + wandb.log({"log_probs": wandb.Html(path)}, step=step.step) return compute_and_viz_log_probs diff --git a/src/levanter/data/sharded_dataset.py b/src/levanter/data/sharded_dataset.py index 0ec178e08..d162882ac 100644 --- a/src/levanter/data/sharded_dataset.py +++ b/src/levanter/data/sharded_dataset.py @@ -147,9 +147,10 @@ class WrappedHFDataset(ShardedDataset[dict]): kwargs are passed to load_dataset """ - def __init__(self, id, *, split, **kwargs): + def __init__(self, id, *, split, streaming: bool = True, **kwargs): self.id = id self.split = split + self.streaming = streaming self.kwargs = kwargs self._shard_names = self._compute_shard_names() @@ -184,7 +185,7 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[dict]: def _load_dataset(self): # obnoxiously, the dataset loading stuff doesn't work with ray because of multiprocessing # so we have to do this hacky thing where we load the dataset in the worker - return datasets.load_dataset(self.id, split=self.split, **self.kwargs) + return datasets.load_dataset(self.id, split=self.split, streaming=self.streaming, **self.kwargs) class TextUrlDataset(ShardedDataset[str]): diff --git a/src/levanter/logging.py b/src/levanter/logging.py index 4fbb4a618..78588669f 100644 --- a/src/levanter/logging.py +++ b/src/levanter/logging.py @@ -1,57 +1,32 @@ import contextlib -import dataclasses -import logging import logging as pylogging import os -import tempfile import time -import warnings -from dataclasses import dataclass from pathlib import Path -from typing import List, Optional, Union +from typing import List, Union -import draccus import jax -import wandb -from draccus import field -from git import InvalidGitRepositoryError, NoSuchPathError, Repo -from optax import MultiStepsState -from levanter.utils import jax_utils -from levanter.utils.jax_utils import jnp_to_python +pylogger = pylogging.getLogger(__name__) -logger = pylogging.getLogger(__name__) - -def log_optimizer_hyperparams(opt_state, prefix: Optional[str] = None, *, step=None): - if isinstance(opt_state, MultiStepsState): - opt_state = opt_state.inner_opt_state - - def wrap_key(key): - if prefix: - return f"{prefix}/{key}" - return key - - if hasattr(opt_state, "hyperparams"): - params = {wrap_key(k): jnp_to_python(v) for k, v in opt_state.hyperparams.items()} - wandb.log(params, step=step) - - -def init_logger(path: Union[str, Path], level: int = pylogging.INFO) -> None: +def init_logging(log_dir: Union[str, Path], run_id: str, level: int = pylogging.INFO) -> None: """ Initialize logging.Logger with the appropriate name, console, and file handlers. :param path: Path for writing log file :param level: Default logging level """ + log_dir = Path(log_dir) + log_dir.mkdir(parents=True, exist_ok=True) + path = log_dir / f"{run_id}.log" + process_index = jax.process_index() log_format = f"%(asctime)s - {process_index} - %(name)s - %(filename)s:%(lineno)d - %(levelname)s :: %(message)s" # use ISO 8601 format for timestamps, except no TZ, because who cares date_format = "%Y-%m-%dT%H:%M:%S" - os.makedirs(os.path.dirname(path), exist_ok=True) - handlers: List[pylogging.Handler] = [pylogging.FileHandler(path, mode="a"), pylogging.StreamHandler()] # Create Root Logger w/ Base Formatting @@ -64,13 +39,21 @@ def init_logger(path: Union[str, Path], level: int = pylogging.INFO) -> None: def save_xla_dumps_to_wandb(initial_time: float): import os + from levanter.tracker.wandb import is_wandb_available + + if not is_wandb_available(): + pylogger.warning("Wandb is not available, so we can't save XLA dumps") + return + + import wandb + # attempt to parse xla_flags to see if we're dumping assembly files flags = os.getenv("XLA_FLAGS", None) if flags is not None and "xla_dump_to" in flags: # parse the path # this isn't robust to quotes path = flags.split("xla_dump_to=")[1].split(" ")[0] - logger.info(f"Found xla_dump_to={path}, logging to wandb") + pylogger.info(f"Found xla_dump_to={path}, logging to wandb") if wandb.run: # only want to save the files that were generated during this run # XLA_FLAGS has to be set before the first jax call, so we can't just set it in the middle of the run @@ -82,7 +65,7 @@ def include_file(path: str): wandb.run.log_code(root=path, name="xla_dumps", include_fn=include_file) else: - logger.warning("XLA_FLAGS is not set to dump to a path, so we can't save the dumps to wandb") + pylogger.warning("XLA_FLAGS is not set to dump to a path, so we can't save the dumps to wandb") @contextlib.contextmanager @@ -100,23 +83,6 @@ def fn(): end = time.time() -@contextlib.contextmanager -def log_time_to_wandb(name: str, *, step=None): - with capture_time() as fn: - yield fn - wandb.log({name: fn()}, step=step) - - -def jittable_wandb_log(data, *, step=None): - """uses jax effect callback to log to wandb from the host""" - if is_wandb_available(): - jax.debug.callback(wandb.log, data, step=step) - - -def is_wandb_available(): - return wandb is not None and wandb.run is not None - - def silence_transformer_nag(): # this is a hack to silence the transformers' "None of PyTorch, TensorFlow 2.0 or Flax have been found..." thing # which is annoying and not useful @@ -125,172 +91,3 @@ def silence_transformer_nag(): os.environ["TRANSFORMERS_VERBOSITY"] = "error" import transformers # noqa: F401 - - -@dataclass -class WandbConfig: - """ - Configuration for wandb. - """ - - entity: Optional[str] = None # An entity is a username or team name where you send runs - project: Optional[str] = None # The name of the project where you are sending the enw run. - name: Optional[str] = None # A short display name for this run, which is how you'll identify this run in the UI. - tags: List[str] = field(default_factory=list) # Will populate the list of tags on this run in the UI. - id: Optional[str] = None # A unique ID for this run, used for resuming. It must be unique in the project - group: Optional[str] = None # Specify a group to organize individual runs into a larger experiment. - mode: Optional[str] = None # Can be "online", "offline" or "disabled". If None, it will be online. - resume: Optional[Union[bool, str]] = None # - """ - Set the resume behavior. Options: "allow", "must", "never", "auto" or None. - By default, if the new run has the same ID as a previous run, this run overwrites that data. - Please refer to [init](https://docs.wandb.ai/ref/python/init) and [resume](https://docs.wandb.ai/guides/runs/resuming) - document for more details. - """ - - save_code: Union[bool, str] = True - """If string, will save code from that directory. If True, will attempt to sniff out the main directory (since we - typically don't run from the root of the repo).""" - - save_xla_dumps: bool = False - """If True, will save the XLA code to wandb (as configured by XLA_FLAGS). This is useful for debugging.""" - - def init(self, run_id: Optional[str], hparams=None, **extra_hparams): - import wandb - - if run_id is not None and self.id is not None and run_id != self.id: - warnings.warn( - f"Both trainer's id {run_id} and WandB's id {self.id} are set. WandB will use the id set in its" - " config." - ) - - id = self.id - if id is None: - id = run_id - - if hparams is None: - hparams_to_save = {} - elif dataclasses.is_dataclass(hparams): - hparams_to_save = dataclasses.asdict(hparams) - else: - hparams_to_save = dict(hparams) - - if extra_hparams: - hparams_to_save.update(extra_hparams) - - # for distributed runs, we only want the primary worker to use wandb, so we make everyone else be disabled - # however, we do share information about the run id, so that we can link to it from the other workers - mode = self.mode - if jax.process_index() != 0: - mode = "disabled" - - if isinstance(self.save_code, str): - code_dir = self.save_code - elif self.save_code: - code_dir = WandbConfig._infer_experiment_git_root() or "." # type: ignore - else: - code_dir = None - - other_settings = dict() - if code_dir is not None: - logger.info(f"Setting wandb code_dir to {code_dir}") - other_settings["code_dir"] = code_dir - other_settings["git_root"] = code_dir - # for some reason, wandb isn't populating the git commit, so we do it here - try: - repo = Repo(code_dir) - other_settings["git_commit"] = repo.head.commit.hexsha - hparams_to_save["git_commit"] = repo.head.commit.hexsha - except (NoSuchPathError, InvalidGitRepositoryError): - logger.warning(f"Could not find git repo at {code_dir}") - pass - - r = wandb.init( - entity=self.entity, - project=self.project, - name=self.name, - tags=self.tags, - id=id, - group=self.group, - resume=self.resume, - mode=mode, - config=hparams_to_save, - settings=other_settings, - allow_val_change=True, - ) - - assert r is not None - - if jax.process_count() > 1: - # we need to share wandb run information across all hosts, because we use it for checkpoint paths and things - metadata_to_share = dict( - entity=r.entity, - project=r.project, - name=r.name, - tags=r.tags, - id=r.id, - group=r.group, - ) - metadata_to_share = jax_utils.multihost_broadcast_sync( - metadata_to_share, is_source=jax.process_index() == 0 - ) - - if jax.process_index() != 0: - assert r.mode == "disabled" - for k, v in metadata_to_share.items(): - setattr(r, k, v) - - logger.info(f"Synced wandb run information from process 0: {r.name} {r.id}") - - if dataclasses.is_dataclass(hparams): - with tempfile.TemporaryDirectory() as tmpdir: - config_path = os.path.join(tmpdir, "config.yaml") - with open(config_path, "w") as f: - draccus.dump(hparams, f, encoding="utf-8") - if wandb.run is not None: - wandb.run.log_artifact(str(config_path), name="config.yaml", type="config") - - # generate a pip freeze - with tempfile.TemporaryDirectory() as tmpdir: - requirements_path = os.path.join(tmpdir, "requirements.txt") - requirements = _generate_pip_freeze() - with open(requirements_path, "w") as f: - f.write(requirements) - if wandb.run is not None: - wandb.run.log_artifact(str(requirements_path), name="requirements.txt", type="requirements") - - wandb.summary["num_devices"] = jax.device_count() - wandb.summary["num_hosts"] = jax.process_count() - wandb.summary["backend"] = jax.default_backend() - - @staticmethod - def _infer_experiment_git_root() -> Optional[str | os.PathLike[str]]: - # sniff out the main directory (since we typically don't run from the root of the repo) - # we'll walk the stack and directories for the files in the stack the until we're at a git root - import os - import traceback - - stack = traceback.extract_stack() - # start from the top of the stack and work our way down since we want to hit the main file first - top_git_root = None - for frame in stack: - dirname = os.path.dirname(frame.filename) - # bit hacky but we want to skip anything that's in the python env - if any(x in dirname for x in ["site-packages", "dist-packages", "venv", "opt/homebrew", "conda", "pyenv"]): - continue - # see if it's under a git root - try: - repo = Repo(dirname, search_parent_directories=True) - top_git_root = repo.working_dir - break - except (NoSuchPathError, InvalidGitRepositoryError): - logger.debug(f"Skipping {dirname} since it's not a git root") - pass - return top_git_root - - -def _generate_pip_freeze(): - from importlib.metadata import distributions - - dists = distributions() - return "\n".join(f"{dist.name}=={dist.version}" for dist in dists) diff --git a/src/levanter/main/cache_dataset.py b/src/levanter/main/cache_dataset.py index 0b0636f4b..9ee6614ca 100644 --- a/src/levanter/main/cache_dataset.py +++ b/src/levanter/main/cache_dataset.py @@ -1,14 +1,13 @@ import logging import os -from dataclasses import dataclass - -import wandb +from dataclasses import dataclass, field import levanter from levanter.data.shard_cache import LoggingMetricsMonitor, RichMetricsMonitor, build_cache from levanter.data.text import BatchTokenizer, LMDatasetConfig from levanter.distributed import RayConfig -from levanter.logging import init_logger +from levanter.logging import init_logging +from levanter.tracker import NoopConfig, TrackerConfig logger = logging.getLogger(__name__) @@ -16,19 +15,17 @@ @dataclass class RayCachedLMDatasetConfig(LMDatasetConfig, RayConfig): - pass + tracker: TrackerConfig = field(default_factory=NoopConfig) @levanter.config.main() def main(args: RayCachedLMDatasetConfig): """Caches two different kinds of datasets. It can cache a dataset from a list of urls, or a dataset from a hf dataset""" - init_logger("cache_dataset.log") + init_logging(".", "cache_dataset.log") args.initialize() tokenizer = args.the_tokenizer - wandb.init(mode="offline") - for split in ["train", "validation"]: print(f"Caching {split} to {args.cache_dir}.") # connect or start the actor diff --git a/src/levanter/main/eval_lm.py b/src/levanter/main/eval_lm.py index 6262eb428..ab6d9d6b9 100644 --- a/src/levanter/main/eval_lm.py +++ b/src/levanter/main/eval_lm.py @@ -41,7 +41,7 @@ class EvalLmConfig: def main(config: EvalLmConfig): - config.trainer.initialize(config) + levanter.initialize(config) tokenizer = config.data.the_tokenizer Batch = Axis("batch", config.trainer.eval_batch_size) diff --git a/src/levanter/main/lora_lm.py b/src/levanter/main/lora_lm.py index 93d60588a..babe7d2fa 100644 --- a/src/levanter/main/lora_lm.py +++ b/src/levanter/main/lora_lm.py @@ -4,7 +4,6 @@ from typing import Optional import jax.random as jrandom -import wandb import haliax.random @@ -47,6 +46,7 @@ class LoraLmConfig: def main(config: LoraLmConfig): + levanter.initialize(config) tokenizer = config.data.the_tokenizer converter = HFCheckpointConverter.from_hf(config.initialize_from_hf, trust_remote_code=config.trust_remote_code) @@ -55,7 +55,6 @@ def main(config: LoraLmConfig): converter = converter.replaced(tokenizer=tokenizer) - config.trainer.initialize(config) model_config = converter.default_config # randomness in jax is tightly controlled by "keys" which are the states of the random number generators @@ -96,8 +95,14 @@ def compute_loss(model, example: LmExample, key=None): all_param_count = parameter_count(state.model) just_lora_params = parameter_count(trainer.trainable_params_only(state.model)) - wandb.summary["parameter_count"] = all_param_count - wandb.summary["trainable_parameter_count"] = just_lora_params + levanter.tracker.log_summary( + { + "parameter_count": all_param_count, + "trainable_parameter_count": just_lora_params, + "fraction_trainable": just_lora_params * 1.0 / all_param_count, + } + ) + logger.info(f"Total parameter count: {all_param_count}") logger.info(f"Trainable parameter count: {just_lora_params}") logger.info(f"Fraction of parameters that are trainable: {just_lora_params * 1.0 / all_param_count%.3}") diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py index f5b6e83b4..42c415b75 100644 --- a/src/levanter/main/train_lm.py +++ b/src/levanter/main/train_lm.py @@ -5,7 +5,6 @@ from typing import Optional, Union import jax.random as jrandom -import wandb import haliax as hax from haliax import Axis @@ -76,39 +75,40 @@ def main(config: TrainLmConfig): else: converter = None - # initialize training config *after* we've done the hf stuff b/c we might have changed the model config - config.trainer.initialize(config) + levanter.initialize(config) - # randomness in jax is tightly controlled by "keys" which are the states of the random number generators - # this makes deterministic training pretty easy - seed = config.trainer.seed - data_key, loader_key, model_key, training_key = jrandom.split(jrandom.PRNGKey(seed), 4) - - # some axes we need - Batch = config.trainer.TrainBatch - EvalBatch = config.trainer.EvalBatch - Pos = config.model.Pos - KeyPos = config.model.KeyPos - - # We have two axis_mappings: one for storing the model and optimizer states, and one for compute - # This allows Zero-3-style parameter sharding, where we shard the parameters and optimizer state across the mesh - compute_axis_mapping = config.trainer.compute_axis_mapping - parameter_axis_mapping = config.trainer.parameter_axis_mapping + optimizer = config.optimizer.build(config.trainer.num_train_steps) def compute_loss(model: LmHeadModel, example: LmExample, key=None): return model.compute_loss(example, key=key).scalar() - optimizer = config.optimizer.build(config.trainer.num_train_steps) - # Our trainer is a wrapper around the optimizer and compute_loss function that handles checkpointing and fsdp - trainer = Trainer(config.trainer, optimizer, compute_loss) - - eval_datasets = config.data.validation_sets(Pos.size) - train_dataset = CausalLmDataset( - config.data.train_set(Pos.size), Pos, KeyPos, ignore_index=config.data.ignore_token_id - ) + # Using the trainer as a context manager does 3 things: + # 1. Sets the device mesh + # 2. Sets the axis mapping (for fsdp) + # 3. Sets the global metrics tracker + with Trainer(config.trainer, optimizer, compute_loss) as trainer: + # randomness in jax is tightly controlled by "keys" which are the states of the random number generators + # this makes deterministic training pretty easy + seed = config.trainer.seed + data_key, loader_key, model_key, training_key = jrandom.split(jrandom.PRNGKey(seed), 4) + + # We have two axis_mappings: one for storing the model and optimizer states, and one for compute + # This allows Zero-3-style parameter sharding, where we shard the parameters and optimizer state across the mesh + compute_axis_mapping = trainer.compute_axis_mapping + parameter_axis_mapping = trainer.parameter_axis_mapping + + # some axes we need + Batch = config.trainer.TrainBatch + EvalBatch = config.trainer.EvalBatch + Pos = config.model.Pos + KeyPos = config.model.KeyPos + + eval_datasets = config.data.validation_sets(Pos.size) + train_dataset = CausalLmDataset( + config.data.train_set(Pos.size), Pos, KeyPos, ignore_index=config.data.ignore_token_id + ) - with trainer.device_mesh: # to do partitioning, our dimensions have to be divisible by the size of the physical axes they're mapped to # For most things, we just insist you specify the config right, but tokenizers often have strange numbers of # tokens: gpt-2 has 50257, for example. So we round up. @@ -135,7 +135,7 @@ def compute_loss(model: LmHeadModel, example: LmExample, key=None): else: logger.info("No checkpoint found. Starting from scratch.") - wandb.summary["parameter_count"] = parameter_count(state.model) + levanter.tracker.log_summary({"parameter_count": parameter_count(state.model)}) # boilerplate hooks and such trainer.add_default_hooks() diff --git a/src/levanter/main/viz_logprobs.py b/src/levanter/main/viz_logprobs.py index 370b20d59..b992cd3f5 100644 --- a/src/levanter/main/viz_logprobs.py +++ b/src/levanter/main/viz_logprobs.py @@ -36,12 +36,11 @@ class VizGpt2Config: def main(config: VizGpt2Config): - config.trainer.initialize(config) + levanter.initialize(config) tokenizer = config.data.the_tokenizer - EvalBatch = Axis("batch", config.trainer.eval_batch_size) - # some axes we use outside the model proper + EvalBatch = config.trainer.EvalBatch Pos = config.model.Pos KeyPos = config.model.KeyPos @@ -53,7 +52,6 @@ def main(config: VizGpt2Config): # some axes we use outside the model proper Pos = config.model.Pos - KeyPos = config.model.KeyPos compute_axis_mapping = config.trainer.compute_axis_mapping parameter_axis_mapping = config.trainer.parameter_axis_mapping diff --git a/src/levanter/tracker/__init__.py b/src/levanter/tracker/__init__.py new file mode 100644 index 000000000..69156c6a6 --- /dev/null +++ b/src/levanter/tracker/__init__.py @@ -0,0 +1,29 @@ +from levanter.tracker.helpers import log_optimizer_hyperparams +from levanter.tracker.tracker import CompositeTracker, NoopConfig, NoopTracker, Tracker, TrackerConfig +from levanter.tracker.tracker_fns import ( + current_tracker, + get_tracker, + jit_log_metrics, + log_configuration, + log_hyperparameters, + log_metrics, + log_summary, + set_global_tracker, +) + + +__all__ = [ + "Tracker", + "TrackerConfig", + "CompositeTracker", + "log_optimizer_hyperparams", + "NoopTracker", + "current_tracker", + "get_tracker", + "jit_log_metrics", + "log_configuration", + "log_metrics", + "log_summary", + "log_hyperparameters", + "set_global_tracker", +] diff --git a/src/levanter/tracker/helpers.py b/src/levanter/tracker/helpers.py new file mode 100644 index 000000000..1091840c5 --- /dev/null +++ b/src/levanter/tracker/helpers.py @@ -0,0 +1,75 @@ +import dataclasses +import logging +import os +from typing import Optional + +from git import InvalidGitRepositoryError, NoSuchPathError, Repo + +import levanter.tracker +from levanter.utils.jax_utils import jnp_to_python + + +logger = logging.getLogger(__name__) + + +def log_optimizer_hyperparams(opt_state, prefix: Optional[str] = None, *, step=None): + try: + from optax._src.wrappers import MultiStepsState + + if isinstance(opt_state, MultiStepsState): + opt_state = opt_state.inner_opt_state + except ImportError: + pass + + def wrap_key(key): + if prefix: + return f"{prefix}/{key}" + return key + + if hasattr(opt_state, "hyperparams"): + params = {wrap_key(k): jnp_to_python(v) for k, v in opt_state.hyperparams.items()} + levanter.tracker.log_metrics(params, step=step) + + +def hparams_to_dict(hparams, **extra_hparams): + if hparams is None: + hparams_to_save = {} + elif dataclasses.is_dataclass(hparams): + hparams_to_save = dataclasses.asdict(hparams) + else: + hparams_to_save = dict(hparams) + if extra_hparams: + hparams_to_save.update(extra_hparams) + return hparams_to_save + + +def infer_experiment_git_root() -> Optional[str | os.PathLike[str]]: + # sniff out the main directory (since we typically don't run from the root of the repo) + # we'll walk the stack and directories for the files in the stack the until we're at a git root + import os + import traceback + + stack = traceback.extract_stack() + # start from the top of the stack and work our way down since we want to hit the main file first + top_git_root = None + for frame in stack: + dirname = os.path.dirname(frame.filename) + # bit hacky but we want to skip anything that's in the python env + if any(x in dirname for x in ["site-packages", "dist-packages", "venv", "opt/homebrew", "conda", "pyenv"]): + continue + # see if it's under a git root + try: + repo = Repo(dirname, search_parent_directories=True) + top_git_root = repo.working_dir + break + except (NoSuchPathError, InvalidGitRepositoryError): + logger.debug(f"Skipping {dirname} since it's not a git root") + pass + return top_git_root + + +def generate_pip_freeze(): + from importlib.metadata import distributions + + dists = distributions() + return "\n".join(f"{dist.name}=={dist.version}" for dist in dists) diff --git a/src/levanter/tracker/tensorboard.py b/src/levanter/tracker/tensorboard.py new file mode 100644 index 000000000..bd3ee70ba --- /dev/null +++ b/src/levanter/tracker/tensorboard.py @@ -0,0 +1,81 @@ +import logging +import os +import typing +from dataclasses import dataclass +from typing import Any, Optional + +from levanter.tracker import Tracker, TrackerConfig + + +pylogger = logging.getLogger(__name__) + +if typing.TYPE_CHECKING: + from tensorboardX import SummaryWriter # noqa: F401 + + +class TensorboardTracker(Tracker): + name: str = "tensorboard" + + def __init__(self, writer: "SummaryWriter"): + self.writer = writer + + def log_hyperparameters(self, hparams: dict[str, Any]): + self.writer.add_hparams(hparams, {"dummy": 0}) + + def log(self, metrics: dict[str, Any], *, step, commit=None): + del commit + for k, v in metrics.items(): + self.writer.add_scalar(k, v, step) + + def log_summary(self, metrics: dict[str, Any]): + for k, v in metrics.items(): + self.writer.add_scalar(k, v, global_step=None) + + def log_artifact(self, artifact_path, *, name: Optional[str] = None, type: Optional[str] = None): + pylogger.error("TensorboardLogger does not support logging artifacts yet") + pass + + +@TrackerConfig.register_subclass("tensorboard") +@dataclass +class TensorboardConfig(TrackerConfig): + logdir: str = "tblogs" + comment: Optional[str] = "" + purge_step: Optional[int] = None + max_queue: Optional[int] = 10 + flush_secs: Optional[int] = 120 + filename_suffix: Optional[str] = "" + write_to_disk: Optional[bool] = True + + def init(self, run_id: Optional[str]) -> TensorboardTracker: + dir_to_write = self.logdir + if run_id is not None: + dir_to_write = os.path.join(dir_to_write, run_id) + + pylogger.info(f"Writing Tensorboard logs to {dir_to_write}") + + from tensorboardX import SummaryWriter # noqa: F811 + + writer = SummaryWriter( + dir_to_write, + comment=self.comment, + purge_step=self.purge_step, + max_queue=self.max_queue, + flush_secs=self.flush_secs, + filename_suffix=self.filename_suffix, + write_to_disk=self.write_to_disk, + ) + + return TensorboardTracker(writer) + + +def _flatten_nested_dict(d): + def items(): + for key, value in d.items(): + if isinstance(value, dict): + for subkey, subvalue in _flatten_nested_dict(value).items(): + yield key + "/" + subkey, subvalue + else: + yield key, value + + return dict(items()) diff --git a/src/levanter/tracker/tracker.py b/src/levanter/tracker/tracker.py new file mode 100644 index 000000000..8b6816f17 --- /dev/null +++ b/src/levanter/tracker/tracker.py @@ -0,0 +1,117 @@ +import abc +import dataclasses +import typing +from typing import Any, List, Optional + +import draccus + + +class Tracker(abc.ABC): + """ + A tracker is responsible for logging metrics, hyperparameters, and artifacts. + Meant to be used with the [levanter.tracker.current_tracker][] context manager, but can also be used directly. + + The name is borrowed from HF Accelerate. + + Examples: + >>> from levanter.tracker import current_tracker, log_metrics + >>> from levanter.tracker.wandb import WandbTracker + >>> with current_tracker(WandbTracker()): + ... log_metrics({"foo": 1}, step=0) + """ + + name: str + + @abc.abstractmethod + def log_hyperparameters(self, hparams: dict[str, Any]): + pass + + @abc.abstractmethod + def log(self, metrics: dict[str, typing.Any], *, step: Optional[int], commit: Optional[bool] = None): + """ + Log metrics to the tracker. Step is always required. + + Args: + metrics: Metrics to log + step: Step to log at + commit: Whether to commit the metrics. If None, uses the default for the tracker. + """ + pass + + @abc.abstractmethod + def log_summary(self, metrics: dict[str, Any]): + pass + + @abc.abstractmethod + def log_artifact(self, artifact_path, *, name: Optional[str] = None, type: Optional[str] = None): + pass + + def __enter__(self): + import levanter.tracker.tracker_fns as tracker_fns + + if hasattr(self, "_tracker_cm"): + raise RuntimeError("This tracker is already set as the global tracker") + setattr(self, "_tracker_cm", tracker_fns.current_tracker(self)) + self._tracker_cm.__enter__() + + def __exit__(self, exc_type, exc_val, exc_tb): + if not hasattr(self, "_tracker_cm"): + raise RuntimeError("This tracker is not set as the global tracker") + self._tracker_cm.__exit__(exc_type, exc_val, exc_tb) + delattr(self, "_tracker_cm") + + +class CompositeTracker(Tracker): + def __init__(self, loggers: List[Tracker]): + self.loggers = loggers + + def log_hyperparameters(self, hparams: dict[str, Any]): + for tracker in self.loggers: + tracker.log_hyperparameters(hparams) + + def log(self, metrics: dict[str, Any], *, step, commit=None): + for tracker in self.loggers: + tracker.log(metrics, step=step, commit=commit) + + def log_summary(self, metrics: dict[str, Any]): + for tracker in self.loggers: + tracker.log_summary(metrics) + + def log_artifact(self, artifact_path, *, name: Optional[str] = None, type: Optional[str] = None): + for tracker in self.loggers: + tracker.log_artifact(artifact_path, name=name, type=type) + + +class TrackerConfig(draccus.PluginRegistry, abc.ABC): + discover_packages_path = "levanter.tracker" + + @abc.abstractmethod + def init(self, run_id: Optional[str]) -> Tracker: + raise NotImplementedError + + @classmethod + def default_choice_name(cls) -> Optional[str]: + return "wandb" + + +class NoopTracker(Tracker): + name: str = "noop" + + def log_hyperparameters(self, hparams: dict[str, Any]): + pass + + def log(self, metrics: dict[str, Any], *, step, commit: Optional[bool] = None): + pass + + def log_summary(self, metrics: dict[str, Any]): + pass + + def log_artifact(self, artifact_path, *, name: Optional[str] = None, type: Optional[str] = None): + pass + + +@TrackerConfig.register_subclass("noop") +@dataclasses.dataclass +class NoopConfig(TrackerConfig): + def init(self, run_id: Optional[str]) -> Tracker: + return NoopTracker() diff --git a/src/levanter/tracker/tracker_fns.py b/src/levanter/tracker/tracker_fns.py new file mode 100644 index 000000000..e3b6a1f71 --- /dev/null +++ b/src/levanter/tracker/tracker_fns.py @@ -0,0 +1,235 @@ +import dataclasses +import logging +import os +import tempfile +import typing +import warnings +from contextlib import AbstractContextManager +from typing import Any, Literal, Optional + +import draccus +import jax + +from levanter.tracker import CompositeTracker, Tracker +from levanter.tracker.helpers import hparams_to_dict +from levanter.tracker.tensorboard import TensorboardTracker +from levanter.tracker.wandb import WandbTracker +from levanter.utils.jax_utils import is_inside_jit + + +logger = logging.getLogger(__name__) + + +_global_tracker: Optional["Tracker"] = None + + +def log_metrics(metrics: dict[str, Any], *, step: Optional[int], commit: Optional[bool] = None): + """ + Log metrics to the global tracker. + + Args: + metrics: Metrics to log + step: Step to log at + commit: Whether to commit the metrics. If None, uses the default for the tracker. + """ + global _global_tracker + if _global_tracker is None: + raise RuntimeError("No global tracker set") + + if is_inside_jit(): + # we're inside a jit, so we need to log from the host + if commit: + raise ValueError("Cannot commit from inside jit") + jit_log_metrics(metrics, step=step) + else: + # TODO: do we need to coerce to np here? + _global_tracker.log(metrics, step=step) + + +def _no_throw_log_metrics(metrics: dict[str, Any], *, step: Optional[int], commit: Optional[bool] = None): + try: + if _global_tracker is None: + raise RuntimeError("No global tracker set") + _global_tracker.log(metrics, step=step, commit=False) + except Exception: + logger.exception("Error logging metrics") + + +def jit_log_metrics(metrics, *, step=None): + """uses jax effect callback to log to wandb from the host""" + jax.debug.callback(_no_throw_log_metrics, metrics, step=step) + + +def log_summary(metrics: dict[str, Any]): + """ + Log summary metrics to the global tracker. + + Args: + metrics: Metrics to log + """ + global _global_tracker + if _global_tracker is None: + raise RuntimeError("No global tracker set") + _global_tracker.log_summary(metrics) + + +def log_hyperparameters(hparams: dict[str, Any]): + """ + Log hyperparameters to the global tracker. + + Args: + hparams: Hyperparameters to log + """ + global _global_tracker + if _global_tracker is None: + raise RuntimeError("No global tracker set") + + _global_tracker.log_hyperparameters(hparams) + + +def log_configuration(hparams: Any, config_name: Optional[str] = None): + """ + Logs a configuration object to the global tracker. If the configuration object is a dataclass, + it is dumped to a yaml file and logged as an artifact. + + Args: + hparams: Hyperparameters to log + """ + global _global_tracker + if _global_tracker is None: + raise RuntimeError("No global tracker set") + + hparams_dict = hparams_to_dict(hparams) + _global_tracker.log_hyperparameters(hparams_dict) + + if dataclasses.is_dataclass(hparams): + with tempfile.TemporaryDirectory() as tmpdir: + config_path = os.path.join(tmpdir, "config.yaml") + with open(config_path, "w") as f: + draccus.dump(hparams, f, encoding="utf-8") + name = config_name or "config.yaml" + _global_tracker.log_artifact(config_path, name=name, type="config") + + +def set_global_tracker(tracker: Tracker): + """ + Set the global tracker. Note that setting the global tracker is not thread-safe, + and using a tracker from multiple threads is only supported if the tracker itself is thread-safe. + + In general, it's preferred to use the context manager returned by `current_tracker` instead of this function + except for once at the beginning of the program. + + Args: + tracker: The tracker to set as the global tracker + force: Whether to force setting the global tracker even if it is already set + + Examples: + >>> from levanter.tracker import set_global_tracker, log_metrics + >>> from levanter.tracker.wandb import WandbTracker + >>> set_global_tracker(WandbTracker()) + >>> log_metrics({"foo": 1}, step=0) + """ + global _global_tracker + if _global_tracker is not None: + warnings.warn("Global tracker is already set. Overwriting it.") + _global_tracker = tracker + + +@typing.overload +def current_tracker() -> "Tracker": + ... + + +@typing.overload +def current_tracker(tracker: "Tracker") -> typing.ContextManager: + """Returns a context manager for setting the global tracker""" + ... + + +def current_tracker( + tracker: Optional[Tracker] = None, +) -> Tracker | typing.ContextManager: + """ + Get or set the global tracker. Note that setting the global tracker is not thread-safe, + and using a tracker from multiple threads is only supported if the tracker itself is thread-safe. + + Args: + tracker: If provided, returns a context manager that sets the global tracker to the provided tracker when used. + + Returns: + If no tracker is provided, returns the current global tracker. + If a tracker is provided, returns a context manager that sets the global tracker to the provided tracker when used. + + Examples: + >>> from levanter.tracker import current_tracker, log_metrics + >>> from levanter.tracker.wandb import WandbTracker + >>> with current_tracker(WandbTracker()): + ... log_metrics({"foo": 1}, step=0) + ... current_tracker().log({"foo": 2}, step=1) + """ + global _global_tracker + if tracker is None: + if _global_tracker is None: + raise RuntimeError("No global tracker set") + return _global_tracker + else: + return _GlobalLoggerContextManager(tracker) + + +@typing.overload +def get_tracker(name: Literal["wandb"]) -> WandbTracker: + ... + + +@typing.overload +def get_tracker(name: Literal["tensorboard"]) -> TensorboardTracker: + ... + + +@typing.overload +def get_tracker(name: str) -> Tracker: + ... + + +def get_tracker(name: str) -> Tracker: + """ + Lookup a tracker in the current global tracker with the provided name. + + Args: + name: Name of the tracker to lookup + + Returns: + The tracker with the provided name + + Examples: + >>> from levanter.tracker import get_tracker, log_metrics + >>> from levanter.tracker.wandb import WandbTracker + >>> with current_tracker(WandbTracker()): + ... log_metrics({"foo": 1}, step=0) + ... get_tracker("wandb").log_metrics({"foo": 2}, step=1) + """ + tracker = current_tracker() + if isinstance(tracker, CompositeTracker): + for t in tracker.loggers: + if t.name == name: + return t + elif tracker.name == name: + return tracker + + raise KeyError(f"Tracker with name {name} not found") + + +class _GlobalLoggerContextManager(AbstractContextManager): + def __init__(self, tracker: "Tracker"): + self.tracker = tracker + + def __enter__(self): + global _global_tracker + self.old_tracker = _global_tracker + _global_tracker = self.tracker + + return self.tracker + + def __exit__(self, exc_type, exc_val, exc_tb): + global _global_tracker + _global_tracker = self.old_tracker diff --git a/src/levanter/tracker/wandb.py b/src/levanter/tracker/wandb.py new file mode 100644 index 000000000..d217ab000 --- /dev/null +++ b/src/levanter/tracker/wandb.py @@ -0,0 +1,199 @@ +import logging +import os +import tempfile +import typing +import warnings +from dataclasses import dataclass +from typing import Any, List, Optional, Union + +import jax +from draccus import field +from git import InvalidGitRepositoryError, NoSuchPathError, Repo + +from levanter.tracker import Tracker +from levanter.tracker.helpers import generate_pip_freeze, infer_experiment_git_root +from levanter.tracker.tracker import TrackerConfig +from levanter.utils import jax_utils + + +if typing.TYPE_CHECKING: + import wandb + import wandb.sdk.lib.disabled + + +logger = logging.getLogger(__name__) + +WandbRun = Union["wandb.sdk.wandb_run.Run", "wandb.sdk.lib.disabled.RunDisabled"] + + +class WandbTracker(Tracker): + name: str = "wandb" + run: WandbRun + + def __init__(self, run: Optional[WandbRun]): + import wandb + + if run is None: + if wandb.run is None: + logger.warning("Wandb run is not initialized. Initializing a new run.") + runx = wandb.init() + if runx is None: + raise RuntimeError("Wandb run is not initialized.") + self.run = runx + else: + self.run = wandb.run + else: + self.run = run + + def log_hyperparameters(self, hparams: dict[str, Any]): + self.run.config.update(hparams, allow_val_change=True) + + def log(self, metrics: dict[str, Any], *, step, commit=None): + if step is None and not commit: + step = self.run.step + + self.run.log(metrics, step=step, commit=commit) + + def log_summary(self, metrics: dict[str, Any]): + self.run.summary.update(metrics) + + def log_artifact(self, artifact_path, *, name: Optional[str] = None, type: Optional[str] = None): + self.run.log_artifact(artifact_path, name=name, type=type) + + +def is_wandb_available(): + try: + import wandb + except ImportError: + return False + return wandb is not None and wandb.run is not None + + +@TrackerConfig.register_subclass("wandb") +@dataclass +class WandbConfig(TrackerConfig): + """ + Configuration for wandb. + """ + + entity: Optional[str] = None # An entity is a username or team name where you send runs + project: Optional[str] = None # The name of the project where you are sending the enw run. + name: Optional[str] = None # A short display name for this run, which is how you'll identify this run in the UI. + tags: List[str] = field(default_factory=list) # Will populate the list of tags on this run in the UI. + id: Optional[str] = None # A unique ID for this run, used for resuming. It must be unique in the project + group: Optional[str] = None # Specify a group to organize individual runs into a larger experiment. + mode: Optional[str] = None # Can be "online", "offline" or "disabled". If None, it will be whatever W&B decides. + resume: Optional[Union[bool, str]] = None + """ + Set the resume behavior. Options: "allow", "must", "never", "auto" or None. + By default, if the new run has the same ID as a previous run, this run overwrites that data. + Please refer to [init](https://docs.wandb.ai/ref/python/init) and [resume](https://docs.wandb.ai/guides/runs/resuming) + document for more details. + """ + + save_code: Union[bool, str] = True + """If string, will save code from that directory. If True, will attempt to sniff out the main directory (since we + typically don't run from the root of the repo).""" + + save_xla_dumps: bool = False + """If True, will save the XLA code to wandb (as configured by XLA_FLAGS). This is useful for debugging.""" + + def init(self, run_id: Optional[str]) -> WandbTracker: + import wandb + + if run_id is not None and self.id is not None and run_id != self.id: + warnings.warn( + f"Both trainer's id {run_id} and WandB's id {self.id} are set. WandB will use the id set in its" + " config." + ) + + id = self.id + if id is None: + id = run_id + + hparams_to_save = {} + + # for distributed runs, we only want the primary worker to use wandb, so we make everyone else be disabled + # however, we do share information about the run id, so that we can link to it from the other workers + if jax.process_index() == 0: + mode = self.mode + else: + mode = "disabled" + + git_settings = self._git_settings() + + if "git_commit" in git_settings: + hparams_to_save["git_commit"] = git_settings["git_commit"] + + r = wandb.init( + entity=self.entity, + project=self.project, + name=self.name, + tags=self.tags, + id=id, + group=self.group, + resume=self.resume, + mode=mode, + config=hparams_to_save, + settings=git_settings, + allow_val_change=True, + ) + + assert r is not None + + if jax.process_count() > 1: + # we need to share wandb run information across all hosts, because we use it for checkpoint paths and things + metadata_to_share = dict( + entity=r.entity, + project=r.project, + name=r.name, + tags=r.tags, + id=r.id, + group=r.group, + ) + metadata_to_share = jax_utils.multihost_broadcast_sync( + metadata_to_share, is_source=jax.process_index() == 0 + ) + + if jax.process_index() != 0: + assert r.mode == "disabled" + for k, v in metadata_to_share.items(): + setattr(r, k, v) + + logger.info(f"Synced wandb run information from process 0: {r.name} {r.id}") + + # generate a pip freeze + with tempfile.TemporaryDirectory() as tmpdir: + requirements_path = os.path.join(tmpdir, "requirements.txt") + requirements = generate_pip_freeze() + with open(requirements_path, "w") as f: + f.write(requirements) + if wandb.run is not None: + wandb.run.log_artifact(str(requirements_path), name="requirements.txt", type="requirements") + + wandb.summary["num_devices"] = jax.device_count() + wandb.summary["num_hosts"] = jax.process_count() + wandb.summary["backend"] = jax.default_backend() + + return WandbTracker(r) + + def _git_settings(self): + other_settings = dict() + if isinstance(self.save_code, str): + code_dir = self.save_code + elif self.save_code: + code_dir = infer_experiment_git_root() or "." # type: ignore + else: + code_dir = None + if code_dir is not None: + logger.info(f"Setting wandb code_dir to {code_dir}") + other_settings["code_dir"] = code_dir + other_settings["git_root"] = code_dir + # for some reason, wandb isn't populating the git commit, so we do it here + try: + repo = Repo(code_dir) + other_settings["git_commit"] = repo.head.commit.hexsha + except (NoSuchPathError, InvalidGitRepositoryError): + logger.warning(f"Could not find git repo at {code_dir}") + pass + return other_settings diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index d9db8dc91..5577c6406 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -5,16 +5,30 @@ import os import sys import typing +import warnings from dataclasses import dataclass from functools import cached_property from pathlib import Path -from typing import Any, Callable, Dict, Generic, Iterable, List, Mapping, Optional, Tuple, TypeVar, Union +from typing import ( + Any, + Callable, + Dict, + Generic, + Iterable, + List, + Mapping, + Optional, + Protocol, + Sequence, + Tuple, + TypeVar, + Union, +) import equinox as eqx import jax import jmp import numpy as np -import wandb from draccus import field from jax import ShapeDtypeStruct from jax.experimental import multihost_utils @@ -28,12 +42,16 @@ from haliax.types import Scalar import levanter.logging +import levanter.tracker +import levanter.tracker.wandb +from levanter import tracker from levanter.checkpoint import CheckpointerConfig from levanter.config import JsonAtom from levanter.data import Dataset, ReplicatedBatchLoader, ShardableDataset, ShardedBatchLoader from levanter.distributed import DistributedConfig, RayConfig from levanter.grad_accum import microbatched -from levanter.logging import WandbConfig, capture_time +from levanter.logging import capture_time +from levanter.tracker import TrackerConfig from levanter.types import FilterSpec from levanter.utils import cloud_utils from levanter.utils.jax_utils import is_inexact_arrayish @@ -112,8 +130,10 @@ class Trainer: config: "TrainerConfig" optimizer: GradientTransformation hooks: TrainerHooks + tracker: levanter.tracker.Tracker is_trainable_param: Optional[PyTree[FilterSpec]] _raw_loss_function: Callable + _cmanagers: List[typing.ContextManager] = [] def __init__( self, @@ -140,6 +160,8 @@ def __init__( self.optimizer = optimizer self.is_trainable_param = is_trainable + self._cmanagers = [] + @cached_property def loss_fn(self): """ @@ -204,6 +226,34 @@ def TrainBatch(self): def EvalBatch(self): return self.config.EvalBatch + def __enter__(self): + if len(self._cmanagers) > 0: + raise RuntimeError("Trainer is already entered") + + self._cmanagers = [ + # levanter.current_tracker(self.tracker), + self.device_mesh, + hax.axis_mapping(self.parameter_axis_mapping), + ] + + for cmanager in self._cmanagers: + cmanager.__enter__() + + return self + + def __exit__(self, *args): + problems = [] + for cmanager in reversed(self._cmanagers): + try: + cmanager.__exit__(*args) + except Exception as e: + problems.append(e) + + self._cmanagers = [] + + if len(problems) > 0: + raise RuntimeError("Exception(s) occurred while exiting trainer", problems) from problems[0] + def initial_state( self, training_key: PRNGKeyArray, model: Optional[M] = None, model_init: Optional[Callable[[], M]] = None ) -> TrainerState: @@ -213,7 +263,6 @@ def initial_state( Returns: model, opt_state, key, resume_step """ - if model is not None and model_init is not None: raise ValueError("only one of model and model_init should be specified") elif model is None and model_init is None: @@ -306,8 +355,7 @@ def training_steps( with capture_time() as loading_time: example = next(iter_data) - # TODO: refactor logging - wandb.log({"throughput/loading_time": loading_time()}, step=state.step) + levanter.tracker.log_metrics({"throughput/loading_time": loading_time()}, step=state.step) info = self.train_step(state, example) state = info.state @@ -316,7 +364,7 @@ def training_steps( with capture_time() as hook_time: self.run_hooks(info) - wandb.log({"throughput/hook_time": hook_time()}, step=state.step) + levanter.tracker.log_metrics({"throughput/hook_time": hook_time()}, step=state.step) yield info @@ -337,10 +385,9 @@ def add_default_hooks(self, eval_dataset: Optional[Iterable[X]] = None): from levanter import callbacks self.add_hook(callbacks.pbar_logger(total=self.config.num_train_steps), every=1) - self.add_hook(callbacks.log_to_wandb, every=1) + self.add_hook(callbacks.log_step_info, every=1) if eval_dataset is not None: self.add_eval_hook(eval_dataset) - self.add_hook(callbacks.wandb_xla_logger(self.config.wandb), every=self.config.steps_per_eval) # engine.add_hook(callbacks.log_memory_usage(), every=1) checkpointer = self.config.checkpointer.create(self.run_id, self.is_trainable_param) self.add_hook(checkpointer.on_step, every=1) # checkpointer manages its own frequency @@ -409,7 +456,9 @@ def split_loss_fn(trainable_model, *batch, **batch_kwargs): loss, grads = self._compute_gradients_microbatched(split_loss_fn, trainable_model, batch, **batch_kwargs) - partial_fn = lambda model: split_loss_fn(model, *batch, **batch_kwargs) + updates, opt_state = self.optimizer.update(grads, opt_state, params=trainable_model) + + partial_fn = lambda model: self.loss_fn(model, *batch, **batch_kwargs) updates, opt_state = self.optimizer.update(grads, opt_state, params=trainable_model, obj_fn=partial_fn) model = eqx.apply_updates(model, updates) @@ -500,16 +549,27 @@ def maybe_load_checkpoint( return None +def _initialize_global_tracker(config, run_id): + if isinstance(config, Sequence): + tracker = levanter.tracker.CompositeTracker([c.init(run_id) for c in config]) + else: + tracker = config.init(run_id) + + levanter.tracker.set_global_tracker(tracker) + + @dataclass class TrainerConfig: seed: int = 0 # random seed mp: jmp.Policy = jmp.get_policy("f32") # mixed precision policy - wandb: WandbConfig = field(default_factory=WandbConfig) + wandb: Optional[tracker.wandb.WandbConfig] = None log_dir: Path = Path("logs/") run_base_dir: Path = Path("runs/") id: Optional[str] = None # run id. if None, will be set to a random string + tracker: TrackerConfig | Tuple[TrackerConfig, ...] = field(default_factory=tracker.wandb.WandbConfig) + # config related to partitioning batch_axis: Optional[str] = "batch" # Batch axis for data parallel. @@ -557,15 +617,6 @@ class TrainerConfig: # whether or not to shutdown the tpu at exit. If a float, shutdown after that many seconds. True = 5 minutes shutdown_at_exit: Union[bool, float] = False - @property - def run_name(self) -> str: - try: - import wandb - - return wandb.run and (wandb.run.name or wandb.run.id) or "unnamed" - except ImportError: - return "unnamed" - @property def TrainBatch(self): return Axis("batch", self.train_batch_size) @@ -578,7 +629,12 @@ def EvalBatch(self): def microbatch_size(self): return self.per_device_parallelism * self.data_axis_size - def initialize(self, all_config): + def __post_init__(self): + if self.wandb is not None: + warnings.warn("wandb is deprecated. use tracker with type wandb instead", DeprecationWarning) + self.tracker = self.wandb + + def initialize(self): """Initializes jax, wandb, logging, setting the run name/id in the process""" self._initialize_jax_config() # Can't do full logging setup until we've initialized jax b/c we use jax for rank id @@ -587,8 +643,8 @@ def initialize(self, all_config): self._validate_and_set_defaults() id = self._maybe_set_id() - levanter.logging.init_logger(f"{self.log_dir}/{id}.log") - self.wandb.init(id, all_config) + levanter.logging.init_logging(self.log_dir, f"{id}.log") + _initialize_global_tracker(self.tracker, id) self.ray.initialize() @@ -668,7 +724,7 @@ def _maybe_set_id(self): # TODO: this doesn't work with wandb sweeps. need to reconcile when we merge if "RUN_ID" in os.environ: self.id = os.environ["RUN_ID"] - elif self.wandb.id is not None: + elif self.wandb is not None and self.wandb.id is not None: self.id = self.wandb.id else: # wandb run ids are 8 characters [a-z0-9], which we'll emulate here @@ -708,5 +764,21 @@ def _validate_and_set_defaults(self): self.per_device_eval_parallelism = self.per_device_parallelism +class AllConfig(Protocol): + trainer: TrainerConfig + + +def initialize(config: TrainerConfig | AllConfig): + """Initializes jax, logging, setting the run name/id in the process. Also initializes tracking and saves config + as hyperparameters and an artifact""" + if isinstance(config, TrainerConfig): + trainer_config = config + else: + trainer_config = config.trainer + + trainer_config.initialize() + levanter.tracker.log_configuration(config) + + def _params_only(t): return eqx.filter(t, is_inexact_arrayish) diff --git a/tests/test_eval_lm.py b/tests/test_eval_lm.py index f1193f4f4..178069f26 100644 --- a/tests/test_eval_lm.py +++ b/tests/test_eval_lm.py @@ -11,8 +11,8 @@ import tiny_test_corpus from levanter.checkpoint import save_checkpoint from levanter.distributed import RayConfig -from levanter.logging import WandbConfig from levanter.models.gpt2 import Gpt2LMHeadModel +from levanter.tracker.wandb import WandbConfig from levanter.utils.py_utils import logical_cpu_core_count diff --git a/tests/test_export_to_hf.py b/tests/test_export_to_hf.py index b50bde9cb..3ce092789 100644 --- a/tests/test_export_to_hf.py +++ b/tests/test_export_to_hf.py @@ -50,8 +50,7 @@ def test_export_lm_to_hf(): export_lm_to_hf.main(config) if has_torch(): - m = AutoModelForCausalLM.from_pretrained(f"{tmpdir}/output") - print(m) + AutoModelForCausalLM.from_pretrained(f"{tmpdir}/output") finally: try: diff --git a/tests/test_logging.py b/tests/test_logging.py index cf99b8c35..ab7cc35f2 100644 --- a/tests/test_logging.py +++ b/tests/test_logging.py @@ -3,7 +3,7 @@ import pytest from git import InvalidGitRepositoryError, NoSuchPathError, Repo -from levanter.logging import WandbConfig +from levanter.tracker.helpers import infer_experiment_git_root def test_infer_experiment_git_root(): @@ -13,7 +13,7 @@ def test_infer_experiment_git_root(): except (InvalidGitRepositoryError, NoSuchPathError): pytest.skip("test not running in a git repo") - root = WandbConfig._infer_experiment_git_root() + root = infer_experiment_git_root() # ensure that 1) this is a git root and 2) this source file is underneath assert root is not None diff --git a/tests/test_tracker.py b/tests/test_tracker.py new file mode 100644 index 000000000..15485b83e --- /dev/null +++ b/tests/test_tracker.py @@ -0,0 +1,80 @@ +# NOTE: Do not explicitly import wandb/other trackers here, as this will cause the tests to trivially pass. +import dataclasses +from typing import Tuple + +import pytest +import yaml + +import levanter.tracker +from levanter.tracker import CompositeTracker, TrackerConfig + + +def test_tracker_plugin_stuff_works(): + assert TrackerConfig.get_choice_class("wandb") is not None + with pytest.raises(KeyError): + TrackerConfig.get_choice_class("foo") + + +def test_tracker_plugin_default_works(): + config = """ + tracker: + entity: foo + """ + parsed = yaml.safe_load(config) + + @dataclasses.dataclass + class ConfigHolder: + tracker: TrackerConfig + + import draccus + + tconfig = draccus.decode(ConfigHolder, parsed).tracker + + assert isinstance(tconfig, TrackerConfig.get_choice_class("wandb")) + + assert tconfig.entity == "foo" # type: ignore + + +def test_tracker_plugin_multi_parsing_work(): + config = """ + tracker: + type: noop + """ + parsed = yaml.safe_load(config) + + @dataclasses.dataclass + class ConfigHolder: + tracker: TrackerConfig | Tuple[TrackerConfig, ...] + + import draccus + + from levanter.tracker.tracker import NoopConfig + + assert isinstance(draccus.decode(ConfigHolder, parsed).tracker, NoopConfig) + + config = """ + tracker: + - type: noop + - type: wandb + """ + parsed = yaml.safe_load(config) + decoded = draccus.decode(ConfigHolder, parsed).tracker + assert decoded == (NoopConfig(), TrackerConfig.get_choice_class("wandb")()) + + +def test_get_tracker_by_name(): + wandb_config = TrackerConfig.get_choice_class("wandb") + if wandb_config is None: + pytest.skip("wandb not installed") + + from levanter.tracker import NoopTracker + + wandb1 = wandb_config(mode="disabled").init(None) + tracker = CompositeTracker([wandb1, NoopTracker()]) + + with tracker: + assert levanter.tracker.get_tracker("wandb") is wandb1 + assert levanter.tracker.get_tracker("noop") is not None + + with pytest.raises(KeyError): + levanter.tracker.get_tracker("foo") diff --git a/tests/test_train_lm.py b/tests/test_train_lm.py index 3cd762d8b..f95b27efb 100644 --- a/tests/test_train_lm.py +++ b/tests/test_train_lm.py @@ -8,7 +8,7 @@ import levanter.main.train_lm as train_lm import tiny_test_corpus from levanter.distributed import RayConfig -from levanter.logging import WandbConfig +from levanter.tracker.wandb import WandbConfig from levanter.utils.py_utils import logical_cpu_core_count diff --git a/tests/test_viz_lm.py b/tests/test_viz_lm.py index 665c98772..29d8f943c 100644 --- a/tests/test_viz_lm.py +++ b/tests/test_viz_lm.py @@ -11,14 +11,18 @@ import tiny_test_corpus from levanter.checkpoint import save_checkpoint from levanter.distributed import RayConfig -from levanter.logging import WandbConfig from levanter.models.gpt2 import Gpt2Config, Gpt2LMHeadModel +from levanter.tracker.wandb import WandbConfig from levanter.utils.py_utils import logical_cpu_core_count def setup_module(module): ray_designated_cores = max(1, logical_cpu_core_count()) - ray.init("local", num_cpus=ray_designated_cores) + try: + ray.init("local", num_cpus=ray_designated_cores) + except AssertionError: + # don't get upset if ray is already running + pass def teardown_module(module): From 5de54dc4a2b45b3369ec2bcf2f168eca6b266554 Mon Sep 17 00:00:00 2001 From: David Hall Date: Sat, 10 Feb 2024 00:06:56 -0800 Subject: [PATCH 2/5] default to adding default_hooks (#460) --- examples/alpaca-lora/alpaca_lora.py | 1 - examples/alpaca/alpaca.py | 1 - examples/gsm8k-lora/gsm8k_lora.py | 2 -- src/levanter/main/lora_lm.py | 2 +- src/levanter/main/train_lm.py | 2 -- src/levanter/trainer.py | 8 +++++--- 6 files changed, 6 insertions(+), 10 deletions(-) diff --git a/examples/alpaca-lora/alpaca_lora.py b/examples/alpaca-lora/alpaca_lora.py index 0e7c5790e..87f51a7fc 100644 --- a/examples/alpaca-lora/alpaca_lora.py +++ b/examples/alpaca-lora/alpaca_lora.py @@ -103,7 +103,6 @@ def compute_loss(model: LmHeadModel, example: LmExample, key=None): # end major difference from Alpaca with Trainer(config.trainer, optimizer, compute_loss, is_trainable=lora_param_filter) as trainer: - trainer.add_default_hooks() state = trainer.initial_state(training_key, model=model) # log some info about the model diff --git a/examples/alpaca/alpaca.py b/examples/alpaca/alpaca.py index a20f357fe..cfe07a1e4 100644 --- a/examples/alpaca/alpaca.py +++ b/examples/alpaca/alpaca.py @@ -249,7 +249,6 @@ def compute_loss(model: LmHeadModel, example: LmExample, key=None): loader = trainer.replicated_loader(train_dataset, trainer.TrainBatch) loader = non_caching_cycle(loader) - trainer.add_default_hooks() state = trainer.initial_state(training_key, model=model) if state.step != 0: diff --git a/examples/gsm8k-lora/gsm8k_lora.py b/examples/gsm8k-lora/gsm8k_lora.py index 7361ed864..febfd2013 100644 --- a/examples/gsm8k-lora/gsm8k_lora.py +++ b/examples/gsm8k-lora/gsm8k_lora.py @@ -195,8 +195,6 @@ def compute_loss(model: LmHeadModel, example: LmExample, key=None): loader = trainer.replicated_loader(train_dataset, trainer.TrainBatch) loader = non_caching_cycle(loader) - trainer.add_default_hooks() - if state.step != 0: logger.info(f"Resuming training from step {state.step}") for i in range(state.step): diff --git a/src/levanter/main/lora_lm.py b/src/levanter/main/lora_lm.py index babe7d2fa..5120c9e22 100644 --- a/src/levanter/main/lora_lm.py +++ b/src/levanter/main/lora_lm.py @@ -114,7 +114,7 @@ def compute_loss(model, example: LmExample, key=None): train_loader = trainer.sharded_loader(train_dataset, Batch) # boilerplate hooks and such - trainer.add_default_hooks(eval_dataset) + trainer.add_eval_hook(eval_dataset) trainer.add_hook(callbacks.log_performance_stats(Pos.size, trainer.config.train_batch_size), every=1) if config.peft_save_path is not None: full_save_path = os.path.join(config.peft_save_path, trainer.run_id) diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py index 42c415b75..2dbd705d5 100644 --- a/src/levanter/main/train_lm.py +++ b/src/levanter/main/train_lm.py @@ -138,8 +138,6 @@ def compute_loss(model: LmHeadModel, example: LmExample, key=None): levanter.tracker.log_summary({"parameter_count": parameter_count(state.model)}) # boilerplate hooks and such - trainer.add_default_hooks() - if len(eval_datasets) == 0: logger.warning("No evaluation datasets provided.") diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index 5577c6406..7d8661c91 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -142,6 +142,7 @@ def __init__( loss_fn: Callable, *, is_trainable: PyTree[FilterSpec] = True, + add_default_hooks: bool = True, ): """ @@ -162,6 +163,9 @@ def __init__( self._cmanagers = [] + if add_default_hooks: + self._add_default_hooks() + @cached_property def loss_fn(self): """ @@ -381,13 +385,11 @@ def train(self, state: TrainerState[M], train_loader: Iterable[X], run_hooks: bo return info - def add_default_hooks(self, eval_dataset: Optional[Iterable[X]] = None): + def _add_default_hooks(self): from levanter import callbacks self.add_hook(callbacks.pbar_logger(total=self.config.num_train_steps), every=1) self.add_hook(callbacks.log_step_info, every=1) - if eval_dataset is not None: - self.add_eval_hook(eval_dataset) # engine.add_hook(callbacks.log_memory_usage(), every=1) checkpointer = self.config.checkpointer.create(self.run_id, self.is_trainable_param) self.add_hook(checkpointer.on_step, every=1) # checkpointer manages its own frequency From 50f72d7a82cad46fca23bcb531c366da86047eae Mon Sep 17 00:00:00 2001 From: David Hall Date: Sat, 10 Feb 2024 23:38:02 -0800 Subject: [PATCH 3/5] Misc fixes (#461) * ensure we do the final softmax in fp32 * tweak number of cores for tokenization * use the trainer better * switch to tracker in shard_cache * expose current_tracker in levanter base package * add back in logging of sophia metrics * config fixes * missed a spot * raise if there's validation set in eval_lm * is precommit being crappy? * sigh --- .github/workflows/run_pre_commit.yaml | 12 ++---------- config/gpt2_small_pile.yaml | 2 +- config/gpt2_small_pile_mixture.yaml | 2 +- examples/alpaca/alpaca.py | 4 +--- src/levanter/__init__.py | 1 + src/levanter/data/shard_cache.py | 5 +++-- src/levanter/main/eval_lm.py | 6 +++++- src/levanter/models/lm_model.py | 2 +- src/levanter/optim/sophia.py | 6 ++---- src/levanter/utils/hf_utils.py | 3 ++- 10 files changed, 19 insertions(+), 24 deletions(-) diff --git a/.github/workflows/run_pre_commit.yaml b/.github/workflows/run_pre_commit.yaml index eb7a4b214..e80facaeb 100644 --- a/.github/workflows/run_pre_commit.yaml +++ b/.github/workflows/run_pre_commit.yaml @@ -9,6 +9,7 @@ jobs: strategy: matrix: python-version: ["3.10"] + jax-version: ["0.4.14"] steps: - uses: actions/checkout@v3 @@ -20,16 +21,7 @@ jobs: run: | python -m pip install --upgrade pip pip install flake8 pytest pre-commit - pip install --upgrade "jax[cpu]==0.4.11" "jaxlib==0.4.11" - # install haliax from source b/c it's changing in parallel with this repo - pip install git+https://github.com/stanford-crfm/haliax.git - pip install . -# - name: Lint with flake8 -# run: | -# # stop the build if there are Python syntax errors or undefined names -# flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics -# # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide -# flake8 . --count --exit-zero --max-complexity=50 --max-line-length=127 --statistics + pip install . "jax[cpu]==${{ matrix.jax-version }}" "jaxlib==${{ matrix.jax-version }}" - name: "Run Pre-commit" run: | pre-commit run --all-files --show-diff-on-failure diff --git a/config/gpt2_small_pile.yaml b/config/gpt2_small_pile.yaml index ab7503871..19512c3dd 100644 --- a/config/gpt2_small_pile.yaml +++ b/config/gpt2_small_pile.yaml @@ -8,7 +8,7 @@ model: gradient_checkpointing: true scale_attn_by_inverse_layer_idx: true trainer: - wandb: + tracker: project: "levanter" tags: [ "pile", "gpt2"] diff --git a/config/gpt2_small_pile_mixture.yaml b/config/gpt2_small_pile_mixture.yaml index e02e4bd1f..a79ec8052 100644 --- a/config/gpt2_small_pile_mixture.yaml +++ b/config/gpt2_small_pile_mixture.yaml @@ -8,7 +8,7 @@ model: gradient_checkpointing: true scale_attn_by_inverse_layer_idx: true trainer: - wandb: + tracker: project: "levanter" tags: [ "pile", "gpt2"] diff --git a/examples/alpaca/alpaca.py b/examples/alpaca/alpaca.py index cfe07a1e4..de1fde555 100644 --- a/examples/alpaca/alpaca.py +++ b/examples/alpaca/alpaca.py @@ -230,9 +230,7 @@ def train(config: TrainArgs): def compute_loss(model: LmHeadModel, example: LmExample, key=None): return model.compute_loss(example, key=key).scalar() - trainer = Trainer(config.trainer, optimizer, compute_loss) - - with trainer.device_mesh: + with Trainer(config.trainer, optimizer, compute_loss) as trainer: # how we shard parameters across devices parameter_axis_mapping = trainer.parameter_axis_mapping diff --git a/src/levanter/__init__.py b/src/levanter/__init__.py index 548a113a0..3042154e2 100644 --- a/src/levanter/__init__.py +++ b/src/levanter/__init__.py @@ -8,4 +8,5 @@ import levanter.tracker as tracker import levanter.trainer as trainer import levanter.visualization as visualization +from levanter.tracker import current_tracker from levanter.trainer import initialize diff --git a/src/levanter/data/shard_cache.py b/src/levanter/data/shard_cache.py index 569bbe711..f5faa9b36 100644 --- a/src/levanter/data/shard_cache.py +++ b/src/levanter/data/shard_cache.py @@ -16,7 +16,6 @@ import pyarrow as pa import pyarrow.parquet as pq import ray -import wandb from dataclasses_json import dataclass_json from fsspec import AbstractFileSystem from ray.actor import ActorHandle @@ -31,6 +30,8 @@ TimeRemainingColumn, ) +import levanter.tracker + from .. import logging from ..utils.ray_utils import ExceptionInfo, RefBox, current_actor_handle, ser_exc_info from ._preprocessor import BatchProcessor, BatchResult, as_record_batch, dict_from_record_batch @@ -739,7 +740,7 @@ def __call__(self, metrics: InProgressCacheMetrics): self.last_metrics = metrics self.last_time = time.time() - wandb.log(to_log, commit=self.commit) + levanter.tracker.log_metrics(to_log, step=None, commit=self.commit) class LoggerMetricsMonitor(MetricsMonitor): diff --git a/src/levanter/main/eval_lm.py b/src/levanter/main/eval_lm.py index ab6d9d6b9..9b056b950 100644 --- a/src/levanter/main/eval_lm.py +++ b/src/levanter/main/eval_lm.py @@ -51,7 +51,11 @@ def main(config: EvalLmConfig): if config.eval_on_train: raw_dataset = CausalLmDataset(config.data.train_set(Pos.size), Pos, KeyPos) else: - raw_dataset = CausalLmDataset(config.data.validation_set(Pos.size), Pos, KeyPos) # type: ignore + validation_set = config.data.validation_set(Pos.size) + if validation_set is None: + raise ValueError("Can't eval on validation_set b/c there isn't one!") + + raw_dataset = CausalLmDataset(validation_set, Pos, KeyPos) # type: ignore eval_loader = ReplicatedBatchLoader(raw_dataset, config.trainer.device_mesh, Batch) compute_axis_mapping = config.trainer.compute_axis_mapping diff --git a/src/levanter/models/lm_model.py b/src/levanter/models/lm_model.py index 0d0a7d70a..c492b2321 100644 --- a/src/levanter/models/lm_model.py +++ b/src/levanter/models/lm_model.py @@ -123,7 +123,7 @@ def compute_loss( across the reduction axis (with reduction_axis=None meaning all axes). If reduction is None, the loss is not reduced, and the result is a named array with axes (*batch axes, sequence_length). """ - logits = self(example.tokens, example.attn_mask, key=key) + logits = self(example.tokens, example.attn_mask, key=key).astype(jnp.float32) targets = hax.roll(example.tokens, -1, axis=self.Pos.name) target_y = hax.nn.one_hot(targets, self.Vocab, dtype=logits.dtype) return cross_entropy_loss( diff --git a/src/levanter/optim/sophia.py b/src/levanter/optim/sophia.py index 8895942a2..daa03285e 100644 --- a/src/levanter/optim/sophia.py +++ b/src/levanter/optim/sophia.py @@ -11,7 +11,7 @@ from jax.random import PRNGKey from jaxtyping import PRNGKeyArray -# import levanter.tracker +import levanter.tracker from levanter.optim.config import HessianOptConfig, OptimizerConfig from levanter.optim.util import hvp, tree_gaussian_like from levanter.utils.jax_utils import parameter_count, tree_filter_like @@ -348,9 +348,7 @@ def update_fn(updates, state, params=None, *, obj_fn, **kwargs): updates = jax.tree_util.tree_map(lambda u: jnp.clip(u, -clip_threshold, clip_threshold), updates) stats["optim/unclipped_fraction"] = unclipped_count / parameter_count(updates) - # this doesn't work well on CPU, so skip if cpu - # if jax.lib.xla_bridge.get_backend().platform != "cpu": - # levanter.tracker.jit_log_metrics(stats, step=state.count) + levanter.tracker.jit_log_metrics(stats, step=state.count) if mu_dtype is not None: mu = jax.tree_util.tree_map(lambda t: t.astype(mu_dtype), mu) diff --git a/src/levanter/utils/hf_utils.py b/src/levanter/utils/hf_utils.py index 408a8c8da..ef77dcdfd 100644 --- a/src/levanter/utils/hf_utils.py +++ b/src/levanter/utils/hf_utils.py @@ -18,7 +18,8 @@ def num_cpus_used_by_tokenizer(tokenizer) -> int: else: # This is a bit hacky, but HF's fast tokenizers are parallelized under the hood. # we reserve a couple of cores just so Ray has somewhere to run the coordinator. - return min(max(1, logical_cpu_core_count() - 2), 32) + # Empirically it doesn't usually exceed 16-20, and it's useful to have some slack + return min(max(1, logical_cpu_core_count() - 2), 12) else: return 1 From 8f9a3dece8e8153f919c38b7761bd4f892fb4327 Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 12 Feb 2024 09:10:32 -0800 Subject: [PATCH 4/5] Refactor TrainerState to make it a module, make save_checkpoint a nicer function (#462) * Trackers let us abstract out TB vs wandb * missed a few spots * remove old config * wip * missed some spots * more missed spots --- src/levanter/checkpoint.py | 120 ++++++---- src/levanter/data/shard_cache.py | 11 - src/levanter/main/eval_lm.py | 6 +- src/levanter/main/export_lm_to_hf.py | 5 +- src/levanter/main/lora_lm.py | 7 +- src/levanter/main/train_lm.py | 2 +- src/levanter/main/viz_logprobs.py | 5 +- src/levanter/trainer.py | 287 +++++++++++++++--------- tests/test_checkpoint.py | 85 +++---- tests/test_eval_lm.py | 5 +- tests/test_export_to_hf.py | 2 +- tests/test_tensorstore_serialization.py | 25 +++ tests/test_viz_lm.py | 2 +- 13 files changed, 330 insertions(+), 232 deletions(-) diff --git a/src/levanter/checkpoint.py b/src/levanter/checkpoint.py index 15f16a203..5c243946e 100644 --- a/src/levanter/checkpoint.py +++ b/src/levanter/checkpoint.py @@ -7,7 +7,7 @@ import urllib.parse from dataclasses import dataclass from datetime import timedelta -from typing import Callable, List, Optional, Sequence, Tuple, TypeVar, Union +from typing import Callable, List, Optional, Sequence, TypeVar, Union import equinox import fsspec @@ -28,8 +28,7 @@ PathLike = Union[str, pathlib.Path] -M = TypeVar("M") -S = TypeVar("S") +M = TypeVar("M", bound=PyTree) @dataclass(frozen=True) @@ -102,19 +101,16 @@ def __init__( def load_checkpoint( self, - model: M, - training_state: S, + state: M, path: Optional[PathLike] = None, *, discover_latest: bool = True, axis_mapping: Optional[haliax.partitioning.ResourceMapping] = None, mesh: Optional[haliax.partitioning.Mesh] = None, - ) -> Optional[Tuple[M, S, int]]: + ) -> Optional[M]: if path is None: path = self.base_path - return load_checkpoint( - model, training_state, path, discover_latest=discover_latest, axis_mapping=axis_mapping, mesh=mesh - ) + return load_checkpoint(state, path, discover_latest=discover_latest, axis_mapping=axis_mapping, mesh=mesh) def load_model( self, @@ -124,16 +120,17 @@ def load_model( discover_latest: bool = True, axis_mapping: Optional[haliax.partitioning.ResourceMapping] = None, mesh: Optional[haliax.partitioning.Mesh] = None, - ) -> Optional[Tuple[M, int]]: - if path is None: - path = self.base_path - ckpt = load_checkpoint( - model, None, path, discover_latest=discover_latest, axis_mapping=axis_mapping, mesh=mesh + ) -> Optional[M]: + """ + Convenience method/holdover from previous API for loading checkpoints. + Loads just the model assuming the model is in the `model` subdir of the discovered checkpoint. + """ + ret_dict = self.load_checkpoint( + {"model": model}, path, discover_latest=discover_latest, axis_mapping=axis_mapping, mesh=mesh ) - if ckpt is None: + if ret_dict is None: return None - model, _, step = ckpt - return model, step + return ret_dict["model"] def on_step(self, info, force: bool = False): step = info.step @@ -219,10 +216,9 @@ def _rm_checkpoint(self, checkpoint): def save_checkpoint(self, info, destination: str): path = os.path.join(self.base_path, destination) logger.info(f"Saving checkpoint at step {info.step} to {path}") - model = equinox.filter(info.model, self.keep_params) + state = equinox.filter(info.state, info.state.is_trainable) save_checkpoint( - model=model, - training_state=(info.opt_state, info.next_key), + state, step=info.step, checkpoint_path=path, ) @@ -231,7 +227,7 @@ def save_checkpoint(self, info, destination: str): logger.info(f"Saved checkpoint at step {info.step} to {path}. Save time is {self._last_save_time}") -def save_checkpoint(model, training_state, step: int, checkpoint_path: PathLike): +def save_checkpoint(tree: M, step: int, checkpoint_path: PathLike): """ Save a checkpoint to a given path using TensorStore. If exist_ok is True, the checkpoint will be saved even if a checkpoint already exists at the given path. @@ -249,10 +245,7 @@ def save_checkpoint(model, training_state, step: int, checkpoint_path: PathLike) fs, plain_path = _get_fs_and_plain_path(checkpoint_path) fs.makedirs(plain_path, exist_ok=True) - tree_serialize_leaves_tensorstore(os.path.join(checkpoint_path, "model"), model) - if training_state is not None: - tree_serialize_leaves_tensorstore(os.path.join(checkpoint_path, "training_state"), training_state) - + tree_serialize_leaves_tensorstore(checkpoint_path, tree) save_metadata(checkpoint_path, fs, step) logger.info(f"Saved checkpoint for step {step}") @@ -271,22 +264,30 @@ def save_metadata(checkpoint_path, fs, step): def load_checkpoint( - model: M, - training_state: S, + tree: M, checkpoint_path: PathLike, *, + subpath: Optional[str] = None, discover_latest=True, axis_mapping: Optional[haliax.partitioning.ResourceMapping] = None, mesh: Optional[jax.sharding.Mesh] = None, -) -> Optional[Tuple[M, S, int]]: +) -> M: """ - Load a checkpoint from a given path. - - Returns the loaded model state, training state, and step. If discover_latest is True, - the latest checkpoint in the given path will be loaded. Otherwise, the checkpoint at - the given path will be loaded. If no checkpoint is found, returns None + Load a checkpoint from a given path. If discover_latest is True, then the latest checkpoint + in a subdirectory of the given path will be loaded. If subpath is not None, then the checkpoint + loads only that subpath of the checkpoint. This is useful for loading, e.g., just the model and not + the entire training state. + + Args: + tree: an exemplar of the tree to load. Can be a PyTree[ShapeDTypeStruct] instead of a PyTree[Any] + checkpoint_path: the path to load the checkpoint from + subpath: the subpath to load from the checkpoint + discover_latest: whether to discover the latest checkpoint in the given path + axis_mapping: the axis mapping to use for loading the checkpoint + mesh: the mesh to use for loading the checkpoint + Returns: + the loaded checkpoint, with the same structure as the exemplar tree - If training_state is None, no training state will be loaded. """ fs: AbstractFileSystem fs, _ = _get_fs_and_plain_path(checkpoint_path) @@ -297,28 +298,52 @@ def load_checkpoint( checkpoint_path = discover_latest_checkpoint(checkpoint_path) # type: ignore if checkpoint_path is None or not fs.exists(checkpoint_path): - return None + raise FileNotFoundError(f"Could not find checkpoint at {checkpoint_path}") logger.info(f"Loading checkpoint from {checkpoint_path}") metadata = load_metadata(checkpoint_path, fs) - model = tree_deserialize_leaves_tensorstore( - os.path.join(checkpoint_path, "model"), model, axis_mapping=axis_mapping, mesh=mesh - ) + if subpath: + checkpoint_path = os.path.join(checkpoint_path, subpath) - if training_state is None: - training_state = None - else: - training_state = tree_deserialize_leaves_tensorstore( - os.path.join(checkpoint_path, "training_state"), training_state, axis_mapping=axis_mapping, mesh=mesh - ) + try: + tree = tree_deserialize_leaves_tensorstore(checkpoint_path, tree, axis_mapping=axis_mapping, mesh=mesh) + return tree + except: # noqa + from levanter.trainer import TrainerState - return model, training_state, metadata["step"] + if not isinstance(tree, TrainerState): + raise + else: + logger.warning("Attempting to load old-style checkpoint") + model, training_state = tree.model, (tree.opt_state, tree.training_key) + + model = tree_deserialize_leaves_tensorstore( + os.path.join(checkpoint_path, "model"), model, axis_mapping=axis_mapping, mesh=mesh + ) + + if training_state is None: + opt_state = None + key = None + else: + training_state = tree_deserialize_leaves_tensorstore( + os.path.join(checkpoint_path, "training_state"), + training_state, + axis_mapping=axis_mapping, + mesh=mesh, + ) + opt_state, key = training_state + + # TODO: pretty sure this is right, but should verify + step = metadata["step"] + new_state = dataclasses.replace( + tree, _step=step + 1, model=model, opt_state=opt_state, training_key=key # type: ignore + ) + return new_state def load_metadata(checkpoint_path, fs=None): if fs is None: - fs: AbstractFileSystem fs, _, _ = fsspec.get_fs_token_paths(str(checkpoint_path)) with fs.open(os.path.join(checkpoint_path, "metadata.json")) as metadata_in: metadata = json.load(metadata_in) @@ -381,13 +406,12 @@ class CheckpointerConfig: def expanded_path(self, run_id): return os.path.expanduser(os.path.join(self.base_path, run_id)) - def create(self, run_id, keep_params: PyTree[FilterSpec] = True) -> Checkpointer: + def create(self, run_id) -> Checkpointer: keeps = [CheckpointInterval(**k) for k in self.keep] return Checkpointer( base_path=self.expanded_path(run_id), save_interval=self.save_interval, step_policies=keeps, - keep_params=keep_params, ) def __post_init__(self): diff --git a/src/levanter/data/shard_cache.py b/src/levanter/data/shard_cache.py index f5faa9b36..c1d24c1a0 100644 --- a/src/levanter/data/shard_cache.py +++ b/src/levanter/data/shard_cache.py @@ -1231,17 +1231,6 @@ def priority_fn(shard_idx, chunk_idx): ray.get(reader_actor.add_work_group.remote(work_item)) - # reader = _alternating_shard_reader.remote( - # name, - # self_ref, - # writer, - # source, - # shard_group, - # priority_fn, - # processor_actor, - # processor.batch_size, - # rows_per_chunk, - # ) self._shard_readers.append(reader_actor) def new_chunk(self, shard_name: str, *chunks: ChunkMetadata): diff --git a/src/levanter/main/eval_lm.py b/src/levanter/main/eval_lm.py index 9b056b950..c7976ad41 100644 --- a/src/levanter/main/eval_lm.py +++ b/src/levanter/main/eval_lm.py @@ -85,14 +85,10 @@ def compute_loss(model: LmHeadModel, example: LmExample): with use_cpu_device(): model = eqx.filter_eval_shape(config.model.build, Vocab, key=key) # TODO: don't load the entire checkpoint into CPU memory when we only need our share of the model - ckpt = load_checkpoint(model, None, config.checkpoint_path) - - assert ckpt is not None - model, _, _ = ckpt + model = load_checkpoint(model, config.checkpoint_path, subpath="model") model = hax.shard_with_axis_mapping(model, parameter_axis_mapping) - # TODO: switch to throwing instead of returning None loss = callbacks.eval_loss_loop(compute_loss, model, eval_loader, max_batches=total) del model diff --git a/src/levanter/main/export_lm_to_hf.py b/src/levanter/main/export_lm_to_hf.py index 50a8e4b92..7fd4d073d 100644 --- a/src/levanter/main/export_lm_to_hf.py +++ b/src/levanter/main/export_lm_to_hf.py @@ -51,10 +51,9 @@ def main(config: ConvertLmConfig): model: LmHeadModel = eqx.filter_eval_shape(config.model.build, Vocab, key=key) trainable, non_trainable = eqx.partition(model, is_inexact_arrayish) # TODO: don't load the entire checkpoint into CPU memory when we only need our share of the model - ckpt = load_checkpoint(trainable, None, config.checkpoint_path) + trainable = load_checkpoint(trainable, config.checkpoint_path, subpath="model") - assert ckpt is not None - trainable, _, _ = ckpt + assert trainable is not None model = eqx.combine(trainable, non_trainable) if config.override_vocab_size: diff --git a/src/levanter/main/lora_lm.py b/src/levanter/main/lora_lm.py index 5120c9e22..6b845b516 100644 --- a/src/levanter/main/lora_lm.py +++ b/src/levanter/main/lora_lm.py @@ -93,7 +93,10 @@ def compute_loss(model, example: LmExample, key=None): state = trainer.initial_state(training_key, model=model) all_param_count = parameter_count(state.model) - just_lora_params = parameter_count(trainer.trainable_params_only(state.model)) + # TODO: remove this once we put this in trainer itself + just_lora_params = parameter_count( + levanter.trainer._partition_trainable_params(state.model, lora_param_filter) + ) levanter.tracker.log_summary( { @@ -140,7 +143,7 @@ def compute_loss(model, example: LmExample, key=None): # TODO: implement iter_data.seek(resume_step +1) import tqdm - for _ in tqdm.tqdm(range(state.step + 1), desc="seeking data for resume"): + for _ in tqdm.tqdm(range(state.step), desc="seeking data for resume"): next(iter_data) ## OK, actually run training! diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py index 2dbd705d5..68f63b987 100644 --- a/src/levanter/main/train_lm.py +++ b/src/levanter/main/train_lm.py @@ -182,7 +182,7 @@ def compute_log_probs(model, example: LmExample): # TODO: implement iter_data.seek(resume_step +1) import tqdm - for _ in tqdm.tqdm(range(state.step + 1), desc="seeking data for resume"): + for _ in tqdm.tqdm(range(state.step), desc="seeking data for resume"): next(train_loader) ## OK, actually run training! diff --git a/src/levanter/main/viz_logprobs.py b/src/levanter/main/viz_logprobs.py index b992cd3f5..ef16a7238 100644 --- a/src/levanter/main/viz_logprobs.py +++ b/src/levanter/main/viz_logprobs.py @@ -81,10 +81,9 @@ def compute_log_probs(model: LmHeadModel, example: LmExample): with use_cpu_device(): model = eqx.filter_eval_shape(config.model.build, Vocab, key=key) # TODO: don't load the entire checkpoint into CPU memory when we only need our share of the model - ckpt = load_checkpoint(model, None, config.checkpoint_path) + model = load_checkpoint(model, config.checkpoint_path, subpath="model") - assert ckpt is not None - model, _, _ = ckpt + assert model is not None model = hax.shard_with_axis_mapping(model, parameter_axis_mapping) diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index 7d8661c91..41f5d04ab 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -1,5 +1,6 @@ import atexit import copy +import dataclasses import functools import logging as pylogging import os @@ -30,7 +31,6 @@ import jmp import numpy as np from draccus import field -from jax import ShapeDtypeStruct from jax.experimental import multihost_utils from jax.sharding import Mesh from jaxtyping import PRNGKeyArray, PyTree @@ -39,13 +39,13 @@ import haliax as hax from haliax import Axis from haliax.partitioning import ResourceAxis, ResourceMapping, named_jit -from haliax.types import Scalar +from haliax.types import IntScalar, Scalar import levanter.logging import levanter.tracker import levanter.tracker.wandb from levanter import tracker -from levanter.checkpoint import CheckpointerConfig +from levanter.checkpoint import CheckpointerConfig, load_checkpoint from levanter.config import JsonAtom from levanter.data import Dataset, ReplicatedBatchLoader, ShardableDataset, ShardedBatchLoader from levanter.distributed import DistributedConfig, RayConfig @@ -54,7 +54,7 @@ from levanter.tracker import TrackerConfig from levanter.types import FilterSpec from levanter.utils import cloud_utils -from levanter.utils.jax_utils import is_inexact_arrayish +from levanter.utils.jax_utils import as_arrayish, is_inexact_arrayish from levanter.utils.tree_utils import inference_mode @@ -62,7 +62,6 @@ X = TypeVar("X") # Input M = TypeVar("M", bound=PyTree) -S = TypeVar("S", bound=PyTree) DEFAULT_JAX_CONFIG = { "jax_threefry_partitionable": True, @@ -74,14 +73,36 @@ # A "StepInfo"'s step is the step that was just completed. If you want the next step, use `next_step`. -@dataclass -class TrainerState(Generic[M]): - step: int +class TrainerState(eqx.Module, Generic[M]): + """ + This is the state of the trainer. It contains the model, optimizer state, and random key. + It is an equinox Module becaues it is a PyTree that gets passed to the core `train_step` method + of the Trainer. This unfortunately means that `step` is an Array and not an int, hence the IntScalar. + + It's designed to be extended by subclasses. + """ + + _step: IntScalar = eqx.field(converter=lambda x: as_arrayish(x)) model: M opt_state: OptState training_key: PRNGKeyArray + is_trainable: PyTree[FilterSpec] # = eqx.field(static=True) + + @cached_property + def step(self) -> int: + return int(self._step) + + @property + def trainable_model(self) -> M: + return eqx.filter(self.model, self.is_trainable) +S = TypeVar("S", bound=TrainerState) + + +# A note on the semantics of "step" vs "next_step": +# The "step" of a TrainerState is the state after `step` steps have been taken. +# A "StepInfo"'s step is the step that was just completed. If you want the next step, use `next_step`. @dataclass class StepInfo(Generic[M]): state: TrainerState[M] @@ -90,7 +111,6 @@ class StepInfo(Generic[M]): model = property(lambda self: self.state.model) opt_state = property(lambda self: self.state.opt_state) - next_key = property(lambda self: self.state.training_key) step = property(lambda self: self.state.step - 1) """ @@ -172,12 +192,12 @@ def loss_fn(self): Wrapped loss function that casts the model to compute precision and sets the context axis mapping to compute """ - @named_jit(in_axis_resources=self.parameter_axis_mapping, axis_resources=self.compute_axis_mapping) + @named_jit(axis_resources=self.compute_axis_mapping) @functools.wraps(self._raw_loss_function) def fn(model, *batch, **batch_kwargs): with hax.axis_mapping(self.compute_axis_mapping): model = self.mp.cast_to_compute(model) - return self._raw_loss_function(model, *batch, **batch_kwargs) + return _ensure_scalar(self._raw_loss_function(model, *batch, **batch_kwargs)) return fn @@ -273,79 +293,81 @@ def initial_state( raise ValueError("one of model and model_init must be specified") if model is not None: - # we can't just use `lambda: model` because JAX jit can't see captures, but it can see partials - # We can't use plain partials because they aren't pytrees + # we can't just use `lambda: model` because JAX jit can't see captures, but it can see jax partials model_init = jax.tree_util.Partial(lambda m: m, model) + del model assert model_init is not None - model_shape, opt_state_shape = eqx.filter_eval_shape(self._init_model_and_opt_state, model_init) + # first try to load a full trainer state checkpoint + checkpoint_path = self.config.load_checkpoint_path + if checkpoint_path is None: + checkpoint_path = self.config.checkpointer.expanded_path(self.run_id) + + do_load_checkpoint = self.config.load_checkpoint + axis_mapping = self.parameter_axis_mapping + mesh = self.device_mesh + initial_model_path = self.config.initialize_from - # we only checkpoint the trainable parameters, so we need to filter out the non-trainable ones - trainable_model_shape = self.trainable_params_only(model_shape) + # we don't save the full trainer state, so we need to filter out the non-trainable parameters - ckpt = self.maybe_load_checkpoint( - trainable_model_shape, - (opt_state_shape, training_key), - axis_mapping=self.parameter_axis_mapping, - mesh=self.device_mesh, + def init_state_and_model(model_init, training_key, is_trainable): + model = model_init() + state = self._initialize_state_from_scratch(model, training_key, is_trainable) + return state + + trainer_state_shape = eqx.filter_eval_shape( + init_state_and_model, model_init, training_key, self.is_trainable_param ) + saveable_state_shape = _make_saveable_trainer_state(trainer_state_shape, self.is_trainable_param) - if ckpt is not None: - trainable_model, (opt_state, training_key), completed_step = ckpt - if model is not None: - model = eqx.combine(trainable_model, model) - else: - model = eqx.combine(trainable_model, model_shape) - - if any(isinstance(leaf, ShapeDtypeStruct) for leaf in jax.tree_leaves(model)): - # if we're resuming, we need to re-initialize the non-trainable parameters to their original values - non_trainable = named_jit(self._init_non_trainable_params, self.parameter_axis_mapping)(model_init) - model = eqx.combine(trainable_model, non_trainable) - - step = completed_step + 1 - elif self.config.initialize_from is not None: - # initialize from a levanter checkpoint - logger.info(f"Initializing model from checkpoint {self.config.initialize_from}") - match levanter.checkpoint.load_checkpoint( - model_shape, - None, - self.config.initialize_from, - axis_mapping=self.parameter_axis_mapping, - mesh=self.device_mesh, - ): - # new_model is probably only the trainable parameters, so we init the rest - case base_model, _, loaded_step: - logger.info(f"Initialized from step {loaded_step} of {self.config.initialize_from}") - old_model_init = model_init - - model_init = jax.tree_util.Partial(lambda m: eqx.combine(m, old_model_init()), base_model) - model, opt_state = named_jit(self._init_model_and_opt_state, self.parameter_axis_mapping)( - model_init - ) - - step = 0 - case None: - raise ValueError(f"Could not load model from checkpoint {self.config.initialize_from}") - else: - model, opt_state = named_jit(self._init_model_and_opt_state, self.parameter_axis_mapping)(model_init) - step = 0 + if do_load_checkpoint is not False: + try: + state = load_checkpoint(saveable_state_shape, checkpoint_path, axis_mapping=axis_mapping, mesh=mesh) + except FileNotFoundError: + if do_load_checkpoint: + raise + else: + state = None + + # if that fails, try to load just a model from a checkpoint for initialization + if state is None and initial_model_path is not None: + logger.info(f"Initializing from {initial_model_path}") + # todo: we are potentially holding two models in memory at once here, if we pass in a model + # instead of a model_init and we use initialize_from. We could avoid this by deleting + # any to-be-loaded parameters from the model before loading, but that's a bit more complicated + loaded_model = load_checkpoint( + saveable_state_shape.model, + initial_model_path, + axis_mapping=axis_mapping, + mesh=mesh, + subpath="model", + ) + + # we don't necessarily load the full model, so we need to combine it with the model init + model_init = jax.tree_util.Partial(lambda m, f: eqx.combine(m, f()), loaded_model, model_init) - return TrainerState(step, model, opt_state, training_key) + # now we initialize a fresh trainer state, possibly just to finish any missing fields + @named_jit(axis_resources=axis_mapping, donate_args=(True, True, True, False)) + def init_state(partial_state, model_init, training_key, is_trainable): + model = model_init() + fresh_state = self._initialize_state_from_scratch(model, training_key, is_trainable) + return eqx.combine(partial_state, fresh_state) + + state = init_state(state, model_init, training_key, self.is_trainable_param) + + return state def train_step(self, state: TrainerState[M], *batch: X, **batch_kwargs) -> StepInfo[M]: """ Performs a single training step. """ with capture_time() as step_time: - key, new_key = jax.random.split(state.training_key) - loss, new_model, new_optstate = self._train_step_fn( - state.model, state.opt_state, *batch, **batch_kwargs, key=key - ) + loss, new_state = self._jit_train_step_fn(state, *batch, **batch_kwargs) # force the loss so timing numbers are accurate. laziness isn't going to help here (i think?) loss = loss.item() # type: ignore - return StepInfo(TrainerState(state.step + 1, new_model, new_optstate, new_key), loss, step_time()) + return StepInfo(new_state, loss, step_time()) def training_steps( self, state: TrainerState[M], train_loader, run_hooks: bool = True @@ -355,7 +377,7 @@ def training_steps( """ iter_data = iter(train_loader) - while state.step < self.config.num_train_steps: + while state.step < self.num_train_steps: with capture_time() as loading_time: example = next(iter_data) @@ -391,7 +413,7 @@ def _add_default_hooks(self): self.add_hook(callbacks.pbar_logger(total=self.config.num_train_steps), every=1) self.add_hook(callbacks.log_step_info, every=1) # engine.add_hook(callbacks.log_memory_usage(), every=1) - checkpointer = self.config.checkpointer.create(self.run_id, self.is_trainable_param) + checkpointer = self.config.checkpointer.create(self.run_id) self.add_hook(checkpointer.on_step, every=1) # checkpointer manages its own frequency def add_eval_hook(self, eval_dataset, name: Optional[str] = None): @@ -440,34 +462,19 @@ def sharded_loader(self, dataset: ShardableDataset[X], batch_axis: Axis) -> Shar return ShardedBatchLoader(dataset, self.device_mesh, batch_axis, self.compute_axis_mapping) @cached_property - def _train_step_fn(self): - @named_jit( - axis_resources=self.parameter_axis_mapping, - out_axis_resources=self.parameter_axis_mapping, - donate_args=(True, True), - ) - def train_step(model, opt_state, *batch, **batch_kwargs): - model = inference_mode(model, False) + def _jit_train_step_fn(self): + return named_jit(self._train_step, axis_resources=self.parameter_axis_mapping, donate_args=(True,)) - # we do this so that we only take the gradients of the trainable parameters - trainable_model, rest_model = self.partition_trainable_params(model) + def _train_step(self, state: TrainerState, *batch, **batch_kwargs) -> tuple[Scalar, TrainerState]: + key, new_key = jax.random.split(state.training_key) + model = inference_mode(state.model, False) - def split_loss_fn(trainable_model, *batch, **batch_kwargs): - model = eqx.combine(trainable_model, rest_model) - return self.loss_fn(model, *batch, **batch_kwargs) + loss, grads = self._compute_gradients_microbatched(self.loss_fn, model, batch, **batch_kwargs, key=key) - loss, grads = self._compute_gradients_microbatched(split_loss_fn, trainable_model, batch, **batch_kwargs) + new_state = self._take_train_step(state, model, grads, *batch, **batch_kwargs, key=key) + new_state = dataclasses.replace(new_state, training_key=new_key) - updates, opt_state = self.optimizer.update(grads, opt_state, params=trainable_model) - - partial_fn = lambda model: self.loss_fn(model, *batch, **batch_kwargs) - - updates, opt_state = self.optimizer.update(grads, opt_state, params=trainable_model, obj_fn=partial_fn) - model = eqx.apply_updates(model, updates) - - return loss, model, opt_state - - return train_step + return loss, new_state def _compute_gradients_microbatched(self, loss_fn, model: M, batch, **batch_kwargs) -> tuple[Scalar, M]: grad_fn = eqx.filter_value_and_grad(loss_fn, has_aux=False) @@ -480,22 +487,59 @@ def _compute_gradients_microbatched(self, loss_fn, model: M, batch, **batch_kwar ) return grad_fn(model, *batch, **batch_kwargs) - def _init_model_and_opt_state(self, model_init): - model = model_init() - # only force trainable params to param precision. Other params are cast to compute precision - trainable, non_trainable = self.partition_trainable_params(model) - trainable = self.mp.cast_to_param(trainable) - non_trainable = self.mp.cast_to_compute(non_trainable) - model = eqx.combine(trainable, non_trainable) - opt_state = self.optimizer.init(trainable) - return model, opt_state - - def _init_non_trainable_params(self, model_init): - model = model_init() + def _take_train_step(self, state: S, model, grads, *batch, **batch_kwargs) -> S: + """ + Takes a training step. This is a separate method so that it can be overridden or used in a subclass. + """ + # only train on the trainable parameters. We're leaning on JAX to do dead code elimination for us + with hax.axis_mapping(self.parameter_axis_mapping): + train_grads = _partition_trainable_params(grads, state.is_trainable)[0] + trainable_model = _partition_trainable_params(model, state.is_trainable)[0] + updates, opt_state = self.optimizer.update(train_grads, state.opt_state, params=trainable_model) + + partial_fn = lambda model: self.loss_fn(model, *batch, **batch_kwargs) + + updates, opt_state = self.optimizer.update(grads, opt_state, params=trainable_model, obj_fn=partial_fn) + model = eqx.apply_updates(model, updates) + + return dataclasses.replace(state, _step=state._step + 1, model=model, opt_state=opt_state) + + def _initialize_state_from_scratch(self, model, training_key, is_trainable): # only force trainable params to param precision. Other params are cast to compute precision - trainable, non_trainable = self.partition_trainable_params(model) - non_trainable = self.mp.cast_to_compute(non_trainable) - return non_trainable + model = cast_params_by_trainability(model, self.mp, is_trainable) + opt_state = init_optimizer_for_trainables(self.optimizer, model, is_trainable) + + return TrainerState(0, model, opt_state, training_key, is_trainable) + + +def init_optimizer_for_trainables(optimizer, model, is_trainable): + trainable, _ = _partition_trainable_params(model, is_trainable) + opt_state = optimizer.init(trainable) + return opt_state + + +def _make_saveable_trainer_state(trainer_state: S, is_trainable) -> S: + """ + Returns the shape of the trainer state that we save to a checkpoint. This is used to load a checkpoint. + You can override if you really need custom checkpointing logic. By default everything in the trainer state + is saved (except for non-trainable model parameters) + """ + saveable_model = eqx.filter(trainer_state.model, is_trainable) + saveable_state = dataclasses.replace(trainer_state, model=saveable_model) + return saveable_state + + +def cast_params_by_trainability(model, mp, is_trainable): + """ + Casts the parameters of a model to the appropriate precision based on the is_trainable filter spec. + Trainable parameters are cast to param precision, non-trainable parameters are cast to compute precision. + """ + + trainable, non_trainable = _partition_trainable_params(model, is_trainable) + trainable = mp.cast_to_param(trainable) + non_trainable = mp.cast_to_compute(non_trainable) + model = eqx.combine(trainable, non_trainable) + return model def trainable_params_only(self, model: M) -> M: """ @@ -784,3 +828,32 @@ def initialize(config: TrainerConfig | AllConfig): def _params_only(t): return eqx.filter(t, is_inexact_arrayish) + + +def _partition_trainable_params(model, filter): + """ + Partitions the model into trainable and non-trainable parameters. This is used internally + for the gradient calculation and checkpointing, but you can also use it to filter out params for logging + or something. + + Returns: + trainable, non-trainable + """ + + def trainable_and_diffable(pred): + if callable(pred): + return lambda x: pred(x) and is_inexact_arrayish(x) + elif pred is True: + return is_inexact_arrayish + else: + return pred + + combined_mask = jax.tree_util.tree_map(trainable_and_diffable, filter) + return eqx.partition(model, combined_mask) + + +def _ensure_scalar(x: hax.types.Scalar | hax.NamedArray) -> hax.types.Scalar: + if isinstance(x, hax.NamedArray): + return x.scalar() + else: + return x diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py index c22525fd6..db54b2569 100644 --- a/tests/test_checkpoint.py +++ b/tests/test_checkpoint.py @@ -1,3 +1,4 @@ +import dataclasses import datetime import pathlib import tempfile @@ -26,10 +27,11 @@ def _dummy_step_info(step): return StepInfo( state=TrainerState( # + 1 b/c step here is next step - step=step + 1, + _step=step + 1, model=None, opt_state=(), training_key=(), + is_trainable=True, ), loss=0.0, step_duration=0.0, @@ -139,42 +141,41 @@ def advance_time(delta_seconds): assert _get_checkpoint_steps(tmpdir) == [2, 4, 6, 8, 10, 15, 20, 30, 40, 49] # 49 is last temporary checkpoint +def _make_state(step, key): + model = MLP(in_size=2, out_size=1, width_size=2, depth=3, key=key) + optim = optax.adam(1e-4) + opt_state = optim.init(arrays_only(model)) + + return TrainerState(step, model, opt_state, key, True) + + def test_checkpoint_simple(): key0 = jax.random.PRNGKey(0) key1 = jax.random.PRNGKey(1) - def make_state(key): - model = MLP(in_size=2, out_size=1, width_size=2, depth=3, key=key) - optim = optax.adam(1e-4) - opt_state = optim.init(arrays_only(model)) - - return model, opt_state, key - - initial_model, initial_opt_state, initial_key = make_state(key0) - rep_model, rep_state, rep_key = make_state(key1) + initial_state = _make_state(10, key0) + rep_state = _make_state(2, key1) - assert_trees_not_close(initial_model, rep_model) + assert_trees_not_close(initial_state.model, rep_state.model) with tempfile.TemporaryDirectory() as tmpdir: save_checkpoint( - initial_model, - (initial_opt_state, initial_key), - step=10, + initial_state, + step=initial_state.step, checkpoint_path=tmpdir, ) - restored_model, (restored_optstate, rkey), step = load_checkpoint( - rep_model, - (rep_state, rep_key), + restored_state = load_checkpoint( + rep_state, checkpoint_path=tmpdir, discover_latest=False, ) assert_trees_all_close( - jax.tree_util.tree_leaves(arrays_only(restored_model)), - jax.tree_util.tree_leaves(arrays_only(initial_model)), + jax.tree_util.tree_leaves(arrays_only(restored_state.model)), + jax.tree_util.tree_leaves(arrays_only(initial_state.model)), ) - assert all(np.isclose(rkey, initial_key)) - assert step == 10 + assert all(np.isclose(restored_state.training_key, initial_state.training_key)) + assert restored_state.step == initial_state.step def test_checkpoint_steps(): @@ -183,13 +184,7 @@ def test_checkpoint_steps(): optim = optax.adam(1e-4) - def make_state(key): - model = MLP(in_size=2, out_size=1, width_size=2, depth=3, key=key) - opt_state = optim.init(arrays_only(model)) - - return model, opt_state, key - - initial_model, initial_opt_state, initial_key = make_state(key0) + initial_state = _make_state(10, key0) data = jax.random.uniform(key0, (2, 2)) @eqx.filter_grad @@ -197,41 +192,33 @@ def loss_fn(model, data): m = jax.vmap(model) return jnp.mean(jnp.square(m(data))) - model, state = initial_model, initial_opt_state + state = initial_state for i in range(3): - grad = loss_fn(model, data) - updates, state = optim.update(grad, state) - model = eqx.apply_updates(model, updates) + grad = loss_fn(state.model, data) + updates, new_state = optim.update(grad, state.opt_state) + model = eqx.apply_updates(state.model, updates) + state = dataclasses.replace(state, _step=state.step + 1, model=model, opt_state=new_state) - assert_trees_not_close(model, initial_model) - assert_trees_not_close(state, initial_opt_state) + assert_trees_not_close(state, initial_state) - rep_model, rep_state, rep_key = make_state(key1) - assert_trees_not_close(model, rep_model) + rep_state = _make_state(42, key1) assert_trees_not_close(state, rep_state) with tempfile.TemporaryDirectory() as tmpdir: - save_checkpoint(model, state, step=3, checkpoint_path=tmpdir) - restored_model, restored_optstate, step = load_checkpoint( - rep_model, rep_state, checkpoint_path=tmpdir, discover_latest=False - ) + save_checkpoint(state, step=3, checkpoint_path=tmpdir) + restored_state = load_checkpoint(rep_state, checkpoint_path=tmpdir, discover_latest=False) assert_trees_all_close( - jax.tree_util.tree_leaves(arrays_only(restored_model)), - jax.tree_util.tree_leaves(arrays_only(model)), - ) - assert_trees_all_close( - jax.tree_util.tree_leaves(arrays_only(restored_optstate)), + jax.tree_util.tree_leaves(arrays_only(restored_state)), jax.tree_util.tree_leaves(arrays_only(state)), ) - assert step == 3 def test_checkpoint_discovery(): with tempfile.TemporaryDirectory() as tempdir: - save_checkpoint(model=1, training_state=2, step=10, checkpoint_path=f"{tempdir}/step-10") - save_checkpoint(model=3, training_state=4, step=20, checkpoint_path=f"{tempdir}/step-20") - save_checkpoint(model=5, training_state=6, step=30, checkpoint_path=f"{tempdir}/step-30") + save_checkpoint(dict(model=1, training_state=2), step=10, checkpoint_path=f"{tempdir}/step-10") + save_checkpoint(dict(model=3, training_state=4), step=20, checkpoint_path=f"{tempdir}/step-20") + save_checkpoint(dict(model=5, training_state=6), step=30, checkpoint_path=f"{tempdir}/step-30") latest = discover_latest_checkpoint(tempdir) assert latest == f"{tempdir}/step-30" diff --git a/tests/test_eval_lm.py b/tests/test_eval_lm.py index 178069f26..a6bf3c8d9 100644 --- a/tests/test_eval_lm.py +++ b/tests/test_eval_lm.py @@ -13,6 +13,7 @@ from levanter.distributed import RayConfig from levanter.models.gpt2 import Gpt2LMHeadModel from levanter.tracker.wandb import WandbConfig +from levanter.trainer import TrainerState from levanter.utils.py_utils import logical_cpu_core_count @@ -43,7 +44,9 @@ def test_eval_lm(): Vocab = haliax.Axis("vocab", len(tok)) model = Gpt2LMHeadModel.init(Vocab, model_config, key=jax.random.PRNGKey(0)) - save_checkpoint(model, None, 0, f"{f}/ckpt") + state = TrainerState(0, model, model, jax.random.PRNGKey(0), True) + + save_checkpoint(state, 0, f"{f}/ckpt") config = eval_lm.EvalLmConfig( data=data_config, diff --git a/tests/test_export_to_hf.py b/tests/test_export_to_hf.py index 3ce092789..ed6a0d4c0 100644 --- a/tests/test_export_to_hf.py +++ b/tests/test_export_to_hf.py @@ -34,7 +34,7 @@ def test_export_lm_to_hf(): # in our trainer, we only export the trainable params trainable, non_trainable = eqx.partition(model, is_inexact_arrayish) - save_checkpoint(trainable, None, 0, f"{tmpdir}/ckpt") + save_checkpoint({"model": trainable}, 0, f"{tmpdir}/ckpt") try: config = export_lm_to_hf.ConvertLmConfig( diff --git a/tests/test_tensorstore_serialization.py b/tests/test_tensorstore_serialization.py index 398bad1f0..caf8a365d 100644 --- a/tests/test_tensorstore_serialization.py +++ b/tests/test_tensorstore_serialization.py @@ -1,10 +1,12 @@ from tempfile import TemporaryDirectory +from typing import Any import equinox as eqx import jax import jax.numpy as jnp import numpy as np import optax +import pytest from chex import assert_trees_all_close import haliax as hax @@ -127,3 +129,26 @@ def make_state(key): jax.tree_util.tree_leaves(arrays_only(restored_model)), jax.tree_util.tree_leaves(arrays_only(initial_model)), ) + + +def test_tensorstore_ok_with_nones(): + A = hax.Axis("A", 10) + + class MyModule(eqx.Module): + a: Any + b: Any + + m = MyModule(a=None, b=hax.zeros(A)) + m2 = MyModule(a=None, b=hax.ones(A)) + + with TemporaryDirectory() as tmpdir: + tree_serialize_leaves_tensorstore(tmpdir, m) + m3 = tree_deserialize_leaves_tensorstore(tmpdir, m2) + assert m3.a is None + assert hax.all(m3.b == hax.zeros(A)) + + m3 = MyModule(a=hax.zeros(A), b=hax.ones(A)) + with TemporaryDirectory() as tmpdir: + tree_serialize_leaves_tensorstore(tmpdir, m2) + with pytest.raises(ValueError): + tree_deserialize_leaves_tensorstore(tmpdir, m3) diff --git a/tests/test_viz_lm.py b/tests/test_viz_lm.py index 29d8f943c..71d117055 100644 --- a/tests/test_viz_lm.py +++ b/tests/test_viz_lm.py @@ -47,7 +47,7 @@ def test_viz_lm(): Vocab = haliax.Axis("vocab", len(tok)) model = Gpt2LMHeadModel.init(Vocab, model_config, key=jax.random.PRNGKey(0)) - save_checkpoint(model, None, 0, f"{f}/ckpt") + save_checkpoint({"model": model}, 0, f"{f}/ckpt") config = viz_logprobs.VizGpt2Config( data=data_config, From 0b8b6e9dd5fb867a8eea6325720b712357ba5559 Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 12 Feb 2024 10:45:39 -0800 Subject: [PATCH 5/5] botched merge (#463) --- config/gpt2_1536_sophiah.yaml | 32 ++++++++++++++++++++++++++++++++ src/levanter/trainer.py | 7 ++++--- 2 files changed, 36 insertions(+), 3 deletions(-) create mode 100644 config/gpt2_1536_sophiah.yaml diff --git a/config/gpt2_1536_sophiah.yaml b/config/gpt2_1536_sophiah.yaml new file mode 100644 index 000000000..0d1008106 --- /dev/null +++ b/config/gpt2_1536_sophiah.yaml @@ -0,0 +1,32 @@ +data: + train_urls: + - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_train.{1..128}-of-128.jsonl.gz" + validation_urls: + - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_val.{1..8}-of-8.jsonl.gz" + cache_dir: "gs://levanter-data/tokenized/openwebtext/" + tokenizer: "gpt2" +model: + type: gpt2 + hidden_dim: 1536 + num_heads: 24 + num_layers: 48 + seq_len: 1024 + gradient_checkpointing: true + scale_attn_by_inverse_layer_idx: true +trainer: + tracker: + project: "levanter" + tags: [ "openwebtext", "gpt2"] + + mp: p=f32,c=bfloat16 + model_axis_size: 1 + per_device_parallelism: 2 + per_device_eval_parallelism: 8 +optimizer: + type: sophia-h + learning_rate: 2E-4 + weight_decay: 0.2 + min_lr_ratio: 0.1 + gamma: 0.01 + # sophia needs a minimum amount of warmup or it doesn't do well + warmup: 2000 diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index 41f5d04ab..f85336df4 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -493,13 +493,14 @@ def _take_train_step(self, state: S, model, grads, *batch, **batch_kwargs) -> S: """ # only train on the trainable parameters. We're leaning on JAX to do dead code elimination for us with hax.axis_mapping(self.parameter_axis_mapping): + opt_state = state.opt_state train_grads = _partition_trainable_params(grads, state.is_trainable)[0] trainable_model = _partition_trainable_params(model, state.is_trainable)[0] - updates, opt_state = self.optimizer.update(train_grads, state.opt_state, params=trainable_model) - partial_fn = lambda model: self.loss_fn(model, *batch, **batch_kwargs) - updates, opt_state = self.optimizer.update(grads, opt_state, params=trainable_model, obj_fn=partial_fn) + updates, opt_state = self.optimizer.update( + train_grads, opt_state, params=trainable_model, obj_fn=partial_fn + ) model = eqx.apply_updates(model, updates) return dataclasses.replace(state, _step=state._step + 1, model=model, opt_state=opt_state)