diff --git a/end_to_end/tpu/eval_assert.py b/end_to_end/tpu/eval_assert.py new file mode 100644 index 0000000..0bf02c2 --- /dev/null +++ b/end_to_end/tpu/eval_assert.py @@ -0,0 +1,59 @@ +""" + Copyright 2024 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + """ + +# pylint: skip-file +"""Reads and asserts over target values""" +from absl import app +from typing import Sequence +import json + +def get_last_n_data(metrics_file, target, n=10): + last_n_data = [] + with open(metrics_file, 'r', encoding='utf8') as file: + lines = file.readlines() + for line in lines[::-1]: + metrics = json.loads(line) + if target in metrics: + last_n_data.append(metrics[target]) + if len(last_n_data) >= n: + break + return last_n_data + + +def test_final_loss(metrics_file, target_loss): + target_loss = float(target_loss) + with open(metrics_file, 'r', encoding='utf8') as metrics: + use_last_n_data = 10 + last_n_data = get_last_n_data(metrics_file, 'learning/loss', use_last_n_data) + avg_last_n_data = sum(last_n_data) / len(last_n_data) + print(f"Mean of last {len(last_n_data)} losses is {avg_last_n_data}") + print(f"Target loss is {target_loss}") + assert avg_last_n_data < target_loss + print('Final loss test passed.') + + +def main(argv: Sequence[str]) -> None: + + _, test_scenario, *test_vars = argv + + if test_scenario == 'final_loss': + test_final_loss(*test_vars) + else: + raise ValueError(f"Unrecognized test_scenario {test_scenario}") + + +if __name__ == "__main__": + app.run(main) \ No newline at end of file diff --git a/end_to_end/tpu/test_sdxl_training_loss.sh b/end_to_end/tpu/test_sdxl_training_loss.sh new file mode 100644 index 0000000..44fa2e2 --- /dev/null +++ b/end_to_end/tpu/test_sdxl_training_loss.sh @@ -0,0 +1,25 @@ +#!/bin/bash +set -ex + +echo "Running test_sdxl_training_loss.sh" + +# Set environment variables +for ARGUMENT in "$@"; do + IFS='=' read -r KEY VALUE <<< "$ARGUMENT" + export "$KEY"="$VALUE" +done + +TRAIN_CMD="python src/maxdiffusion/train_sdxl.py src/maxdiffusion/configs/base_xl.yml \ + pretrained_model_name_or_path=gs://maxdiffusion-github-runner-test-assets/checkpoints/models--stabilityai--stable-diffusion-xl-base-1.0 \ + revision=refs/pr/95 activations_dtype=bfloat16 weights_dtype=bfloat16 metrics_file=metrics.txt write_metrics=True \ + dataset_name=gs://jfacevedo-maxdiffusion-v5p/pokemon-datasets/pokemon-gpt4-captions_xl resolution=1024 per_device_batch_size=1 \ + jax_cache_dir=gs://jfacevedo-maxdiffusion/cache_dir/ max_train_steps=$STEPS attention=flash run_name=sdxl-fsdp-v5p-64-ddp enable_profiler=True \ + run_name=$RUN_NAME \ + output_dir=$OUTPUT_DIR " + +# Train +export LIBTPU_INIT_ARGS="" +$TRAIN_CMD + +# Assert training loss is smaller than input LOSS_THRESHOLD +python3 end_to_end/tpu/eval_assert.py final_loss metrics.txt $LOSS_THRESHOLD \ No newline at end of file diff --git a/src/maxdiffusion/trainers/sdxl_trainer.py b/src/maxdiffusion/trainers/sdxl_trainer.py index e746b1b..b52b752 100644 --- a/src/maxdiffusion/trainers/sdxl_trainer.py +++ b/src/maxdiffusion/trainers/sdxl_trainer.py @@ -188,7 +188,7 @@ def training_loop(self, p_train_step, pipeline, params, train_states, data_itera num_model_parameters = max_utils.calculate_num_params_from_pytree(unet_state.params) max_utils.add_text_to_summary_writer("number_model_parameters", str(num_model_parameters), writer) - max_utils.add_text_to_summary_writer("libtpu_init_args", os.environ["LIBTPU_INIT_ARGS"], writer) + max_utils.add_text_to_summary_writer("libtpu_init_args", os.environ.get("LIBTPU_INIT_ARGS",""), writer) max_utils.add_config_to_summary_writer(self.config, writer) if jax.process_index() == 0: