Figure from the original paper.
This is an unofficial implementation of Diffusion-LM. For the official implementation, please refer here.
If you are using poetry, run the following command to install required dependencies.
poetry install
Next, activate your virtual environment.
poetry shell
You can find more details about the required packages in pyproject.toml
.
After that, initialize an 🤗Accelerate environment with:
accelerate config
Alternatively, you can set up a default Accelerate configuration without answering questions about your environment using:
accelerate config default
This repository allows you to train the Diffusion-LM on E2E datasets using the following commands.
accelerate launch scripts/train.py --expn foo
Available arguments
--expn
(-e
): The experimental name, which is used as the basename of the output directory. If this argument is not provided, the directory name is assigned based on the current time.--wandb
(-w
): Indicates whether to use the Weights & Biases tracker.
This repository also supports the training of GPT2 classifiers for control by Semantic Content.
accelerate launch scripts/clf_train.py --output output/foo
Available arguments
-o
,--output
: The directory where the training results will be saved.-mc
,--model_ckpt
(default='checkpoints/pytorch_model_1.bin'): Path to the Diffusion-LM checkpoint (from the path specified in the--output
argument).
After training the Diffusion-LM and the GPT2 classifier, you can perform conditional sampling.
accelerate launch scripts/sample.py --output output/foo --control_label 'food : Japanese'
Available arguments
-o
,--output
: The directory where the training results will be saved.-n
,--n_samples
(default=16): The number of samples (used as batch size).-mc
,--model_ckpt
(default='checkpoints/pytorch_model_1.bin'): Path to the Diffusion-LM checkpoint (from the path specified in the--output
argument).-ud
,--use_ddpm
(default=False): Whether to use DDPM sampling (the default is DDIM).-cc
,--clf_ckpt
(default='classifier/pytorch_model.bin'): Path to the classifier checkpoint (from the path specified in the output argument).Path to the classifier checkpoint (from the path specified in the--output
argument).-cl
,--control_label
(default=None): Label for plug-and-play control.
For unconditional sampling, it's only necessary to train the Diffusion-LM (there's no need to train the GPT2 classifier).classifier).
accelerate launch scripts/sample.py --output output/foo
If you find this repository helpful, please consider giving a star ⭐!