diff --git a/README.md b/README.md index e4f6a02..8f21ff5 100644 --- a/README.md +++ b/README.md @@ -188,7 +188,7 @@ MaxDiffusion started as a fork of [Diffusers](https://github.com/huggingface/dif Whether you are forking MaxDiffusion for your own needs or intending to contribute back to the community, a full suite of tests can be found in `tests` and `src/maxdiffusion/tests`. -To run unit tests, simply run: +To run unit tests simply run: ``` python -m pytest ``` diff --git a/docs/getting_started/first_run.md b/docs/getting_started/first_run.md index cabc720..900c3cf 100644 --- a/docs/getting_started/first_run.md +++ b/docs/getting_started/first_run.md @@ -12,9 +12,7 @@ multiple hosts. 1. Clone MaxDiffusion in your TPU VM. 1. Within the root directory of the MaxDiffusion `git` repo, install dependencies by running: ```bash -pip3 install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html -pip3 install -r requirements.txt -pip3 install . +bash setup.sh MODE=stable ``` ## Getting Starting: Multihost development diff --git a/setup.sh b/setup.sh index a8f32fd..b7d345e 100644 --- a/setup.sh +++ b/setup.sh @@ -108,3 +108,6 @@ fi # Install dependencies from requirements.txt pip3 install -U -r requirements.txt + +# Install maxdiffusion +pip3 install -U . \ No newline at end of file diff --git a/src/maxdiffusion/configs/base_xl_lightning.yml b/src/maxdiffusion/configs/base_xl_lightning.yml index 6dd6a4e..1e0eb76 100644 --- a/src/maxdiffusion/configs/base_xl_lightning.yml +++ b/src/maxdiffusion/configs/base_xl_lightning.yml @@ -110,6 +110,7 @@ ici_tensor_parallelism: 1 # Dataset # Replace with dataset path or train_data_dir. One has to be set. dataset_name: '' +dataset_type: 'tf' train_data_dir: '' dataset_config_name: '' jax_cache_dir: '' diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index b14c5bb..4ac1739 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -121,7 +121,7 @@ def write_metrics_for_gcs(metrics, step, config, running_metrics): """Writes metrics to gcs""" metrics_dict_step = _prepare_metrics_for_json(metrics, step, config.run_name) running_metrics.append(metrics_dict_step) - if (step + 1) % config.log_period == 0 or step == config.steps - 1: + if (step + 1) % config.log_period == 0 or step == config.max_train_steps - 1: start_step = (step // config.log_period) * config.log_period metrics_filename = f"metrics_step_{start_step:06}_to_step_{step:06}.txt" with open(metrics_filename, "w", encoding="utf8") as metrics_for_gcs: