From a58a75fa8c093d2ab33b2665e30b6a406e43ceee Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Wed, 4 Dec 2024 17:44:40 +0000 Subject: [PATCH] fixes failing smoke tests --- README.md | 2 +- src/maxdiffusion/configs/base_xl_lightning.yml | 1 + src/maxdiffusion/max_utils.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index e4f6a02..54fe093 100644 --- a/README.md +++ b/README.md @@ -188,7 +188,7 @@ MaxDiffusion started as a fork of [Diffusers](https://github.com/huggingface/dif Whether you are forking MaxDiffusion for your own needs or intending to contribute back to the community, a full suite of tests can be found in `tests` and `src/maxdiffusion/tests`. -To run unit tests, simply run: +To run unit tests, you'll need to [install gcsfuse](https://cloud.google.com/storage/docs/cloud-storage-fuse/install) then simply run: ``` python -m pytest ``` diff --git a/src/maxdiffusion/configs/base_xl_lightning.yml b/src/maxdiffusion/configs/base_xl_lightning.yml index 6dd6a4e..1e0eb76 100644 --- a/src/maxdiffusion/configs/base_xl_lightning.yml +++ b/src/maxdiffusion/configs/base_xl_lightning.yml @@ -110,6 +110,7 @@ ici_tensor_parallelism: 1 # Dataset # Replace with dataset path or train_data_dir. One has to be set. dataset_name: '' +dataset_type: 'tf' train_data_dir: '' dataset_config_name: '' jax_cache_dir: '' diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index b14c5bb..4ac1739 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -121,7 +121,7 @@ def write_metrics_for_gcs(metrics, step, config, running_metrics): """Writes metrics to gcs""" metrics_dict_step = _prepare_metrics_for_json(metrics, step, config.run_name) running_metrics.append(metrics_dict_step) - if (step + 1) % config.log_period == 0 or step == config.steps - 1: + if (step + 1) % config.log_period == 0 or step == config.max_train_steps - 1: start_step = (step // config.log_period) * config.log_period metrics_filename = f"metrics_step_{start_step:06}_to_step_{step:06}.txt" with open(metrics_filename, "w", encoding="utf8") as metrics_for_gcs: