From 12b80de1d3c967a33e2e157297167f9aaa8b1a34 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carl=20Thom=C3=A9?= Date: Mon, 15 Jul 2024 15:07:19 +0200 Subject: [PATCH 1/4] Add basic pre-commit config with black and isort --- .pre-commit-config.yaml | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) create mode 100644 .pre-commit-config.yaml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..61ddd38f --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,20 @@ +# See https://pre-commit.com for more information +# See https://pre-commit.com/hooks.html for more hooks +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.6.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-yaml + - id: check-added-large-files + + - repo: https://github.com/psf/black + rev: 24.4.2 + hooks: + - id: black + + - repo: https://github.com/PyCQA/isort + rev: 5.13.2 + hooks: + - id: isort From 9f4be0740eaba655c1f6750da95abea960061a65 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carl=20Thom=C3=A9?= Date: Mon, 15 Jul 2024 15:07:36 +0200 Subject: [PATCH 2/4] Add GitHub Actions workflow for pre-commit checking --- .github/workflows/pre-commit.yaml | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) create mode 100644 .github/workflows/pre-commit.yaml diff --git a/.github/workflows/pre-commit.yaml b/.github/workflows/pre-commit.yaml new file mode 100644 index 00000000..55c15892 --- /dev/null +++ b/.github/workflows/pre-commit.yaml @@ -0,0 +1,16 @@ +name: pre-commit + +on: + pull_request: + push: + branches: [main] + +jobs: + pre-commit: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v4 + with: + cache: pip + - uses: pre-commit/action@v3.0.1 From 9c17df19e91bea5a2520cccdc3e92463cf1b69ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carl=20Thom=C3=A9?= Date: Mon, 15 Jul 2024 15:09:00 +0200 Subject: [PATCH 3/4] Configure black and isort --- pyproject.toml | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 7fd26b97..66805ee3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,9 @@ [build-system] requires = ["setuptools"] -build-backend = "setuptools.build_meta" \ No newline at end of file +build-backend = "setuptools.build_meta" + +[tool.black] +line-length = 120 + +[tool.isort] +profile = "black" From db6e369601978ac66231cc4a053831d458c3a277 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carl=20Thom=C3=A9?= Date: Mon, 15 Jul 2024 15:10:20 +0200 Subject: [PATCH 4/4] `pre-commit run --all` --- .gitignore | 2 +- LICENSES/LICENSE_ADP.txt | 2 +- LICENSES/LICENSE_AURALOSS.txt | 2 +- LICENSES/LICENSE_DESCRIPT.txt | 2 +- LICENSES/LICENSE_META.txt | 2 +- LICENSES/LICENSE_NVIDIA.txt | 4 +- LICENSES/LICENSE_XTRANSFORMERS.txt | 2 +- README.md | 10 +- defaults.ini | 12 +- docs/autoencoders.md | 18 +- docs/conditioning.md | 8 +- docs/datasets.md | 8 +- docs/diffusion.md | 8 +- docs/pretransforms.md | 2 +- run_gradio.py | 38 +- scripts/ds_zero_to_pl_ckpt.py | 7 +- setup.py | 80 +- stable_audio_tools/__init__.py | 2 +- .../custom_metadata/custom_md_example.py | 2 +- .../local_training_example.json | 2 +- .../dataset_configs/s3_wds_example.json | 2 +- .../autoencoders/dac_2048_32_vae.json | 2 +- .../autoencoders/encodec_musicgen_rvq.json | 2 +- .../autoencoders/stable_audio_1_0_vae.json | 2 +- .../autoencoders/stable_audio_2_0_vae.json | 2 +- .../dance_diffusion/dance_diffusion_base.json | 2 +- .../dance_diffusion_base_16k.json | 2 +- .../dance_diffusion_base_44k.json | 2 +- .../dance_diffusion_large.json | 2 +- .../txt2audio/stable_audio_1_0.json | 2 +- .../txt2audio/stable_audio_2_0.json | 2 +- stable_audio_tools/data/dataset.py | 230 +++--- stable_audio_tools/data/utils.py | 75 +- stable_audio_tools/inference/generation.py | 206 +++-- stable_audio_tools/inference/sampling.py | 153 ++-- stable_audio_tools/inference/utils.py | 8 +- stable_audio_tools/interface/gradio.py | 456 ++++++----- stable_audio_tools/models/__init__.py | 2 +- stable_audio_tools/models/adp.py | 257 +++---- stable_audio_tools/models/autoencoders.py | 408 +++++----- stable_audio_tools/models/blocks.py | 147 ++-- stable_audio_tools/models/bottleneck.py | 120 +-- .../models/codebook_patterns.py | 64 +- stable_audio_tools/models/conditioners.py | 311 ++++---- stable_audio_tools/models/diffusion.py | 471 ++++++------ stable_audio_tools/models/diffusion_prior.py | 27 +- stable_audio_tools/models/discriminators.py | 84 +- stable_audio_tools/models/dit.py | 176 +++-- stable_audio_tools/models/factory.py | 83 +- stable_audio_tools/models/lm.py | 295 +++---- stable_audio_tools/models/lm_backbone.py | 105 +-- stable_audio_tools/models/local_attention.py | 173 ++--- stable_audio_tools/models/pqmf.py | 188 ++--- stable_audio_tools/models/pretrained.py | 13 +- stable_audio_tools/models/pretransforms.py | 92 ++- stable_audio_tools/models/transformer.py | 448 ++++++----- stable_audio_tools/models/utils.py | 11 +- stable_audio_tools/models/wavelets.py | 40 +- stable_audio_tools/training/__init__.py | 5 +- stable_audio_tools/training/autoencoders.py | 346 +++++---- stable_audio_tools/training/diffusion.py | 727 +++++++++--------- stable_audio_tools/training/factory.py | 120 +-- stable_audio_tools/training/lm.py | 145 ++-- .../training/losses/__init__.py | 2 +- .../training/losses/auraloss.py | 62 +- stable_audio_tools/training/losses/losses.py | 32 +- stable_audio_tools/training/utils.py | 31 +- train.py | 77 +- unwrap_model.py | 105 +-- 69 files changed, 3506 insertions(+), 3024 deletions(-) diff --git a/.gitignore b/.gitignore index 3e6aee68..da4dd75e 100644 --- a/.gitignore +++ b/.gitignore @@ -161,4 +161,4 @@ cython_debug/ *.ckpt *.wav -wandb/* \ No newline at end of file +wandb/* diff --git a/LICENSES/LICENSE_ADP.txt b/LICENSES/LICENSE_ADP.txt index f418ac88..3fcd96f4 100644 --- a/LICENSES/LICENSE_ADP.txt +++ b/LICENSES/LICENSE_ADP.txt @@ -18,4 +18,4 @@ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. \ No newline at end of file +SOFTWARE. diff --git a/LICENSES/LICENSE_AURALOSS.txt b/LICENSES/LICENSE_AURALOSS.txt index f49a4e16..261eeb9e 100644 --- a/LICENSES/LICENSE_AURALOSS.txt +++ b/LICENSES/LICENSE_AURALOSS.txt @@ -198,4 +198,4 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and - limitations under the License. \ No newline at end of file + limitations under the License. diff --git a/LICENSES/LICENSE_DESCRIPT.txt b/LICENSES/LICENSE_DESCRIPT.txt index 2569ec0b..8356bd6f 100644 --- a/LICENSES/LICENSE_DESCRIPT.txt +++ b/LICENSES/LICENSE_DESCRIPT.txt @@ -18,4 +18,4 @@ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. \ No newline at end of file +SOFTWARE. diff --git a/LICENSES/LICENSE_META.txt b/LICENSES/LICENSE_META.txt index a45a376f..b93be905 100644 --- a/LICENSES/LICENSE_META.txt +++ b/LICENSES/LICENSE_META.txt @@ -18,4 +18,4 @@ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. \ No newline at end of file +SOFTWARE. diff --git a/LICENSES/LICENSE_NVIDIA.txt b/LICENSES/LICENSE_NVIDIA.txt index e9663595..0ac31a64 100644 --- a/LICENSES/LICENSE_NVIDIA.txt +++ b/LICENSES/LICENSE_NVIDIA.txt @@ -10,7 +10,7 @@ copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. +copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, @@ -18,4 +18,4 @@ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. \ No newline at end of file +SOFTWARE. diff --git a/LICENSES/LICENSE_XTRANSFORMERS.txt b/LICENSES/LICENSE_XTRANSFORMERS.txt index b7662801..cad43213 100644 --- a/LICENSES/LICENSE_XTRANSFORMERS.txt +++ b/LICENSES/LICENSE_XTRANSFORMERS.txt @@ -18,4 +18,4 @@ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. \ No newline at end of file +SOFTWARE. diff --git a/README.md b/README.md index a8f5ce3f..afb096c8 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,7 @@ Development for the repo is done in Python 3.8.10 # Interface -A basic Gradio interface is provided to test out trained models. +A basic Gradio interface is provided to test out trained models. For example, to create an interface for the [`stable-audio-open-1.0`](https://huggingface.co/stabilityai/stable-audio-open-1.0) model, once you've accepted the terms for the model on Hugging Face, you can run: ```bash @@ -37,7 +37,7 @@ The `run_gradio.py` script accepts the following command line arguments: - Path to the model config file for a local model - `--ckpt-path` - Path to unwrapped model checkpoint file for a local model -- `--pretransform-ckpt-path` +- `--pretransform-ckpt-path` - Path to an unwrapped pretransform checkpoint, replaces the pretransform in the model, useful for testing out fine-tuned decoders - Optional - `--share` @@ -69,7 +69,7 @@ $ python3 ./train.py --dataset-config /path/to/dataset/config --model-config /pa The `--name` parameter will set the project name for your Weights and Biases run. ## Training wrappers and model unwrapping -`stable-audio-tools` uses PyTorch Lightning to facilitate multi-GPU and multi-node training. +`stable-audio-tools` uses PyTorch Lightning to facilitate multi-GPU and multi-node training. When a model is being trained, it is wrapped in a "training wrapper", which is a `pl.LightningModule` that contains all of the relevant objects needed only for training. That includes things like discriminators for autoencoders, EMA copies of models, and all of the optimizer states. @@ -88,7 +88,7 @@ Unwrapped model checkpoints are required for: - Fine-tuning a pre-trained model with a modified configuration (i.e. partial initialization) ## Fine-tuning -Fine-tuning a model involves continuning a training run from a pre-trained checkpoint. +Fine-tuning a model involves continuning a training run from a pre-trained checkpoint. To continue a training run from a wrapped model checkpoint, you can pass in the checkpoint path to `train.py` with the `--ckpt-path` flag. @@ -154,4 +154,4 @@ The following properties are defined in the top level of the model configuration # Todo - [ ] Add troubleshooting section -- [ ] Add contribution guidelines +- [ ] Add contribution guidelines diff --git a/defaults.ini b/defaults.ini index 9f240a37..8b46db15 100644 --- a/defaults.ini +++ b/defaults.ini @@ -5,13 +5,13 @@ name = stable_audio_tools # the batch size -batch_size = 8 +batch_size = 8 # number of GPUs to use for training -num_gpus = 1 +num_gpus = 1 # number of nodes to use for training -num_nodes = 1 +num_nodes = 1 # Multi-GPU strategy for PyTorch Lightning strategy = "" @@ -29,8 +29,8 @@ seed = 42 accum_batches = 1 # Number of steps between checkpoints -checkpoint_every = 10000 - +checkpoint_every = 10000 + # trainer checkpoint file to restart training from ckpt_path = '' @@ -53,4 +53,4 @@ save_dir = '' gradient_clip_val = 0.0 # remove the weight norm from the pretransform model -remove_pretransform_weight_norm = '' \ No newline at end of file +remove_pretransform_weight_norm = '' diff --git a/docs/autoencoders.md b/docs/autoencoders.md index 7fb48216..6bdb8d34 100644 --- a/docs/autoencoders.md +++ b/docs/autoencoders.md @@ -1,5 +1,5 @@ # Autoencoders -At a high level, autoencoders are models constructed of two parts: an *encoder*, and a *decoder*. +At a high level, autoencoders are models constructed of two parts: an *encoder*, and a *decoder*. The *encoder* takes in an sequence (such as mono or stereo audio) and outputs a compressed representation of that sequence as a d-channel "latent sequence", usually heavily downsampled by a constant factor. @@ -41,7 +41,7 @@ The `training` config in the autoencoder model config file should have the follo - `learning_rate` - The learning rate to use during training - `use_ema` - - If true, a copy of the model weights is maintained during training and updated as an exponential moving average of the trained model's weights. + - If true, a copy of the model weights is maintained during training and updated as an exponential moving average of the trained model's weights. - Optional. Default: `false` - `warmup_steps` - The number of training steps before turning on adversarial losses @@ -62,11 +62,11 @@ There are few different types of losses that are used for autoencoder training, Hyperparameters fo these losses as well as loss weighting factors can be configured in the `loss_configs` property in the `training` config. ### Spectral losses -Multi-resolution STFT losses are the main reconstruction loss used for our audio autoencoders. We use the [auraloss](https://github.com/csteinmetz1/auraloss/tree/main/auraloss) library for our spectral loss functions. +Multi-resolution STFT losses are the main reconstruction loss used for our audio autoencoders. We use the [auraloss](https://github.com/csteinmetz1/auraloss/tree/main/auraloss) library for our spectral loss functions. -For mono autoencoders (`io_channels` == 1), we use the [MultiResolutionSTFTLoss](https://github.com/csteinmetz1/auraloss/blob/1576b0cd6e927abc002b23cf3bfc455b660f663c/auraloss/freq.py#L329) module. +For mono autoencoders (`io_channels` == 1), we use the [MultiResolutionSTFTLoss](https://github.com/csteinmetz1/auraloss/blob/1576b0cd6e927abc002b23cf3bfc455b660f663c/auraloss/freq.py#L329) module. -For stereo autoencoders (`io_channels` == 2), we use the [SumAndDifferenceSTFTLoss](https://github.com/csteinmetz1/auraloss/blob/1576b0cd6e927abc002b23cf3bfc455b660f663c/auraloss/freq.py#L533) module. +For stereo autoencoders (`io_channels` == 2), we use the [SumAndDifferenceSTFTLoss](https://github.com/csteinmetz1/auraloss/blob/1576b0cd6e927abc002b23cf3bfc455b660f663c/auraloss/freq.py#L533) module. #### Example config ```json @@ -130,7 +130,7 @@ The only property to set for autoencoder training demos is the `demo_every` prop ``` # Encoder and decoder types -Encoders and decoders are defined separately in the model configuration, so encoders and decoders from different model architectures and libraries can be used interchangeably. +Encoders and decoders are defined separately in the model configuration, so encoders and decoders from different model architectures and libraries can be used interchangeably. ## Oobleck Oobleck is Harmonai's in-house autoencoder architecture, implementing features from a variety of other autoencoder architectures. @@ -229,7 +229,7 @@ In our terminology, the "bottleneck" of an autoencoder is a module placed betwee Bottlenecks have a similar interface to the autoencoder with `encode()` and `decode()` functions defined. Some bottlenecks return extra information in addition to the output latent series, such as quantized token indices, or additional losses to be considered during training. -To define a bottleneck for the autoencoder, you can provide the `bottleneck` object in the autoencoder's model configuration, with the following +To define a bottleneck for the autoencoder, you can provide the `bottleneck` object in the autoencoder's model configuration, with the following ## VAE @@ -311,7 +311,7 @@ Residual vector quantization (RVQ) is currently the leading method for learning This RVQ bottleneck uses [lucidrains' implementation](https://github.com/lucidrains/vector-quantize-pytorch/tree/master) from the `vector-quantize-pytorch` repo, which provides a lot of different quantizer options. The bottleneck config is passed through to the `ResidualVQ` [constructor](https://github.com/lucidrains/vector-quantize-pytorch/blob/0c6cea24ce68510b607f2c9997e766d9d55c085b/vector_quantize_pytorch/residual_vq.py#L26). -**Note: This RVQ implementation uses manual replacement of codebook vectors to reduce codebook collapse. This does not work with multi-GPU training as the random replacement is not synchronized across devices.** +**Note: This RVQ implementation uses manual replacement of codebook vectors to reduce codebook collapse. This does not work with multi-GPU training as the random replacement is not synchronized across devices.** ### Example config ```json @@ -327,7 +327,7 @@ This RVQ bottleneck uses [lucidrains' implementation](https://github.com/lucidra ``` ## DAC RVQ -This is the residual vector quantization implementation from the `descript-audio-codec` repo. It differs from the above implementation in that it does not use manual replacements to improve codebook usage, but instead uses learnable linear layers to project the latents down to a lower-dimensional space before performing the individual quantization operations. This means it's compatible with distributed training. +This is the residual vector quantization implementation from the `descript-audio-codec` repo. It differs from the above implementation in that it does not use manual replacements to improve codebook usage, but instead uses learnable linear layers to project the latents down to a lower-dimensional space before performing the individual quantization operations. This means it's compatible with distributed training. The bottleneck config is passed directly into the `ResidualVectorQuantize` [constructor](https://github.com/descriptinc/descript-audio-codec/blob/c7cfc5d2647e26471dc394f95846a0830e7bec34/dac/nn/quantize.py#L97). diff --git a/docs/conditioning.md b/docs/conditioning.md index c690701f..b2ad4c2e 100644 --- a/docs/conditioning.md +++ b/docs/conditioning.md @@ -27,7 +27,7 @@ Input concatenation applies a spatial conditioning signal to the model that corr Signals used for input concatenation conditioning should be of the shape `[batch, channels, sequence]` and must be the same length as the model's input. # Conditioners and conditioning configs -`stable-audio-tools` uses Conditioner modules to translate human-readable metadata such as text prompts or a number of seconds into tensors that the model can take as input. +`stable-audio-tools` uses Conditioner modules to translate human-readable metadata such as text prompts or a number of seconds into tensors that the model can take as input. Each conditioner has a corresponding `id` that it expects to find in the conditioning dictionary provided during training or inference. Each conditioner takes in the relevant conditioning data and returns a tuple containing the corresponding tensor and a mask. @@ -35,7 +35,7 @@ The ConditionedDiffusionModelWrapper manages the translation between the user-pr To apply conditioning to a model, you must provide a `conditioning` configuration in the model's config. At the moment, we only support conditioning diffusion models though the `diffusion_cond` model type. -The `conditioning` configuration should contain a `configs` array, which allows you to define multiple conditioning signals. +The `conditioning` configuration should contain a `configs` array, which allows you to define multiple conditioning signals. Each item in `configs` array should define the `id` for the corresponding metadata, the type of conditioner to be used, and the config for that conditioner. @@ -74,7 +74,7 @@ If you set `enable_grad` to `true`, the T5 model will be un-frozen and saved wit T5 encodings are only compatible with cross attention conditioning. -#### Example config +#### Example config ```json { "id": "prompt", @@ -155,4 +155,4 @@ Number embeddings are compatible with global conditioning and cross attention co "max_val": 512 } } -``` \ No newline at end of file +``` diff --git a/docs/datasets.md b/docs/datasets.md index 931c1ab2..00d4faf6 100644 --- a/docs/datasets.md +++ b/docs/datasets.md @@ -11,7 +11,7 @@ To use a local directory of audio samples, set the `dataset_type` property in yo This will load all of the compatible audio files from the provided directory and all subdirectories. -### Example config +### Example config ```json { "dataset_type": "audio_dir", @@ -43,9 +43,9 @@ To load audio files and related metadata from .tar files in the WebDataset forma ``` # Custom metadata -To customize the metadata provided to the conditioners during model training, you can provide a separate custom metadata module to the dataset config. This metadata module should be a Python file that must contain a function called `get_custom_metadata` that takes in two parameters, `info`, and `audio`, and returns a dictionary. +To customize the metadata provided to the conditioners during model training, you can provide a separate custom metadata module to the dataset config. This metadata module should be a Python file that must contain a function called `get_custom_metadata` that takes in two parameters, `info`, and `audio`, and returns a dictionary. -For local training, the `info` parameter will contain a few pieces of information about the loaded audio file, such as the path, and information about how the audio was cropped from the original training sample. For WebDataset datasets, it will also contain the metadata from the related JSON files. +For local training, the `info` parameter will contain a few pieces of information about the loaded audio file, such as the path, and information about how the audio was cropped from the original training sample. For WebDataset datasets, it will also contain the metadata from the related JSON files. The `audio` parameter contains the audio sample that will be passed to the model at training time. This lets you analyze the audio for extra properties that you can then pass in as extra conditioning signals. @@ -72,4 +72,4 @@ def get_custom_metadata(info, audio): # Pass in the relative path of the audio file as the prompt return {"prompt": info["relpath"]} -``` \ No newline at end of file +``` diff --git a/docs/diffusion.md b/docs/diffusion.md index 682031be..675282fa 100644 --- a/docs/diffusion.md +++ b/docs/diffusion.md @@ -48,7 +48,7 @@ The `training` config in the diffusion model config file should have the followi - The learning rate to use during training - Defaults to constant learning rate, can be overridden with `optimizer_configs` - `use_ema` - - If true, a copy of the model weights is maintained during training and updated as an exponential moving average of the trained model's weights. + - If true, a copy of the model weights is maintained during training and updated as an exponential moving average of the trained model's weights. - Optional. Default: `true` - `log_loss_info` - If true, additional diffusion loss info will be gathered across all GPUs and displayed during training @@ -138,11 +138,11 @@ This is our custom implementation of a transformer model, based on the `x-transf This model type uses the `ContinuousTransformerWrapper` class from the https://github.com/lucidrains/x-transformers repository as the diffusion transformer backbone. `x-transformers` is a great baseline transformer implementation with lots of options for various experimental settings. -It's great for testing out experimental features without implementing them yourself, but the implementations might not be fully optimized, and breaking changes may be introduced without much warning. +It's great for testing out experimental features without implementing them yourself, but the implementations might not be fully optimized, and breaking changes may be introduced without much warning. ## Diffusion U-Net -U-Nets use a hierarchical architecture to gradually downsample the input data before more heavy processing is performed, then upsample the data again, using skip connections to pass data across the downsampling "valley" (the "U" in the name) to the upsampling layer at the same resolution. +U-Nets use a hierarchical architecture to gradually downsample the input data before more heavy processing is performed, then upsample the data again, using skip connections to pass data across the downsampling "valley" (the "U" in the name) to the upsampling layer at the same resolution. ### audio-diffusion-pytorch U-Net (ADP) @@ -150,4 +150,4 @@ This model type uses a modified implementation of the `UNetCFG1D` class from ver ### Dance Diffusion U-Net -This is a reimplementation of the U-Net used in [Dance Diffusion](https://github.com/Harmonai-org/sample-generator). It has minimal conditioning support, only really supporting global conditioning. Mostly used for unconditional diffusion models. \ No newline at end of file +This is a reimplementation of the U-Net used in [Dance Diffusion](https://github.com/Harmonai-org/sample-generator). It has minimal conditioning support, only really supporting global conditioning. Mostly used for unconditional diffusion models. diff --git a/docs/pretransforms.md b/docs/pretransforms.md index e5ea8a5b..e349d9e8 100644 --- a/docs/pretransforms.md +++ b/docs/pretransforms.md @@ -40,4 +40,4 @@ Wavelet pretransforms take the following properties: - The specific wavelet from [PyWavelets](https://pywavelets.readthedocs.io/en/latest/ref/wavelets.html) to use, currently limited to `"bior2.2", "bior2.4", "bior2.6", "bior2.8", "bior4.4", "bior6.8"` ## Future work -We hope to add more filters and transforms to this list, including PQMF and STFT transforms. \ No newline at end of file +We hope to add more filters and transforms to this list, including PQMF and STFT transforms. diff --git a/run_gradio.py b/run_gradio.py index a3032317..f6fbc629 100644 --- a/run_gradio.py +++ b/run_gradio.py @@ -1,32 +1,38 @@ +import json + +import torch + from stable_audio_tools import get_pretrained_model from stable_audio_tools.interface.gradio import create_ui -import json -import torch def main(args): torch.manual_seed(42) interface = create_ui( - model_config_path = args.model_config, - ckpt_path=args.ckpt_path, - pretrained_name=args.pretrained_name, + model_config_path=args.model_config, + ckpt_path=args.ckpt_path, + pretrained_name=args.pretrained_name, pretransform_ckpt_path=args.pretransform_ckpt_path, - model_half=args.model_half + model_half=args.model_half, ) interface.queue() interface.launch(share=args.share, auth=(args.username, args.password) if args.username is not None else None) + if __name__ == "__main__": import argparse - parser = argparse.ArgumentParser(description='Run gradio interface') - parser.add_argument('--pretrained-name', type=str, help='Name of pretrained model', required=False) - parser.add_argument('--model-config', type=str, help='Path to model config', required=False) - parser.add_argument('--ckpt-path', type=str, help='Path to model checkpoint', required=False) - parser.add_argument('--pretransform-ckpt-path', type=str, help='Optional to model pretransform checkpoint', required=False) - parser.add_argument('--share', action='store_true', help='Create a publicly shareable link', required=False) - parser.add_argument('--username', type=str, help='Gradio username', required=False) - parser.add_argument('--password', type=str, help='Gradio password', required=False) - parser.add_argument('--model-half', action='store_true', help='Whether to use half precision', required=False) + + parser = argparse.ArgumentParser(description="Run gradio interface") + parser.add_argument("--pretrained-name", type=str, help="Name of pretrained model", required=False) + parser.add_argument("--model-config", type=str, help="Path to model config", required=False) + parser.add_argument("--ckpt-path", type=str, help="Path to model checkpoint", required=False) + parser.add_argument( + "--pretransform-ckpt-path", type=str, help="Optional to model pretransform checkpoint", required=False + ) + parser.add_argument("--share", action="store_true", help="Create a publicly shareable link", required=False) + parser.add_argument("--username", type=str, help="Gradio username", required=False) + parser.add_argument("--password", type=str, help="Gradio password", required=False) + parser.add_argument("--model-half", action="store_true", help="Whether to use half precision", required=False) args = parser.parse_args() - main(args) \ No newline at end of file + main(args) diff --git a/scripts/ds_zero_to_pl_ckpt.py b/scripts/ds_zero_to_pl_ckpt.py index 528a5160..7c41ce0d 100644 --- a/scripts/ds_zero_to_pl_ckpt.py +++ b/scripts/ds_zero_to_pl_ckpt.py @@ -1,5 +1,8 @@ import argparse -from lightning.pytorch.utilities.deepspeed import convert_zero_checkpoint_to_fp32_state_dict + +from lightning.pytorch.utilities.deepspeed import ( + convert_zero_checkpoint_to_fp32_state_dict, +) if __name__ == "__main__": @@ -11,4 +14,4 @@ # lightning deepspeed has saved a directory instead of a file save_path = args.save_path output_path = args.output_path - convert_zero_checkpoint_to_fp32_state_dict(save_path, output_path) \ No newline at end of file + convert_zero_checkpoint_to_fp32_state_dict(save_path, output_path) diff --git a/setup.py b/setup.py index 7e7470d3..821b367c 100644 --- a/setup.py +++ b/setup.py @@ -1,44 +1,44 @@ -from setuptools import setup, find_packages +from setuptools import find_packages, setup setup( - name='stable-audio-tools', - version='0.0.16', - url='https://github.com/Stability-AI/stable-audio-tools.git', - author='Stability AI', - description='Training and inference tools for generative audio models from Stability AI', - packages=find_packages(), + name="stable-audio-tools", + version="0.0.16", + url="https://github.com/Stability-AI/stable-audio-tools.git", + author="Stability AI", + description="Training and inference tools for generative audio models from Stability AI", + packages=find_packages(), install_requires=[ - 'aeiou==0.0.20', - 'alias-free-torch==0.0.6', - 'auraloss==0.4.0', - 'descript-audio-codec==1.0.0', - 'einops==0.7.0', - 'einops-exts==0.0.4', - 'ema-pytorch==0.2.3', - 'encodec==0.1.1', - 'gradio>=3.42.0', - 'huggingface_hub', - 'importlib-resources==5.12.0', - 'k-diffusion==0.1.1', - 'laion-clap==1.1.4', - 'local-attention==1.8.6', - 'pandas==2.0.2', - 'pedalboard==0.7.4', - 'prefigure==0.0.9', - 'pytorch_lightning==2.1.0', - 'PyWavelets==1.4.1', - 'safetensors', - 'sentencepiece==0.1.99', - 's3fs', - 'torch>=2.0.1', - 'torchaudio>=2.0.2', - 'torchmetrics==0.11.4', - 'tqdm', - 'transformers', - 'v-diffusion-pytorch==0.0.2', - 'vector-quantize-pytorch==1.9.14', - 'wandb==0.15.4', - 'webdataset==0.2.48', - 'x-transformers<1.27.0' + "aeiou==0.0.20", + "alias-free-torch==0.0.6", + "auraloss==0.4.0", + "descript-audio-codec==1.0.0", + "einops==0.7.0", + "einops-exts==0.0.4", + "ema-pytorch==0.2.3", + "encodec==0.1.1", + "gradio>=3.42.0", + "huggingface_hub", + "importlib-resources==5.12.0", + "k-diffusion==0.1.1", + "laion-clap==1.1.4", + "local-attention==1.8.6", + "pandas==2.0.2", + "pedalboard==0.7.4", + "prefigure==0.0.9", + "pytorch_lightning==2.1.0", + "PyWavelets==1.4.1", + "safetensors", + "sentencepiece==0.1.99", + "s3fs", + "torch>=2.0.1", + "torchaudio>=2.0.2", + "torchmetrics==0.11.4", + "tqdm", + "transformers", + "v-diffusion-pytorch==0.0.2", + "vector-quantize-pytorch==1.9.14", + "wandb==0.15.4", + "webdataset==0.2.48", + "x-transformers<1.27.0", ], -) \ No newline at end of file +) diff --git a/stable_audio_tools/__init__.py b/stable_audio_tools/__init__.py index 22446be5..09d70383 100644 --- a/stable_audio_tools/__init__.py +++ b/stable_audio_tools/__init__.py @@ -1,2 +1,2 @@ from .models.factory import create_model_from_config, create_model_from_config_path -from .models.pretrained import get_pretrained_model \ No newline at end of file +from .models.pretrained import get_pretrained_model diff --git a/stable_audio_tools/configs/dataset_configs/custom_metadata/custom_md_example.py b/stable_audio_tools/configs/dataset_configs/custom_metadata/custom_md_example.py index d7ca14ae..11da333d 100644 --- a/stable_audio_tools/configs/dataset_configs/custom_metadata/custom_md_example.py +++ b/stable_audio_tools/configs/dataset_configs/custom_metadata/custom_md_example.py @@ -1,4 +1,4 @@ def get_custom_metadata(info, audio): # Use relative path as the prompt - return {"prompt": info["relpath"]} \ No newline at end of file + return {"prompt": info["relpath"]} diff --git a/stable_audio_tools/configs/dataset_configs/local_training_example.json b/stable_audio_tools/configs/dataset_configs/local_training_example.json index 94668680..75c35970 100644 --- a/stable_audio_tools/configs/dataset_configs/local_training_example.json +++ b/stable_audio_tools/configs/dataset_configs/local_training_example.json @@ -8,4 +8,4 @@ } ], "random_crop": true -} \ No newline at end of file +} diff --git a/stable_audio_tools/configs/dataset_configs/s3_wds_example.json b/stable_audio_tools/configs/dataset_configs/s3_wds_example.json index 71e3a8b9..fab6264a 100644 --- a/stable_audio_tools/configs/dataset_configs/s3_wds_example.json +++ b/stable_audio_tools/configs/dataset_configs/s3_wds_example.json @@ -7,4 +7,4 @@ } ], "random_crop": true -} \ No newline at end of file +} diff --git a/stable_audio_tools/configs/model_configs/autoencoders/dac_2048_32_vae.json b/stable_audio_tools/configs/model_configs/autoencoders/dac_2048_32_vae.json index d0f3eba7..c3c02f36 100644 --- a/stable_audio_tools/configs/model_configs/autoencoders/dac_2048_32_vae.json +++ b/stable_audio_tools/configs/model_configs/autoencoders/dac_2048_32_vae.json @@ -68,4 +68,4 @@ "demo_every": 2000 } } -} \ No newline at end of file +} diff --git a/stable_audio_tools/configs/model_configs/autoencoders/encodec_musicgen_rvq.json b/stable_audio_tools/configs/model_configs/autoencoders/encodec_musicgen_rvq.json index e76bd3d9..b79522a6 100644 --- a/stable_audio_tools/configs/model_configs/autoencoders/encodec_musicgen_rvq.json +++ b/stable_audio_tools/configs/model_configs/autoencoders/encodec_musicgen_rvq.json @@ -85,4 +85,4 @@ "demo_every": 2000 } } -} \ No newline at end of file +} diff --git a/stable_audio_tools/configs/model_configs/autoencoders/stable_audio_1_0_vae.json b/stable_audio_tools/configs/model_configs/autoencoders/stable_audio_1_0_vae.json index 26dcb25f..2ebabf40 100644 --- a/stable_audio_tools/configs/model_configs/autoencoders/stable_audio_1_0_vae.json +++ b/stable_audio_tools/configs/model_configs/autoencoders/stable_audio_1_0_vae.json @@ -108,4 +108,4 @@ "demo_every": 2000 } } -} \ No newline at end of file +} diff --git a/stable_audio_tools/configs/model_configs/autoencoders/stable_audio_2_0_vae.json b/stable_audio_tools/configs/model_configs/autoencoders/stable_audio_2_0_vae.json index 3aa762f2..d83328b7 100644 --- a/stable_audio_tools/configs/model_configs/autoencoders/stable_audio_2_0_vae.json +++ b/stable_audio_tools/configs/model_configs/autoencoders/stable_audio_2_0_vae.json @@ -119,4 +119,4 @@ "demo_every": 2000 } } -} \ No newline at end of file +} diff --git a/stable_audio_tools/configs/model_configs/dance_diffusion/dance_diffusion_base.json b/stable_audio_tools/configs/model_configs/dance_diffusion/dance_diffusion_base.json index a57f9e4a..e6d4fe74 100644 --- a/stable_audio_tools/configs/model_configs/dance_diffusion/dance_diffusion_base.json +++ b/stable_audio_tools/configs/model_configs/dance_diffusion/dance_diffusion_base.json @@ -15,4 +15,4 @@ "demo_steps": 250 } } -} \ No newline at end of file +} diff --git a/stable_audio_tools/configs/model_configs/dance_diffusion/dance_diffusion_base_16k.json b/stable_audio_tools/configs/model_configs/dance_diffusion/dance_diffusion_base_16k.json index 4319a567..c6715db7 100644 --- a/stable_audio_tools/configs/model_configs/dance_diffusion/dance_diffusion_base_16k.json +++ b/stable_audio_tools/configs/model_configs/dance_diffusion/dance_diffusion_base_16k.json @@ -15,4 +15,4 @@ "demo_steps": 250 } } -} \ No newline at end of file +} diff --git a/stable_audio_tools/configs/model_configs/dance_diffusion/dance_diffusion_base_44k.json b/stable_audio_tools/configs/model_configs/dance_diffusion/dance_diffusion_base_44k.json index fedb83fa..decd1765 100644 --- a/stable_audio_tools/configs/model_configs/dance_diffusion/dance_diffusion_base_44k.json +++ b/stable_audio_tools/configs/model_configs/dance_diffusion/dance_diffusion_base_44k.json @@ -15,4 +15,4 @@ "demo_steps": 250 } } -} \ No newline at end of file +} diff --git a/stable_audio_tools/configs/model_configs/dance_diffusion/dance_diffusion_large.json b/stable_audio_tools/configs/model_configs/dance_diffusion/dance_diffusion_large.json index f9f96a45..9dd741bd 100644 --- a/stable_audio_tools/configs/model_configs/dance_diffusion/dance_diffusion_large.json +++ b/stable_audio_tools/configs/model_configs/dance_diffusion/dance_diffusion_large.json @@ -15,4 +15,4 @@ "demo_steps": 250 } } -} \ No newline at end of file +} diff --git a/stable_audio_tools/configs/model_configs/txt2audio/stable_audio_1_0.json b/stable_audio_tools/configs/model_configs/txt2audio/stable_audio_1_0.json index 22db891d..4c1f419c 100644 --- a/stable_audio_tools/configs/model_configs/txt2audio/stable_audio_1_0.json +++ b/stable_audio_tools/configs/model_configs/txt2audio/stable_audio_1_0.json @@ -104,4 +104,4 @@ "demo_cfg_scales": [3, 6, 9] } } -} \ No newline at end of file +} diff --git a/stable_audio_tools/configs/model_configs/txt2audio/stable_audio_2_0.json b/stable_audio_tools/configs/model_configs/txt2audio/stable_audio_2_0.json index bf8d5742..f43b74a0 100644 --- a/stable_audio_tools/configs/model_configs/txt2audio/stable_audio_2_0.json +++ b/stable_audio_tools/configs/model_configs/txt2audio/stable_audio_2_0.json @@ -124,4 +124,4 @@ "demo_cfg_scales": [3, 6, 9] } } -} \ No newline at end of file +} diff --git a/stable_audio_tools/data/dataset.py b/stable_audio_tools/data/dataset.py index 4bc535a1..d58adef0 100644 --- a/stable_audio_tools/data/dataset.py +++ b/stable_audio_tools/data/dataset.py @@ -1,5 +1,4 @@ import importlib -import numpy as np import io import os import posixpath @@ -7,33 +6,35 @@ import re import subprocess import time +from os import path +from typing import Callable, List, Optional + +import numpy as np import torch import torchaudio import webdataset as wds - from aeiou.core import is_silence -from os import path from pedalboard.io import AudioFile from torchaudio import transforms as T -from typing import Optional, Callable, List -from .utils import Stereo, Mono, PhaseFlipper, PadCrop_Normalized_T +from .utils import Mono, PadCrop_Normalized_T, PhaseFlipper, Stereo AUDIO_KEYS = ("flac", "wav", "mp3", "m4a", "ogg", "opus") # fast_scandir implementation by Scott Hawley originally in https://github.com/zqevans/audio-diffusion/blob/main/dataset/dataset.py + def fast_scandir( - dir:str, # top-level directory at which to begin scanning - ext:list, # list of allowed file extensions, - #max_size = 1 * 1000 * 1000 * 1000 # Only files < 1 GB - ): + dir: str, # top-level directory at which to begin scanning + ext: list, # list of allowed file extensions, + # max_size = 1 * 1000 * 1000 * 1000 # Only files < 1 GB +): "very fast `glob` alternative. from https://stackoverflow.com/a/59803793/4259243" subfolders, files = [], [] - ext = ['.'+x if x[0]!='.' else x for x in ext] # add starting period to extensions if needed - try: # hope to avoid 'permission denied' by this try + ext = ["." + x if x[0] != "." else x for x in ext] # add starting period to extensions if needed + try: # hope to avoid 'permission denied' by this try for f in os.scandir(dir): - try: # 'hope to avoid too many levels of symbolic links' error + try: # 'hope to avoid too many levels of symbolic links' error if f.is_dir(): subfolders.append(f.path) elif f.is_file(): @@ -43,7 +44,7 @@ def fast_scandir( if file_ext in ext and not is_hidden: files.append(f.path) except: - pass + pass except: pass @@ -53,6 +54,7 @@ def fast_scandir( files.extend(f) return subfolders, files + def keyword_scandir( dir: str, # top-level directory at which to begin scanning ext: list, # list of allowed file extensions @@ -63,7 +65,7 @@ def keyword_scandir( # make keywords case insensitive keywords = [keyword.lower() for keyword in keywords] # add starting period to extensions if needed - ext = ['.'+x if x[0] != '.' else x for x in ext] + ext = ["." + x if x[0] != "." else x for x in ext] banned_words = ["paxheader", "__macosx"] try: # hope to avoid 'permission denied' by this try for f in os.scandir(dir): @@ -71,14 +73,18 @@ def keyword_scandir( if f.is_dir(): subfolders.append(f.path) elif f.is_file(): - is_hidden = f.name.split("/")[-1][0] == '.' + is_hidden = f.name.split("/")[-1][0] == "." has_ext = os.path.splitext(f.name)[1].lower() in ext name_lower = f.name.lower() - has_keyword = any( - [keyword in name_lower for keyword in keywords]) - has_banned = any( - [banned_word in name_lower for banned_word in banned_words]) - if has_ext and has_keyword and not has_banned and not is_hidden and not os.path.basename(f.path).startswith("._"): + has_keyword = any([keyword in name_lower for keyword in keywords]) + has_banned = any([banned_word in name_lower for banned_word in banned_words]) + if ( + has_ext + and has_keyword + and not has_banned + and not is_hidden + and not os.path.basename(f.path).startswith("._") + ): files.append(f.path) except: pass @@ -91,16 +97,17 @@ def keyword_scandir( files.extend(f) return subfolders, files + def get_audio_filenames( paths: list, # directories in which to search keywords=None, - exts=['.wav', '.mp3', '.flac', '.ogg', '.aif', '.opus'] + exts=[".wav", ".mp3", ".flac", ".ogg", ".aif", ".opus"], ): "recursively get a list of audio filenames" filenames = [] if type(paths) is str: paths = [paths] - for path in paths: # get a list of relevant filenames + for path in paths: # get a list of relevant filenames if keywords is not None: subfolders, files = keyword_scandir(path, exts, keywords) else: @@ -108,26 +115,17 @@ def get_audio_filenames( filenames.extend(files) return filenames + class LocalDatasetConfig: - def __init__( - self, - id: str, - path: str, - custom_metadata_fn: Optional[Callable[[str], str]] = None - ): + def __init__(self, id: str, path: str, custom_metadata_fn: Optional[Callable[[str], str]] = None): self.id = id self.path = path self.custom_metadata_fn = custom_metadata_fn + class SampleDataset(torch.utils.data.Dataset): def __init__( - self, - configs, - sample_size=65536, - sample_rate=48000, - keywords=None, - random_crop=True, - force_channels="stereo" + self, configs, sample_size=65536, sample_rate=48000, keywords=None, random_crop=True, force_channels="stereo" ): super().__init__() self.filenames = [] @@ -157,7 +155,7 @@ def __init__( if config.custom_metadata_fn is not None: self.custom_metadata_fns[config.path] = config.custom_metadata_fn - print(f'Found {len(self.filenames)} files') + print(f"Found {len(self.filenames)} files") def load_file(self, filename): ext = filename.split(".")[-1] @@ -225,9 +223,10 @@ def __getitem__(self, idx): return (audio, info) except Exception as e: - print(f'Couldn\'t load file {audio_filename}: {e}') + print(f"Couldn't load file {audio_filename}: {e}") return self[random.randrange(len(self))] + def group_by_keys(data, keys=wds.tariterators.base_plus_ext, lcase=True, suffixes=None, handler=None): """Return function over iterator that groups key, value pairs into samples. :param keys: function that splits the key into key and extension (base_plus_ext) @@ -259,50 +258,49 @@ def group_by_keys(data, keys=wds.tariterators.base_plus_ext, lcase=True, suffixe if wds.tariterators.valid_sample(current_sample): yield current_sample + wds.tariterators.group_by_keys = group_by_keys # S3 code and WDS preprocessing code based on implementation by Scott Hawley originally in https://github.com/zqevans/audio-diffusion/blob/main/dataset/dataset.py -def get_s3_contents(dataset_path, s3_url_prefix=None, filter='', recursive=True, debug=False, profile=None): + +def get_s3_contents(dataset_path, s3_url_prefix=None, filter="", recursive=True, debug=False, profile=None): """ Returns a list of full S3 paths to files in a given S3 bucket and directory path. """ # Ensure dataset_path ends with a trailing slash - if dataset_path != '' and not dataset_path.endswith('/'): - dataset_path += '/' + if dataset_path != "" and not dataset_path.endswith("/"): + dataset_path += "/" # Use posixpath to construct the S3 URL path - bucket_path = posixpath.join(s3_url_prefix or '', dataset_path) + bucket_path = posixpath.join(s3_url_prefix or "", dataset_path) # Construct the `aws s3 ls` command - cmd = ['aws', 's3', 'ls', bucket_path] + cmd = ["aws", "s3", "ls", bucket_path] if profile is not None: - cmd.extend(['--profile', profile]) + cmd.extend(["--profile", profile]) if recursive: # Add the --recursive flag if requested - cmd.append('--recursive') - + cmd.append("--recursive") + # Run the `aws s3 ls` command and capture the output run_ls = subprocess.run(cmd, capture_output=True, check=True) # Split the output into lines and strip whitespace from each line - contents = run_ls.stdout.decode('utf-8').split('\n') + contents = run_ls.stdout.decode("utf-8").split("\n") contents = [x.strip() for x in contents if x] # Remove the timestamp from lines that begin with a timestamp - contents = [re.sub(r'^\S+\s+\S+\s+\d+\s+', '', x) - if re.match(r'^\S+\s+\S+\s+\d+\s+', x) else x for x in contents] + contents = [re.sub(r"^\S+\s+\S+\s+\d+\s+", "", x) if re.match(r"^\S+\s+\S+\s+\d+\s+", x) else x for x in contents] # Construct a full S3 path for each file in the contents list - contents = [posixpath.join(s3_url_prefix or '', x) - for x in contents if not x.endswith('/')] + contents = [posixpath.join(s3_url_prefix or "", x) for x in contents if not x.endswith("/")] # Apply the filter, if specified if filter: contents = [x for x in contents if filter in x] # Remove redundant directory names in the S3 URL if recursive: # Get the main directory name from the S3 URL - main_dir = "/".join(bucket_path.split('/')[3:]) + main_dir = "/".join(bucket_path.split("/")[3:]) # Remove the redundant directory names from each file path - contents = [x.replace(f'{main_dir}', '').replace( - '//', '/') for x in contents] + contents = [x.replace(f"{main_dir}", "").replace("//", "/") for x in contents] # Print debugging information, if requested if debug: print("contents = \n", contents) @@ -311,15 +309,15 @@ def get_s3_contents(dataset_path, s3_url_prefix=None, filter='', recursive=True, def get_all_s3_urls( - names=[], # list of all valid [LAION AudioDataset] dataset names + names=[], # list of all valid [LAION AudioDataset] dataset names # list of subsets you want from those datasets, e.g. ['train','valid'] - subsets=[''], + subsets=[""], s3_url_prefix=None, # prefix for those dataset names - recursive=True, # recursively list all tar files in all subdirs - filter_str='tar', # only grab files with this substring + recursive=True, # recursively list all tar files in all subdirs + filter_str="tar", # only grab files with this substring # print debugging info -- note: info displayed likely to change at dev's whims debug=False, - profiles={}, # dictionary of profiles for each item in names, e.g. {'dataset1': 'profile1', 'dataset2': 'profile2'} + profiles={}, # dictionary of profiles for each item in names, e.g. {'dataset1': 'profile1', 'dataset2': 'profile2'} ): "get urls of shards (tar files) for multiple datasets in one s3 bucket" urls = [] @@ -339,11 +337,11 @@ def get_all_s3_urls( # Get the list of tar files in the current subset directory profile = profiles.get(name, None) tar_list = get_s3_contents( - subset_str, s3_url_prefix=None, recursive=recursive, filter=filter_str, debug=debug, profile=profile) + subset_str, s3_url_prefix=None, recursive=recursive, filter=filter_str, debug=debug, profile=profile + ) for tar in tar_list: # Escape spaces and parentheses in the tar filename for use in the shell command - tar = tar.replace(" ", "\ ").replace( - "(", "\(").replace(")", "\)") + tar = tar.replace(" ", "\ ").replace("(", "\(").replace(")", "\)") # Construct the S3 path to the current tar file s3_path = posixpath.join(name, subset, tar) + " -" # Construct the AWS CLI command to download the current tar file @@ -374,6 +372,7 @@ def is_valid_sample(sample): return has_json and has_audio and not is_silent and not is_rejected + class S3DatasetConfig: def __init__( self, @@ -398,6 +397,7 @@ def load_data_urls(self): return self.urls + class LocalWebDatasetConfig: def __init__( self, @@ -417,6 +417,7 @@ def load_data_urls(self): return self.urls + def audio_decoder(key, value): # Get file extension from key ext = key.split(".")[-1] @@ -426,22 +427,24 @@ def audio_decoder(key, value): else: return None + def collation_fn(samples): - batched = list(zip(*samples)) - result = [] - for b in batched: - if isinstance(b[0], (int, float)): - b = np.array(b) - elif isinstance(b[0], torch.Tensor): - b = torch.stack(b) - elif isinstance(b[0], np.ndarray): - b = np.array(b) - else: - b = b - result.append(b) - return result - -class WebDatasetDataLoader(): + batched = list(zip(*samples)) + result = [] + for b in batched: + if isinstance(b[0], (int, float)): + b = np.array(b) + elif isinstance(b[0], torch.Tensor): + b = torch.stack(b) + elif isinstance(b[0], np.ndarray): + b = np.array(b) + else: + b = b + result.append(b) + return result + + +class WebDatasetDataLoader: def __init__( self, datasets: List[S3DatasetConfig], @@ -453,7 +456,7 @@ def __init__( random_crop=True, force_channels="stereo", augment_phase=True, - **data_loader_kwargs + **data_loader_kwargs, ): self.datasets = datasets @@ -479,24 +482,24 @@ def __init__( wds.map(self.wds_preprocess, handler=log_and_continue), wds.select(is_valid_sample), wds.to_tuple("audio", "json", handler=log_and_continue), - #wds.shuffle(bufsize=1000, initial=5000), + # wds.shuffle(bufsize=1000, initial=5000), wds.batched(batch_size, partial=False, collation_fn=collation_fn), - ).with_epoch(epoch_steps//num_workers if num_workers > 0 else epoch_steps) + ).with_epoch(epoch_steps // num_workers if num_workers > 0 else epoch_steps) self.data_loader = wds.WebLoader(self.dataset, num_workers=num_workers, **data_loader_kwargs) def wds_preprocess(self, sample): - found_key, rewrite_key = '', '' + found_key, rewrite_key = "", "" for k, v in sample.items(): # print the all entries in dict for akey in AUDIO_KEYS: if k.endswith(akey): # to rename long/weird key with its simpler counterpart found_key, rewrite_key = k, akey break - if '' != found_key: + if "" != found_key: break - if '' == found_key: # got no audio! + if "" == found_key: # got no audio! return None # try returning None to tell WebDataset to skip this one audio, in_sr = sample[found_key] @@ -506,10 +509,8 @@ def wds_preprocess(self, sample): if self.sample_size is not None: # Pad/crop and get the relative timestamp - pad_crop = PadCrop_Normalized_T( - self.sample_size, randomize=self.random_crop, sample_rate=self.sample_rate) - audio, t_start, t_end, seconds_start, seconds_total, padding_mask = pad_crop( - audio) + pad_crop = PadCrop_Normalized_T(self.sample_size, randomize=self.random_crop, sample_rate=self.sample_rate) + audio, t_start, t_end, seconds_start, seconds_total, padding_mask = pad_crop(audio) sample["json"]["seconds_start"] = seconds_start sample["json"]["seconds_total"] = seconds_total sample["json"]["padding_mask"] = padding_mask @@ -524,7 +525,7 @@ def wds_preprocess(self, sample): augs = torch.nn.Sequential( Stereo() if self.force_channels == "stereo" else torch.nn.Identity(), Mono() if self.force_channels == "mono" else torch.nn.Identity(), - PhaseFlipper() if self.augment_phase else torch.nn.Identity() + PhaseFlipper() if self.augment_phase else torch.nn.Identity(), ) audio = augs(audio) @@ -538,22 +539,25 @@ def wds_preprocess(self, sample): for dataset in self.datasets: if dataset.custom_metadata_fn is None: continue - + if dataset.path in sample["__url__"]: custom_metadata = dataset.custom_metadata_fn(sample["json"], audio) sample["json"].update(custom_metadata) - if found_key != rewrite_key: # rename long/weird key with its simpler counterpart + if found_key != rewrite_key: # rename long/weird key with its simpler counterpart del sample[found_key] sample["audio"] = audio # Add audio to the metadata as well for conditioning sample["json"]["audio"] = audio - + return sample -def create_dataloader_from_config(dataset_config, batch_size, sample_size, sample_rate, audio_channels=2, num_workers=4): + +def create_dataloader_from_config( + dataset_config, batch_size, sample_size, sample_rate, audio_channels=2, num_workers=4 +): dataset_type = dataset_config.get("dataset_type", None) @@ -568,7 +572,7 @@ def create_dataloader_from_config(dataset_config, batch_size, sample_size, sampl audio_dir_configs = dataset_config.get("datasets", None) - assert audio_dir_configs is not None, "Directory configuration must be specified in datasets[\"dataset\"]" + assert audio_dir_configs is not None, 'Directory configuration must be specified in datasets["dataset"]' configs = [] @@ -582,15 +586,13 @@ def create_dataloader_from_config(dataset_config, batch_size, sample_size, sampl if custom_metadata_module_path is not None: spec = importlib.util.spec_from_file_location("metadata_module", custom_metadata_module_path) metadata_module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(metadata_module) + spec.loader.exec_module(metadata_module) custom_metadata_fn = metadata_module.get_custom_metadata configs.append( LocalDatasetConfig( - id=audio_dir_config["id"], - path=audio_dir_path, - custom_metadata_fn=custom_metadata_fn + id=audio_dir_config["id"], path=audio_dir_path, custom_metadata_fn=custom_metadata_fn ) ) @@ -599,13 +601,21 @@ def create_dataloader_from_config(dataset_config, batch_size, sample_size, sampl sample_rate=sample_rate, sample_size=sample_size, random_crop=dataset_config.get("random_crop", True), - force_channels=force_channels + force_channels=force_channels, ) - return torch.utils.data.DataLoader(train_set, batch_size, shuffle=True, - num_workers=num_workers, persistent_workers=True, pin_memory=True, drop_last=True, collate_fn=collation_fn) + return torch.utils.data.DataLoader( + train_set, + batch_size, + shuffle=True, + num_workers=num_workers, + persistent_workers=True, + pin_memory=True, + drop_last=True, + collate_fn=collation_fn, + ) - elif dataset_type in ["s3", "wds"]: # Support "s3" type for backwards compatibility + elif dataset_type in ["s3", "wds"]: # Support "s3" type for backwards compatibility wds_configs = [] for wds_config in dataset_config["datasets"]: @@ -616,7 +626,7 @@ def create_dataloader_from_config(dataset_config, batch_size, sample_size, sampl if custom_metadata_module_path is not None: spec = importlib.util.spec_from_file_location("metadata_module", custom_metadata_module_path) metadata_module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(metadata_module) + spec.loader.exec_module(metadata_module) custom_metadata_fn = metadata_module.get_custom_metadata @@ -630,16 +640,14 @@ def create_dataloader_from_config(dataset_config, batch_size, sample_size, sampl profile=wds_config.get("profile", None), ) ) - + elif "path" in wds_config: - - wds_configs.append( - LocalWebDatasetConfig( - id=wds_config["id"], - path=wds_config["path"], - custom_metadata_fn=custom_metadata_fn - ) + + wds_configs.append( + LocalWebDatasetConfig( + id=wds_config["id"], path=wds_config["path"], custom_metadata_fn=custom_metadata_fn ) + ) return WebDatasetDataLoader( wds_configs, @@ -650,5 +658,5 @@ def create_dataloader_from_config(dataset_config, batch_size, sample_size, sampl num_workers=num_workers, persistent_workers=True, force_channels=force_channels, - epoch_steps=dataset_config.get("epoch_steps", 2000) - ).data_loader \ No newline at end of file + epoch_steps=dataset_config.get("epoch_steps", 2000), + ).data_loader diff --git a/stable_audio_tools/data/utils.py b/stable_audio_tools/data/utils.py index 848012e4..da82e9b7 100644 --- a/stable_audio_tools/data/utils.py +++ b/stable_audio_tools/data/utils.py @@ -1,9 +1,10 @@ import math import random -import torch +from typing import Tuple +import torch from torch import nn -from typing import Tuple + class PadCrop(nn.Module): def __init__(self, n_samples, randomize=True): @@ -16,29 +17,30 @@ def __call__(self, signal): start = 0 if (not self.randomize) else torch.randint(0, max(0, s - self.n_samples) + 1, []).item() end = start + self.n_samples output = signal.new_zeros([n, self.n_samples]) - output[:, :min(s, self.n_samples)] = signal[:, start:end] + output[:, : min(s, self.n_samples)] = signal[:, start:end] return output + class PadCrop_Normalized_T(nn.Module): - + def __init__(self, n_samples: int, sample_rate: int, randomize: bool = True): - + super().__init__() - + self.n_samples = n_samples self.sample_rate = sample_rate self.randomize = randomize def __call__(self, source: torch.Tensor) -> Tuple[torch.Tensor, float, float, int, int]: - + n_channels, n_samples = source.shape - + # If the audio is shorter than the desired length, pad it upper_bound = max(0, n_samples - self.n_samples) - + # If randomize is False, always start at the beginning of the audio offset = 0 - if(self.randomize and n_samples > self.n_samples): + if self.randomize and n_samples > self.n_samples: offset = random.randint(0, upper_bound) # Calculate the start and end times of the chunk @@ -49,48 +51,45 @@ def __call__(self, source: torch.Tensor) -> Tuple[torch.Tensor, float, float, in chunk = source.new_zeros([n_channels, self.n_samples]) # Copy the audio into the chunk - chunk[:, :min(n_samples, self.n_samples)] = source[:, offset:offset + self.n_samples] - + chunk[:, : min(n_samples, self.n_samples)] = source[:, offset : offset + self.n_samples] + # Calculate the start and end times of the chunk in seconds seconds_start = math.floor(offset / self.sample_rate) seconds_total = math.ceil(n_samples / self.sample_rate) # Create a mask the same length as the chunk with 1s where the audio is and 0s where it isn't padding_mask = torch.zeros([self.n_samples]) - padding_mask[:min(n_samples, self.n_samples)] = 1 - - - return ( - chunk, - t_start, - t_end, - seconds_start, - seconds_total, - padding_mask - ) + padding_mask[: min(n_samples, self.n_samples)] = 1 + + return (chunk, t_start, t_end, seconds_start, seconds_total, padding_mask) + class PhaseFlipper(nn.Module): "Randomly invert the phase of a signal" + def __init__(self, p=0.5): super().__init__() self.p = p + def __call__(self, signal): return -signal if (random.random() < self.p) else signal - + + class Mono(nn.Module): - def __call__(self, signal): - return torch.mean(signal, dim=0, keepdims=True) if len(signal.shape) > 1 else signal + def __call__(self, signal): + return torch.mean(signal, dim=0, keepdims=True) if len(signal.shape) > 1 else signal + class Stereo(nn.Module): - def __call__(self, signal): - signal_shape = signal.shape - # Check if it's mono - if len(signal_shape) == 1: # s -> 2, s - signal = signal.unsqueeze(0).repeat(2, 1) - elif len(signal_shape) == 2: - if signal_shape[0] == 1: #1, s -> 2, s - signal = signal.repeat(2, 1) - elif signal_shape[0] > 2: #?, s -> 2,s - signal = signal[:2, :] - - return signal + def __call__(self, signal): + signal_shape = signal.shape + # Check if it's mono + if len(signal_shape) == 1: # s -> 2, s + signal = signal.unsqueeze(0).repeat(2, 1) + elif len(signal_shape) == 2: + if signal_shape[0] == 1: # 1, s -> 2, s + signal = signal.repeat(2, 1) + elif signal_shape[0] > 2: # ?, s -> 2,s + signal = signal[:2, :] + + return signal diff --git a/stable_audio_tools/inference/generation.py b/stable_audio_tools/inference/generation.py index 843ab4b7..95bd5366 100644 --- a/stable_audio_tools/inference/generation.py +++ b/stable_audio_tools/inference/generation.py @@ -1,33 +1,35 @@ -import numpy as np -import torch +import math import typing as tp -import math + +import numpy as np +import torch from torchaudio import transforms as T -from .utils import prepare_audio -from .sampling import sample, sample_k, sample_rf from ..data.utils import PadCrop +from .sampling import sample, sample_k, sample_rf +from .utils import prepare_audio + def generate_diffusion_uncond( - model, - steps: int = 250, - batch_size: int = 1, - sample_size: int = 2097152, - seed: int = -1, - device: str = "cuda", - init_audio: tp.Optional[tp.Tuple[int, torch.Tensor]] = None, - init_noise_level: float = 1.0, - return_latents = False, - **sampler_kwargs - ) -> torch.Tensor: - - # The length of the output in audio samples + model, + steps: int = 250, + batch_size: int = 1, + sample_size: int = 2097152, + seed: int = -1, + device: str = "cuda", + init_audio: tp.Optional[tp.Tuple[int, torch.Tensor]] = None, + init_noise_level: float = 1.0, + return_latents=False, + **sampler_kwargs +) -> torch.Tensor: + + # The length of the output in audio samples audio_sample_size = sample_size # If this is latent diffusion, change sample_size instead to the downsampled latent size if model.pretransform is not None: sample_size = sample_size // model.pretransform.downsampling_ratio - + # Seed # The user can explicitly set the seed to deterministically generate the same output. Otherwise, use a random seed. seed = seed if seed != -1 else np.random.randint(0, 2**32 - 1, dtype=np.uint32) @@ -47,7 +49,14 @@ def generate_diffusion_uncond( io_channels = model.pretransform.io_channels # Prepare the initial audio for use by the model - init_audio = prepare_audio(init_audio, in_sr=in_sr, target_sr=model.sample_rate, target_length=audio_sample_size, target_channels=io_channels, device=device) + init_audio = prepare_audio( + init_audio, + in_sr=in_sr, + target_sr=model.sample_rate, + target_length=audio_sample_size, + target_channels=io_channels, + device=device, + ) # For latent models, encode the initial audio into latents if model.pretransform is not None: @@ -55,16 +64,16 @@ def generate_diffusion_uncond( init_audio = init_audio.repeat(batch_size, 1, 1) else: - # The user did not supply any initial audio for inpainting or variation. Generate new output from scratch. + # The user did not supply any initial audio for inpainting or variation. Generate new output from scratch. init_audio = None init_noise_level = None # Inpainting mask - + if init_audio is not None: # variations sampler_kwargs["sigma_max"] = init_noise_level - mask = None + mask = None else: mask = None @@ -72,13 +81,13 @@ def generate_diffusion_uncond( diff_objective = model.diffusion_objective - if diff_objective == "v": + if diff_objective == "v": # k-diffusion denoising process go! sampled = sample_k(model.model, noise, init_audio, mask, steps, **sampler_kwargs, device=device) elif diff_objective == "rectified_flow": sampled = sample_rf(model.model, noise, init_data=init_audio, steps=steps, **sampler_kwargs, device=device) - # Denoising process done. + # Denoising process done. # If this is latent diffusion, decode latents back into audio if model.pretransform is not None and not return_latents: sampled = model.pretransform.decode(sampled) @@ -88,31 +97,31 @@ def generate_diffusion_uncond( def generate_diffusion_cond( - model, - steps: int = 250, - cfg_scale=6, - conditioning: dict = None, - conditioning_tensors: tp.Optional[dict] = None, - negative_conditioning: dict = None, - negative_conditioning_tensors: tp.Optional[dict] = None, - batch_size: int = 1, - sample_size: int = 2097152, - sample_rate: int = 48000, - seed: int = -1, - device: str = "cuda", - init_audio: tp.Optional[tp.Tuple[int, torch.Tensor]] = None, - init_noise_level: float = 1.0, - mask_args: dict = None, - return_latents = False, - **sampler_kwargs - ) -> torch.Tensor: + model, + steps: int = 250, + cfg_scale=6, + conditioning: dict = None, + conditioning_tensors: tp.Optional[dict] = None, + negative_conditioning: dict = None, + negative_conditioning_tensors: tp.Optional[dict] = None, + batch_size: int = 1, + sample_size: int = 2097152, + sample_rate: int = 48000, + seed: int = -1, + device: str = "cuda", + init_audio: tp.Optional[tp.Tuple[int, torch.Tensor]] = None, + init_noise_level: float = 1.0, + mask_args: dict = None, + return_latents=False, + **sampler_kwargs +) -> torch.Tensor: """ Generate audio from a prompt using a diffusion model. - + Args: model: The diffusion model to use for generation. steps: The number of diffusion steps to use. - cfg_scale: Classifier-free guidance scale + cfg_scale: Classifier-free guidance scale conditioning: A dictionary of conditioning parameters to use for generation. conditioning_tensors: A dictionary of precomputed conditioning tensors to use for generation. batch_size: The batch size to use for generation. @@ -123,16 +132,16 @@ def generate_diffusion_cond( init_audio: A tuple of (sample_rate, audio) to use as the initial audio for generation. init_noise_level: The noise level to use when generating from an initial audio sample. return_latents: Whether to return the latents used for generation instead of the decoded audio. - **sampler_kwargs: Additional keyword arguments to pass to the sampler. + **sampler_kwargs: Additional keyword arguments to pass to the sampler. """ - # The length of the output in audio samples + # The length of the output in audio samples audio_sample_size = sample_size # If this is latent diffusion, change sample_size instead to the downsampled latent size if model.pretransform is not None: sample_size = sample_size // model.pretransform.downsampling_ratio - + # Seed # The user can explicitly set the seed to deterministically generate the same output. Otherwise, use a random seed. seed = seed if seed != -1 else np.random.randint(0, 2**32 - 1, dtype=np.uint32) @@ -147,16 +156,18 @@ def generate_diffusion_cond( torch.backends.cudnn.benchmark = False # Conditioning - assert conditioning is not None or conditioning_tensors is not None, "Must provide either conditioning or conditioning_tensors" + assert ( + conditioning is not None or conditioning_tensors is not None + ), "Must provide either conditioning or conditioning_tensors" if conditioning_tensors is None: conditioning_tensors = model.conditioner(conditioning, device) conditioning_inputs = model.get_conditioning_inputs(conditioning_tensors) if negative_conditioning is not None or negative_conditioning_tensors is not None: - + if negative_conditioning_tensors is None: negative_conditioning_tensors = model.conditioner(negative_conditioning, device) - + negative_conditioning_tensors = model.get_conditioning_inputs(negative_conditioning_tensors, negative=True) else: negative_conditioning_tensors = {} @@ -172,7 +183,14 @@ def generate_diffusion_cond( io_channels = model.pretransform.io_channels # Prepare the initial audio for use by the model - init_audio = prepare_audio(init_audio, in_sr=in_sr, target_sr=model.sample_rate, target_length=audio_sample_size, target_channels=io_channels, device=device) + init_audio = prepare_audio( + init_audio, + in_sr=in_sr, + target_sr=model.sample_rate, + target_length=audio_sample_size, + target_channels=io_channels, + device=device, + ) # For latent models, encode the initial audio into latents if model.pretransform is not None: @@ -180,7 +198,7 @@ def generate_diffusion_cond( init_audio = init_audio.repeat(batch_size, 1, 1) else: - # The user did not supply any initial audio for inpainting or variation. Generate new output from scratch. + # The user did not supply any initial audio for inpainting or variation. Generate new output from scratch. init_audio = None init_noise_level = None mask_args = None @@ -189,18 +207,18 @@ def generate_diffusion_cond( if init_audio is not None and mask_args is not None: # Cut and paste init_audio according to cropfrom, pastefrom, pasteto # This is helpful for forward and reverse outpainting - cropfrom = math.floor(mask_args["cropfrom"]/100.0 * sample_size) - pastefrom = math.floor(mask_args["pastefrom"]/100.0 * sample_size) - pasteto = math.ceil(mask_args["pasteto"]/100.0 * sample_size) + cropfrom = math.floor(mask_args["cropfrom"] / 100.0 * sample_size) + pastefrom = math.floor(mask_args["pastefrom"] / 100.0 * sample_size) + pasteto = math.ceil(mask_args["pasteto"] / 100.0 * sample_size) assert pastefrom < pasteto, "Paste From should be less than Paste To" croplen = pasteto - pastefrom if cropfrom + croplen > sample_size: - croplen = sample_size - cropfrom + croplen = sample_size - cropfrom cropto = cropfrom + croplen pasteto = pastefrom + croplen cutpaste = init_audio.new_zeros(init_audio.shape) - cutpaste[:, :, pastefrom:pasteto] = init_audio[:,:,cropfrom:cropto] - #print(cropfrom, cropto, pastefrom, pasteto) + cutpaste[:, :, pastefrom:pasteto] = init_audio[:, :, cropfrom:cropto] + # print(cropfrom, cropto, pastefrom, pasteto) init_audio = cutpaste # Build a soft mask (list of floats 0 to 1, the size of the latent) from the given args mask = build_mask(sample_size, mask_args) @@ -208,7 +226,7 @@ def generate_diffusion_cond( elif init_audio is not None and mask_args is None: # variations sampler_kwargs["sigma_max"] = init_noise_level - mask = None + mask = None else: mask = None @@ -220,9 +238,22 @@ def generate_diffusion_cond( diff_objective = model.diffusion_objective - if diff_objective == "v": + if diff_objective == "v": # k-diffusion denoising process go! - sampled = sample_k(model.model, noise, init_audio, mask, steps, **sampler_kwargs, **conditioning_inputs, **negative_conditioning_tensors, cfg_scale=cfg_scale, batch_cfg=True, rescale_cfg=True, device=device) + sampled = sample_k( + model.model, + noise, + init_audio, + mask, + steps, + **sampler_kwargs, + **conditioning_inputs, + **negative_conditioning_tensors, + cfg_scale=cfg_scale, + batch_cfg=True, + rescale_cfg=True, + device=device + ) elif diff_objective == "rectified_flow": if "sigma_min" in sampler_kwargs: @@ -231,44 +262,57 @@ def generate_diffusion_cond( if "sampler_type" in sampler_kwargs: del sampler_kwargs["sampler_type"] - sampled = sample_rf(model.model, noise, init_data=init_audio, steps=steps, **sampler_kwargs, **conditioning_inputs, **negative_conditioning_tensors, cfg_scale=cfg_scale, batch_cfg=True, rescale_cfg=True, device=device) - - # v-diffusion: - #sampled = sample(model.model, noise, steps, 0, **conditioning_tensors, embedding_scale=cfg_scale) + sampled = sample_rf( + model.model, + noise, + init_data=init_audio, + steps=steps, + **sampler_kwargs, + **conditioning_inputs, + **negative_conditioning_tensors, + cfg_scale=cfg_scale, + batch_cfg=True, + rescale_cfg=True, + device=device + ) + + # v-diffusion: + # sampled = sample(model.model, noise, steps, 0, **conditioning_tensors, embedding_scale=cfg_scale) del noise del conditioning_tensors del conditioning_inputs torch.cuda.empty_cache() - # Denoising process done. + # Denoising process done. # If this is latent diffusion, decode latents back into audio if model.pretransform is not None and not return_latents: - #cast sampled latents to pretransform dtype + # cast sampled latents to pretransform dtype sampled = sampled.to(next(model.pretransform.parameters()).dtype) sampled = model.pretransform.decode(sampled) # Return audio return sampled + # builds a softmask given the parameters -# returns array of values 0 to 1, size sample_size, where 0 means noise / fresh generation, 1 means keep the input audio, +# returns array of values 0 to 1, size sample_size, where 0 means noise / fresh generation, 1 means keep the input audio, # and anything between is a mixture of old/new # ideally 0.5 is half/half mixture but i haven't figured this out yet def build_mask(sample_size, mask_args): - maskstart = math.floor(mask_args["maskstart"]/100.0 * sample_size) - maskend = math.ceil(mask_args["maskend"]/100.0 * sample_size) - softnessL = round(mask_args["softnessL"]/100.0 * sample_size) - softnessR = round(mask_args["softnessR"]/100.0 * sample_size) + maskstart = math.floor(mask_args["maskstart"] / 100.0 * sample_size) + maskend = math.ceil(mask_args["maskend"] / 100.0 * sample_size) + softnessL = round(mask_args["softnessL"] / 100.0 * sample_size) + softnessR = round(mask_args["softnessR"] / 100.0 * sample_size) marination = mask_args["marination"] # use hann windows for softening the transition (i don't know if this is correct) - hannL = torch.hann_window(softnessL*2, periodic=False)[:softnessL] - hannR = torch.hann_window(softnessR*2, periodic=False)[softnessR:] - # build the mask. + hannL = torch.hann_window(softnessL * 2, periodic=False)[:softnessL] + hannR = torch.hann_window(softnessR * 2, periodic=False)[softnessR:] + # build the mask. mask = torch.zeros((sample_size)) mask[maskstart:maskend] = 1 - mask[maskstart:maskstart+softnessL] = hannL - mask[maskend-softnessR:maskend] = hannR + mask[maskstart : maskstart + softnessL] = hannL + mask[maskend - softnessR : maskend] = hannR # marination finishes the inpainting early in the denoising schedule, and lets audio get changed in the final rounds - if marination > 0: - mask = mask * (1-marination) - #print(mask) + if marination > 0: + mask = mask * (1 - marination) + # print(mask) return mask diff --git a/stable_audio_tools/inference/sampling.py b/stable_audio_tools/inference/sampling.py index 2229e508..7aefa1ab 100644 --- a/stable_audio_tools/inference/sampling.py +++ b/stable_audio_tools/inference/sampling.py @@ -1,8 +1,9 @@ -import torch import math -from tqdm import trange, tqdm import k_diffusion as K +import torch +from tqdm import tqdm, trange + # Define the noise schedule and sampling loop def get_alphas_sigmas(t): @@ -10,11 +11,13 @@ def get_alphas_sigmas(t): noise (sigma), given a timestep.""" return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2) + def alpha_sigma_to_t(alpha, sigma): """Returns a timestep, given the scaling factors for the clean image and for the noise.""" return torch.atan2(sigma, alpha) / math.pi * 2 + def t_to_alpha_sigma(t): """Returns the scaling factors for the clean image and for the noise, given a timestep.""" @@ -31,19 +34,18 @@ def sample_discrete_euler(model, x, steps, sigma_max=1, **extra_args): # Create the noise schedule t = torch.linspace(sigma_max, 0, steps + 1) - #alphas, sigmas = 1-t, t + # alphas, sigmas = 1-t, t for t_curr, t_prev in tqdm(zip(t[:-1], t[1:])): - # Broadcast the current timestep to the correct shape - t_curr_tensor = t_curr * torch.ones( - (x.shape[0],), dtype=x.dtype, device=x.device - ) - dt = t_prev - t_curr # we solve backwards in our formulation - x = x + dt * model(x, t_curr_tensor, **extra_args) #.denoise(x, denoiser, t_curr_tensor, cond, uc) + # Broadcast the current timestep to the correct shape + t_curr_tensor = t_curr * torch.ones((x.shape[0],), dtype=x.dtype, device=x.device) + dt = t_prev - t_curr # we solve backwards in our formulation + x = x + dt * model(x, t_curr_tensor, **extra_args) # .denoise(x, denoiser, t_curr_tensor, cond, uc) # If we are on the last timestep, output the denoised image return x + @torch.no_grad() def sample(model, x, steps, eta, **extra_args): """Draws samples from a model given starting noise. v-diffusion""" @@ -70,9 +72,10 @@ def sample(model, x, steps, eta, **extra_args): if i < steps - 1: # If eta > 0, adjust the scaling factor for the predicted noise # downward according to the amount of additional noise to add - ddim_sigma = eta * (sigmas[i + 1]**2 / sigmas[i]**2).sqrt() * \ - (1 - alphas[i]**2 / alphas[i + 1]**2).sqrt() - adjusted_sigma = (sigmas[i + 1]**2 - ddim_sigma**2).sqrt() + ddim_sigma = ( + eta * (sigmas[i + 1] ** 2 / sigmas[i] ** 2).sqrt() * (1 - alphas[i] ** 2 / alphas[i + 1] ** 2).sqrt() + ) + adjusted_sigma = (sigmas[i + 1] ** 2 - ddim_sigma**2).sqrt() # Recombine the predicted noise and predicted denoised image in the # correct proportions for the next step @@ -85,14 +88,16 @@ def sample(model, x, steps, eta, **extra_args): # If we are on the last timestep, output the denoised image return pred + # Soft mask inpainting is just shrinking hard (binary) mask inpainting # Given a float-valued soft mask (values between 0 and 1), get the binary mask for this particular step def get_bmask(i, steps, mask): - strength = (i+1)/(steps) + strength = (i + 1) / (steps) # convert to binary mask - bmask = torch.where(mask<=strength,1,0) + bmask = torch.where(mask <= strength, 1, 0) return bmask + def make_cond_model_fn(model, cond_fn): def cond_model_fn(x, sigma, **kwargs): with torch.enable_grad(): @@ -101,27 +106,30 @@ def cond_model_fn(x, sigma, **kwargs): cond_grad = cond_fn(x, sigma, denoised=denoised, **kwargs).detach() cond_denoised = denoised.detach() + cond_grad * K.utils.append_dims(sigma**2, x.ndim) return cond_denoised + return cond_model_fn + # Uses k-diffusion from https://github.com/crowsonkb/k-diffusion # init_data is init_audio as latents (if this is latent diffusion) # For sampling, set both init_data and mask to None -# For variations, set init_data -# For inpainting, set both init_data & mask +# For variations, set init_data +# For inpainting, set both init_data & mask def sample_k( - model_fn, - noise, - init_data=None, - mask=None, - steps=100, - sampler_type="dpmpp-2m-sde", - sigma_min=0.5, - sigma_max=50, - rho=1.0, device="cuda", - callback=None, - cond_fn=None, - **extra_args - ): + model_fn, + noise, + init_data=None, + mask=None, + steps=100, + sampler_type="dpmpp-2m-sde", + sigma_min=0.5, + sigma_max=50, + rho=1.0, + device="cuda", + callback=None, + cond_fn=None, + **extra_args +): denoiser = K.external.VDenoiser(model_fn) @@ -130,7 +138,7 @@ def sample_k( # Make the list of sigmas. Sigma values are scalars related to the amount of noise each denoising step has sigmas = K.sampling.get_sigmas_polyexponential(steps, sigma_min, sigma_max, rho, device=device) - # Scale the initial noise by sigma + # Scale the initial noise by sigma noise = noise * sigmas[0] wrapped_callback = callback @@ -138,14 +146,15 @@ def sample_k( if mask is None and init_data is not None: # VARIATION (no inpainting) # set the initial latent to the init_data, and noise it with initial sigma - x = init_data + noise + x = init_data + noise elif mask is not None and init_data is not None: # INPAINTING bmask = get_bmask(0, steps, mask) # initial noising input_noised = init_data + noise # set the initial latent to a mix of init_data and noise, based on step 0's binary mask - x = input_noised * bmask + noise * (1-bmask) + x = input_noised * bmask + noise * (1 - bmask) + # define the inpainting callback function (Note: side effects, it mutates x) # See https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py#L596C13-L596C105 # callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) @@ -154,17 +163,18 @@ def inpainting_callback(args): i = args["i"] x = args["x"] sigma = args["sigma"] - #denoised = args["denoised"] + # denoised = args["denoised"] # noise the init_data input with this step's appropriate amount of noise input_noised = init_data + torch.randn_like(init_data) * sigma # shrinking hard mask bmask = get_bmask(i, steps, mask) # mix input_noise with x, using binary mask - new_x = input_noised * bmask + x * (1-bmask) + new_x = input_noised * bmask + x * (1 - bmask) # mutate x - x[:,:,:] = new_x[:,:,:] - # wrap together the inpainting callback and the user-submitted callback. - if callback is None: + x[:, :, :] = new_x[:, :, :] + + # wrap together the inpainting callback and the user-submitted callback. + if callback is None: wrapped_callback = inpainting_callback else: wrapped_callback = lambda args: (inpainting_callback(args), callback(args)) @@ -173,41 +183,64 @@ def inpainting_callback(args): # set the initial latent to noise x = noise - with torch.cuda.amp.autocast(): if sampler_type == "k-heun": - return K.sampling.sample_heun(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) + return K.sampling.sample_heun( + denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args + ) elif sampler_type == "k-lms": - return K.sampling.sample_lms(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) + return K.sampling.sample_lms( + denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args + ) elif sampler_type == "k-dpmpp-2s-ancestral": - return K.sampling.sample_dpmpp_2s_ancestral(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) + return K.sampling.sample_dpmpp_2s_ancestral( + denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args + ) elif sampler_type == "k-dpm-2": - return K.sampling.sample_dpm_2(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) + return K.sampling.sample_dpm_2( + denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args + ) elif sampler_type == "k-dpm-fast": - return K.sampling.sample_dpm_fast(denoiser, x, sigma_min, sigma_max, steps, disable=False, callback=wrapped_callback, extra_args=extra_args) + return K.sampling.sample_dpm_fast( + denoiser, + x, + sigma_min, + sigma_max, + steps, + disable=False, + callback=wrapped_callback, + extra_args=extra_args, + ) elif sampler_type == "k-dpm-adaptive": - return K.sampling.sample_dpm_adaptive(denoiser, x, sigma_min, sigma_max, rtol=0.01, atol=0.01, disable=False, callback=wrapped_callback, extra_args=extra_args) + return K.sampling.sample_dpm_adaptive( + denoiser, + x, + sigma_min, + sigma_max, + rtol=0.01, + atol=0.01, + disable=False, + callback=wrapped_callback, + extra_args=extra_args, + ) elif sampler_type == "dpmpp-2m-sde": - return K.sampling.sample_dpmpp_2m_sde(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) + return K.sampling.sample_dpmpp_2m_sde( + denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args + ) elif sampler_type == "dpmpp-3m-sde": - return K.sampling.sample_dpmpp_3m_sde(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) + return K.sampling.sample_dpmpp_3m_sde( + denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args + ) + # Uses discrete Euler sampling for rectified flow models # init_data is init_audio as latents (if this is latent diffusion) # For sampling, set both init_data and mask to None -# For variations, set init_data -# For inpainting, set both init_data & mask +# For variations, set init_data +# For inpainting, set both init_data & mask def sample_rf( - model_fn, - noise, - init_data=None, - steps=100, - sigma_max=1, - device="cuda", - callback=None, - cond_fn=None, - **extra_args - ): + model_fn, noise, init_data=None, steps=100, sigma_max=1, device="cuda", callback=None, cond_fn=None, **extra_args +): if sigma_max > 1: sigma_max = 1 @@ -228,5 +261,5 @@ def sample_rf( with torch.cuda.amp.autocast(): # TODO: Add callback support - #return sample_discrete_euler(model_fn, x, steps, sigma_max, callback=wrapped_callback, **extra_args) - return sample_discrete_euler(model_fn, x, steps, sigma_max, **extra_args) \ No newline at end of file + # return sample_discrete_euler(model_fn, x, steps, sigma_max, callback=wrapped_callback, **extra_args) + return sample_discrete_euler(model_fn, x, steps, sigma_max, **extra_args) diff --git a/stable_audio_tools/inference/utils.py b/stable_audio_tools/inference/utils.py index 6a6c0a57..2b12df24 100644 --- a/stable_audio_tools/inference/utils.py +++ b/stable_audio_tools/inference/utils.py @@ -1,6 +1,7 @@ +from torchaudio import transforms as T + from ..data.utils import PadCrop -from torchaudio import transforms as T def set_audio_channels(audio, target_channels): if target_channels == 1: @@ -14,8 +15,9 @@ def set_audio_channels(audio, target_channels): audio = audio[:, :2, :] return audio + def prepare_audio(audio, in_sr, target_sr, target_length, target_channels, device): - + audio = audio.to(device) if in_sr != target_sr: @@ -32,4 +34,4 @@ def prepare_audio(audio, in_sr, target_sr, target_length, target_channels, devic audio = set_audio_channels(audio, target_channels) - return audio \ No newline at end of file + return audio diff --git a/stable_audio_tools/interface/gradio.py b/stable_audio_tools/interface/gradio.py index f38468bc..158014ce 100644 --- a/stable_audio_tools/interface/gradio.py +++ b/stable_audio_tools/interface/gradio.py @@ -1,12 +1,11 @@ import gc +import json import platform -import numpy as np import gradio as gr -import json +import numpy as np import torch import torchaudio - from aeiou.viz import audio_spectrogram_image from einops import rearrange from safetensors.torch import load_file @@ -14,19 +13,27 @@ from torchaudio import transforms as T from ..inference.generation import generate_diffusion_cond, generate_diffusion_uncond +from ..inference.utils import prepare_audio from ..models.factory import create_model_from_config from ..models.pretrained import get_pretrained_model from ..models.utils import load_ckpt_state_dict -from ..inference.utils import prepare_audio from ..training.utils import copy_state_dict model = None sample_rate = 32000 sample_size = 1920000 -def load_model(model_config=None, model_ckpt_path=None, pretrained_name=None, pretransform_ckpt_path=None, device="cuda", model_half=False): + +def load_model( + model_config=None, + model_ckpt_path=None, + pretrained_name=None, + pretransform_ckpt_path=None, + device="cuda", + model_half=False, +): global model, sample_rate, sample_size - + if pretrained_name is not None: print(f"Loading pretrained model {pretrained_name}") model, model_config = get_pretrained_model(pretrained_name) @@ -38,7 +45,7 @@ def load_model(model_config=None, model_ckpt_path=None, pretrained_name=None, pr print(f"Loading model checkpoint from {model_ckpt_path}") # Load checkpoint copy_state_dict(model, load_ckpt_state_dict(model_ckpt_path)) - #model.load_state_dict(load_ckpt_state_dict(model_ckpt_path)) + # model.load_state_dict(load_ckpt_state_dict(model_ckpt_path)) sample_rate = model_config["sample_rate"] sample_size = model_config["sample_size"] @@ -52,37 +59,38 @@ def load_model(model_config=None, model_ckpt_path=None, pretrained_name=None, pr if model_half: model.to(torch.float16) - + print(f"Done loading model") return model, model_config + def generate_cond( - prompt, - negative_prompt=None, - seconds_start=0, - seconds_total=30, - cfg_scale=6.0, - steps=250, - preview_every=None, - seed=-1, - sampler_type="dpmpp-3m-sde", - sigma_min=0.03, - sigma_max=1000, - cfg_rescale=0.0, - use_init=False, - init_audio=None, - init_noise_level=1.0, - mask_cropfrom=None, - mask_pastefrom=None, - mask_pasteto=None, - mask_maskstart=None, - mask_maskend=None, - mask_softnessL=None, - mask_softnessR=None, - mask_marination=None, - batch_size=1 - ): + prompt, + negative_prompt=None, + seconds_start=0, + seconds_total=30, + cfg_scale=6.0, + steps=250, + preview_every=None, + seed=-1, + sampler_type="dpmpp-3m-sde", + sigma_min=0.03, + sigma_max=1000, + cfg_rescale=0.0, + use_init=False, + init_audio=None, + init_noise_level=1.0, + mask_cropfrom=None, + mask_pastefrom=None, + mask_pasteto=None, + mask_maskstart=None, + mask_maskend=None, + mask_softnessL=None, + mask_softnessR=None, + mask_marination=None, + batch_size=1, +): if torch.cuda.is_available(): torch.cuda.empty_cache() @@ -99,29 +107,31 @@ def generate_cond( conditioning = [{"prompt": prompt, "seconds_start": seconds_start, "seconds_total": seconds_total}] * batch_size if negative_prompt: - negative_conditioning = [{"prompt": negative_prompt, "seconds_start": seconds_start, "seconds_total": seconds_total}] * batch_size + negative_conditioning = [ + {"prompt": negative_prompt, "seconds_start": seconds_start, "seconds_total": seconds_total} + ] * batch_size else: negative_conditioning = None - - #Get the device from the model + + # Get the device from the model device = next(model.parameters()).device seed = int(seed) if not use_init: init_audio = None - + input_sample_size = sample_size if init_audio is not None: in_sr, init_audio = init_audio # Turn into torch tensor, converting from int16 to float32 init_audio = torch.from_numpy(init_audio).float().div(32767) - + if init_audio.dim() == 1: - init_audio = init_audio.unsqueeze(0) # [1, n] + init_audio = init_audio.unsqueeze(0) # [1, n] elif init_audio.dim() == 2: - init_audio = init_audio.transpose(0, 1) # [n, 2] -> [2, n] + init_audio = init_audio.transpose(0, 1) # [n, 2] -> [2, n] if in_sr != sample_rate: resample_tf = T.Resample(in_sr, sample_rate).to(init_audio.device) @@ -131,7 +141,10 @@ def generate_cond( if audio_length > sample_size: - input_sample_size = audio_length + (model.min_input_length - (audio_length % model.min_input_length)) % model.min_input_length + input_sample_size = ( + audio_length + + (model.min_input_length - (audio_length % model.min_input_length)) % model.min_input_length + ) init_audio = (sample_rate, init_audio) @@ -151,7 +164,7 @@ def progress_callback(callback_info): # If inpainting, send mask args # This will definitely change in the future - if mask_cropfrom is not None: + if mask_cropfrom is not None: mask_args = { "cropfrom": mask_cropfrom, "pastefrom": mask_pastefrom, @@ -163,11 +176,11 @@ def progress_callback(callback_info): "marination": mask_marination, } else: - mask_args = None + mask_args = None # Do the audio generation audio = generate_diffusion_cond( - model, + model, conditioning=conditioning, negative_conditioning=negative_conditioning, steps=steps, @@ -182,9 +195,9 @@ def progress_callback(callback_info): sigma_max=sigma_max, init_audio=init_audio, init_noise_level=init_noise_level, - mask_args = mask_args, - callback = progress_callback if preview_every is not None else None, - scale_phi = cfg_rescale + mask_args=mask_args, + callback=progress_callback if preview_every is not None else None, + scale_phi=cfg_rescale, ) # Convert to WAV file @@ -197,18 +210,19 @@ def progress_callback(callback_info): return ("output.wav", [audio_spectrogram, *preview_images]) + def generate_uncond( - steps=250, - seed=-1, - sampler_type="dpmpp-3m-sde", - sigma_min=0.03, - sigma_max=1000, - use_init=False, - init_audio=None, - init_noise_level=1.0, - batch_size=1, - preview_every=None - ): + steps=250, + seed=-1, + sampler_type="dpmpp-3m-sde", + sigma_min=0.03, + sigma_max=1000, + use_init=False, + init_audio=None, + init_noise_level=1.0, + batch_size=1, + preview_every=None, +): global preview_images @@ -218,25 +232,25 @@ def generate_uncond( torch.cuda.empty_cache() gc.collect() - #Get the device from the model + # Get the device from the model device = next(model.parameters()).device seed = int(seed) if not use_init: init_audio = None - + input_sample_size = sample_size if init_audio is not None: in_sr, init_audio = init_audio # Turn into torch tensor, converting from int16 to float32 init_audio = torch.from_numpy(init_audio).float().div(32767) - + if init_audio.dim() == 1: - init_audio = init_audio.unsqueeze(0) # [1, n] + init_audio = init_audio.unsqueeze(0) # [1, n] elif init_audio.dim() == 2: - init_audio = init_audio.transpose(0, 1) # [n, 2] -> [2, n] + init_audio = init_audio.transpose(0, 1) # [n, 2] -> [2, n] if in_sr != sample_rate: resample_tf = T.Resample(in_sr, sample_rate).to(init_audio.device) @@ -246,7 +260,10 @@ def generate_uncond( if audio_length > sample_size: - input_sample_size = audio_length + (model.min_input_length - (audio_length % model.min_input_length)) % model.min_input_length + input_sample_size = ( + audio_length + + (model.min_input_length - (audio_length % model.min_input_length)) % model.min_input_length + ) init_audio = (sample_rate, init_audio) @@ -270,7 +287,7 @@ def progress_callback(callback_info): preview_images.append((audio_spectrogram, f"Step {current_step} sigma={sigma:.3f})")) audio = generate_diffusion_uncond( - model, + model, steps=steps, batch_size=batch_size, sample_size=input_sample_size, @@ -281,7 +298,7 @@ def progress_callback(callback_info): sigma_max=sigma_max, init_audio=init_audio, init_noise_level=init_noise_level, - callback = progress_callback if preview_every is not None else None + callback=progress_callback if preview_every is not None else None, ) audio = rearrange(audio, "b d n -> d (b n)") @@ -294,28 +311,29 @@ def progress_callback(callback_info): return ("output.wav", [audio_spectrogram, *preview_images]) + def generate_lm( - temperature=1.0, - top_p=0.95, - top_k=0, - batch_size=1, - ): - + temperature=1.0, + top_p=0.95, + top_k=0, + batch_size=1, +): + if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() - #Get the device from the model + # Get the device from the model device = next(model.parameters()).device audio = model.generate_audio( batch_size=batch_size, - max_gen_len = sample_size//model.pretransform.downsampling_ratio, + max_gen_len=sample_size // model.pretransform.downsampling_ratio, conditioning=None, temp=temperature, top_p=top_p, top_k=top_k, - use_cache=True + use_cache=True, ) audio = rearrange(audio, "b d n -> d (b n)") @@ -329,61 +347,75 @@ def generate_lm( return ("output.wav", [audio_spectrogram]) -def create_uncond_sampling_ui(model_config): - generate_button = gr.Button("Generate", variant='primary', scale=1) - +def create_uncond_sampling_ui(model_config): + generate_button = gr.Button("Generate", variant="primary", scale=1) + with gr.Row(equal_height=False): - with gr.Column(): + with gr.Column(): with gr.Row(): # Steps slider steps_slider = gr.Slider(minimum=1, maximum=500, step=1, value=100, label="Steps") with gr.Accordion("Sampler params", open=False): - + # Seed seed_textbox = gr.Textbox(label="Seed (set to -1 for random seed)", value="-1") - # Sampler params + # Sampler params with gr.Row(): - sampler_type_dropdown = gr.Dropdown(["dpmpp-2m-sde", "dpmpp-3m-sde", "k-heun", "k-lms", "k-dpmpp-2s-ancestral", "k-dpm-2", "k-dpm-fast"], label="Sampler type", value="dpmpp-3m-sde") + sampler_type_dropdown = gr.Dropdown( + [ + "dpmpp-2m-sde", + "dpmpp-3m-sde", + "k-heun", + "k-lms", + "k-dpmpp-2s-ancestral", + "k-dpm-2", + "k-dpm-fast", + ], + label="Sampler type", + value="dpmpp-3m-sde", + ) sigma_min_slider = gr.Slider(minimum=0.0, maximum=2.0, step=0.01, value=0.03, label="Sigma min") sigma_max_slider = gr.Slider(minimum=0.0, maximum=1000.0, step=0.1, value=500, label="Sigma max") with gr.Accordion("Init audio", open=False): init_audio_checkbox = gr.Checkbox(label="Use init audio") init_audio_input = gr.Audio(label="Init audio") - init_noise_level_slider = gr.Slider(minimum=0.0, maximum=100.0, step=0.01, value=0.1, label="Init noise level") + init_noise_level_slider = gr.Slider( + minimum=0.0, maximum=100.0, step=0.01, value=0.1, label="Init noise level" + ) with gr.Column(): audio_output = gr.Audio(label="Output audio", interactive=False) audio_spectrogram_output = gr.Gallery(label="Output spectrogram", show_label=False) send_to_init_button = gr.Button("Send to init audio", scale=1) send_to_init_button.click(fn=lambda audio: audio, inputs=[audio_output], outputs=[init_audio_input]) - - generate_button.click(fn=generate_uncond, + + generate_button.click( + fn=generate_uncond, inputs=[ - steps_slider, - seed_textbox, - sampler_type_dropdown, - sigma_min_slider, + steps_slider, + seed_textbox, + sampler_type_dropdown, + sigma_min_slider, sigma_max_slider, init_audio_checkbox, init_audio_input, init_noise_level_slider, - ], - outputs=[ - audio_output, - audio_spectrogram_output - ], - api_name="generate") + ], + outputs=[audio_output, audio_spectrogram_output], + api_name="generate", + ) + def create_sampling_ui(model_config, inpainting=False): with gr.Row(): with gr.Column(scale=6): prompt = gr.Textbox(show_label=False, placeholder="Prompt") negative_prompt = gr.Textbox(show_label=False, placeholder="Negative prompt") - generate_button = gr.Button("Generate", variant='primary', scale=1) - + generate_button = gr.Button("Generate", variant="primary", scale=1) + model_conditioning_config = model_config["model"].get("conditioning", None) has_seconds_start = False @@ -398,11 +430,20 @@ def create_sampling_ui(model_config, inpainting=False): with gr.Row(equal_height=False): with gr.Column(): - with gr.Row(visible = has_seconds_start or has_seconds_total): + with gr.Row(visible=has_seconds_start or has_seconds_total): # Timing controls - seconds_start_slider = gr.Slider(minimum=0, maximum=512, step=1, value=0, label="Seconds start", visible=has_seconds_start) - seconds_total_slider = gr.Slider(minimum=0, maximum=512, step=1, value=sample_size//sample_rate, label="Seconds total", visible=has_seconds_total) - + seconds_start_slider = gr.Slider( + minimum=0, maximum=512, step=1, value=0, label="Seconds start", visible=has_seconds_start + ) + seconds_total_slider = gr.Slider( + minimum=0, + maximum=512, + step=1, + value=sample_size // sample_rate, + label="Seconds total", + visible=has_seconds_total, + ) + with gr.Row(): # Steps slider steps_slider = gr.Slider(minimum=1, maximum=500, step=1, value=100, label="Steps") @@ -410,50 +451,77 @@ def create_sampling_ui(model_config, inpainting=False): # Preview Every slider preview_every_slider = gr.Slider(minimum=0, maximum=100, step=1, value=0, label="Preview Every") - # CFG scale + # CFG scale cfg_scale_slider = gr.Slider(minimum=0.0, maximum=25.0, step=0.1, value=7.0, label="CFG scale") with gr.Accordion("Sampler params", open=False): - + # Seed seed_textbox = gr.Textbox(label="Seed (set to -1 for random seed)", value="-1") # Sampler params with gr.Row(): - sampler_type_dropdown = gr.Dropdown(["dpmpp-2m-sde", "dpmpp-3m-sde", "k-heun", "k-lms", "k-dpmpp-2s-ancestral", "k-dpm-2", "k-dpm-fast"], label="Sampler type", value="dpmpp-3m-sde") + sampler_type_dropdown = gr.Dropdown( + [ + "dpmpp-2m-sde", + "dpmpp-3m-sde", + "k-heun", + "k-lms", + "k-dpmpp-2s-ancestral", + "k-dpm-2", + "k-dpm-fast", + ], + label="Sampler type", + value="dpmpp-3m-sde", + ) sigma_min_slider = gr.Slider(minimum=0.0, maximum=2.0, step=0.01, value=0.03, label="Sigma min") sigma_max_slider = gr.Slider(minimum=0.0, maximum=1000.0, step=0.1, value=500, label="Sigma max") - cfg_rescale_slider = gr.Slider(minimum=0.0, maximum=1, step=0.01, value=0.0, label="CFG rescale amount") + cfg_rescale_slider = gr.Slider( + minimum=0.0, maximum=1, step=0.01, value=0.0, label="CFG rescale amount" + ) - if inpainting: + if inpainting: # Inpainting Tab with gr.Accordion("Inpainting", open=False): - sigma_max_slider.maximum=1000 - + sigma_max_slider.maximum = 1000 + init_audio_checkbox = gr.Checkbox(label="Do inpainting") init_audio_input = gr.Audio(label="Init audio") - init_noise_level_slider = gr.Slider(minimum=0.1, maximum=100.0, step=0.1, value=80, label="Init audio noise level", visible=False) # hide this + init_noise_level_slider = gr.Slider( + minimum=0.1, maximum=100.0, step=0.1, value=80, label="Init audio noise level", visible=False + ) # hide this mask_cropfrom_slider = gr.Slider(minimum=0.0, maximum=100.0, step=0.1, value=0, label="Crop From %") - mask_pastefrom_slider = gr.Slider(minimum=0.0, maximum=100.0, step=0.1, value=0, label="Paste From %") + mask_pastefrom_slider = gr.Slider( + minimum=0.0, maximum=100.0, step=0.1, value=0, label="Paste From %" + ) mask_pasteto_slider = gr.Slider(minimum=0.0, maximum=100.0, step=0.1, value=100, label="Paste To %") - mask_maskstart_slider = gr.Slider(minimum=0.0, maximum=100.0, step=0.1, value=50, label="Mask Start %") + mask_maskstart_slider = gr.Slider( + minimum=0.0, maximum=100.0, step=0.1, value=50, label="Mask Start %" + ) mask_maskend_slider = gr.Slider(minimum=0.0, maximum=100.0, step=0.1, value=100, label="Mask End %") - mask_softnessL_slider = gr.Slider(minimum=0.0, maximum=100.0, step=0.1, value=0, label="Softmask Left Crossfade Length %") - mask_softnessR_slider = gr.Slider(minimum=0.0, maximum=100.0, step=0.1, value=0, label="Softmask Right Crossfade Length %") - mask_marination_slider = gr.Slider(minimum=0.0, maximum=1, step=0.0001, value=0, label="Marination level", visible=False) # still working on the usefulness of this - - inputs = [prompt, + mask_softnessL_slider = gr.Slider( + minimum=0.0, maximum=100.0, step=0.1, value=0, label="Softmask Left Crossfade Length %" + ) + mask_softnessR_slider = gr.Slider( + minimum=0.0, maximum=100.0, step=0.1, value=0, label="Softmask Right Crossfade Length %" + ) + mask_marination_slider = gr.Slider( + minimum=0.0, maximum=1, step=0.0001, value=0, label="Marination level", visible=False + ) # still working on the usefulness of this + + inputs = [ + prompt, negative_prompt, - seconds_start_slider, - seconds_total_slider, - cfg_scale_slider, - steps_slider, - preview_every_slider, - seed_textbox, - sampler_type_dropdown, - sigma_min_slider, + seconds_start_slider, + seconds_total_slider, + cfg_scale_slider, + steps_slider, + preview_every_slider, + seed_textbox, + sampler_type_dropdown, + sigma_min_slider, sigma_max_slider, cfg_rescale_slider, init_audio_checkbox, @@ -466,30 +534,33 @@ def create_sampling_ui(model_config, inpainting=False): mask_maskend_slider, mask_softnessL_slider, mask_softnessR_slider, - mask_marination_slider + mask_marination_slider, ] else: # Default generation tab with gr.Accordion("Init audio", open=False): init_audio_checkbox = gr.Checkbox(label="Use init audio") init_audio_input = gr.Audio(label="Init audio") - init_noise_level_slider = gr.Slider(minimum=0.1, maximum=100.0, step=0.01, value=0.1, label="Init noise level") + init_noise_level_slider = gr.Slider( + minimum=0.1, maximum=100.0, step=0.01, value=0.1, label="Init noise level" + ) - inputs = [prompt, + inputs = [ + prompt, negative_prompt, - seconds_start_slider, - seconds_total_slider, - cfg_scale_slider, - steps_slider, - preview_every_slider, - seed_textbox, - sampler_type_dropdown, - sigma_min_slider, + seconds_start_slider, + seconds_total_slider, + cfg_scale_slider, + steps_slider, + preview_every_slider, + seed_textbox, + sampler_type_dropdown, + sigma_min_slider, sigma_max_slider, cfg_rescale_slider, init_audio_checkbox, init_audio_input, - init_noise_level_slider + init_noise_level_slider, ] with gr.Column(): @@ -497,36 +568,34 @@ def create_sampling_ui(model_config, inpainting=False): audio_spectrogram_output = gr.Gallery(label="Output spectrogram", show_label=False) send_to_init_button = gr.Button("Send to init audio", scale=1) send_to_init_button.click(fn=lambda audio: audio, inputs=[audio_output], outputs=[init_audio_input]) - - generate_button.click(fn=generate_cond, - inputs=inputs, - outputs=[ - audio_output, - audio_spectrogram_output - ], - api_name="generate") + + generate_button.click( + fn=generate_cond, inputs=inputs, outputs=[audio_output, audio_spectrogram_output], api_name="generate" + ) def create_txt2audio_ui(model_config): with gr.Blocks() as ui: with gr.Tab("Generation"): - create_sampling_ui(model_config) + create_sampling_ui(model_config) with gr.Tab("Inpainting"): - create_sampling_ui(model_config, inpainting=True) + create_sampling_ui(model_config, inpainting=True) return ui + def create_diffusion_uncond_ui(model_config): with gr.Blocks() as ui: create_uncond_sampling_ui(model_config) - + return ui + def autoencoder_process(audio, latent_noise, n_quantizers): if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() - #Get the device from the model + # Get the device from the model device = next(model.parameters()).device in_sr, audio = audio @@ -539,10 +608,10 @@ def autoencoder_process(audio, latent_noise, n_quantizers): audio = audio.transpose(0, 1) audio = model.preprocess_audio_for_encoder(audio, in_sr) - # Note: If you need to do chunked encoding, to reduce VRAM, + # Note: If you need to do chunked encoding, to reduce VRAM, # then add these arguments to encode_audio and decode_audio: chunked=True, overlap=32, chunk_size=128 # To turn it off, do chunked=False - # Optimal overlap and chunk_size values will depend on the model. + # Optimal overlap and chunk_size values will depend on the model. # See encode_audio & decode_audio in autoencoders.py for more info # Get dtype of model dtype = next(model.parameters()).dtype @@ -567,9 +636,14 @@ def autoencoder_process(audio, latent_noise, n_quantizers): return "output.wav" + def create_autoencoder_ui(model_config): - is_dac_rvq = "model" in model_config and "bottleneck" in model_config["model"] and model_config["model"]["bottleneck"]["type"] in ["dac_rvq","dac_rvq_vae"] + is_dac_rvq = ( + "model" in model_config + and "bottleneck" in model_config["model"] + and model_config["model"]["bottleneck"]["type"] in ["dac_rvq", "dac_rvq_vae"] + ) if is_dac_rvq: n_quantizers = model_config["model"]["bottleneck"]["config"]["n_codebooks"] @@ -579,34 +653,47 @@ def create_autoencoder_ui(model_config): with gr.Blocks() as ui: input_audio = gr.Audio(label="Input audio") output_audio = gr.Audio(label="Output audio", interactive=False) - n_quantizers_slider = gr.Slider(minimum=1, maximum=n_quantizers, step=1, value=n_quantizers, label="# quantizers", visible=is_dac_rvq) + n_quantizers_slider = gr.Slider( + minimum=1, maximum=n_quantizers, step=1, value=n_quantizers, label="# quantizers", visible=is_dac_rvq + ) latent_noise_slider = gr.Slider(minimum=0.0, maximum=10.0, step=0.001, value=0.0, label="Add latent noise") - process_button = gr.Button("Process", variant='primary', scale=1) - process_button.click(fn=autoencoder_process, inputs=[input_audio, latent_noise_slider, n_quantizers_slider], outputs=output_audio, api_name="process") + process_button = gr.Button("Process", variant="primary", scale=1) + process_button.click( + fn=autoencoder_process, + inputs=[input_audio, latent_noise_slider, n_quantizers_slider], + outputs=output_audio, + api_name="process", + ) return ui + def diffusion_prior_process(audio, steps, sampler_type, sigma_min, sigma_max): if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() - #Get the device from the model + # Get the device from the model device = next(model.parameters()).device in_sr, audio = audio audio = torch.from_numpy(audio).float().div(32767).to(device) - + if audio.dim() == 1: - audio = audio.unsqueeze(0) # [1, n] + audio = audio.unsqueeze(0) # [1, n] elif audio.dim() == 2: - audio = audio.transpose(0, 1) # [n, 2] -> [2, n] + audio = audio.transpose(0, 1) # [n, 2] -> [2, n] audio = audio.unsqueeze(0) - audio = model.stereoize(audio, in_sr, steps, sampler_kwargs={"sampler_type": sampler_type, "sigma_min": sigma_min, "sigma_max": sigma_max}) + audio = model.stereoize( + audio, + in_sr, + steps, + sampler_kwargs={"sampler_type": sampler_type, "sigma_min": sigma_min, "sigma_max": sigma_max}, + ) audio = rearrange(audio, "b d n -> d (b n)") @@ -616,6 +703,7 @@ def diffusion_prior_process(audio, steps, sampler_type, sigma_min, sigma_max): return "output.wav" + def create_diffusion_prior_ui(model_config): with gr.Blocks() as ui: input_audio = gr.Audio(label="Input audio") @@ -623,14 +711,24 @@ def create_diffusion_prior_ui(model_config): # Sampler params with gr.Row(): steps_slider = gr.Slider(minimum=1, maximum=500, step=1, value=100, label="Steps") - sampler_type_dropdown = gr.Dropdown(["dpmpp-2m-sde", "dpmpp-3m-sde", "k-heun", "k-lms", "k-dpmpp-2s-ancestral", "k-dpm-2", "k-dpm-fast"], label="Sampler type", value="dpmpp-3m-sde") + sampler_type_dropdown = gr.Dropdown( + ["dpmpp-2m-sde", "dpmpp-3m-sde", "k-heun", "k-lms", "k-dpmpp-2s-ancestral", "k-dpm-2", "k-dpm-fast"], + label="Sampler type", + value="dpmpp-3m-sde", + ) sigma_min_slider = gr.Slider(minimum=0.0, maximum=2.0, step=0.01, value=0.03, label="Sigma min") sigma_max_slider = gr.Slider(minimum=0.0, maximum=1000.0, step=0.1, value=500, label="Sigma max") - process_button = gr.Button("Process", variant='primary', scale=1) - process_button.click(fn=diffusion_prior_process, inputs=[input_audio, steps_slider, sampler_type_dropdown, sigma_min_slider, sigma_max_slider], outputs=output_audio, api_name="process") + process_button = gr.Button("Process", variant="primary", scale=1) + process_button.click( + fn=diffusion_prior_process, + inputs=[input_audio, steps_slider, sampler_type_dropdown, sigma_min_slider, sigma_max_slider], + outputs=output_audio, + api_name="process", + ) return ui + def create_lm_ui(model_config): with gr.Blocks() as ui: output_audio = gr.Audio(label="Output audio", interactive=False) @@ -642,23 +740,24 @@ def create_lm_ui(model_config): top_p_slider = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.95, label="Top p") top_k_slider = gr.Slider(minimum=0, maximum=100, step=1, value=0, label="Top k") - generate_button = gr.Button("Generate", variant='primary', scale=1) + generate_button = gr.Button("Generate", variant="primary", scale=1) generate_button.click( - fn=generate_lm, - inputs=[ - temperature_slider, - top_p_slider, - top_k_slider - ], + fn=generate_lm, + inputs=[temperature_slider, top_p_slider, top_k_slider], outputs=[output_audio, audio_spectrogram_output], - api_name="generate" + api_name="generate", ) return ui -def create_ui(model_config_path=None, ckpt_path=None, pretrained_name=None, pretransform_ckpt_path=None, model_half=False): - assert (pretrained_name is not None) ^ (model_config_path is not None and ckpt_path is not None), "Must specify either pretrained name or provide a model config and checkpoint, but not both" +def create_ui( + model_config_path=None, ckpt_path=None, pretrained_name=None, pretransform_ckpt_path=None, model_half=False +): + + assert (pretrained_name is not None) ^ ( + model_config_path is not None and ckpt_path is not None + ), "Must specify either pretrained name or provide a model config and checkpoint, but not both" if model_config_path is not None: # Load config from json file @@ -682,8 +781,15 @@ def create_ui(model_config_path=None, ckpt_path=None, pretrained_name=None, pret print("Using device:", device) - _, model_config = load_model(model_config, ckpt_path, pretrained_name=pretrained_name, pretransform_ckpt_path=pretransform_ckpt_path, model_half=model_half, device=device) - + _, model_config = load_model( + model_config, + ckpt_path, + pretrained_name=pretrained_name, + pretransform_ckpt_path=pretransform_ckpt_path, + model_half=model_half, + device=device, + ) + model_type = model_config["model_type"] if model_type == "diffusion_cond": @@ -696,5 +802,5 @@ def create_ui(model_config_path=None, ckpt_path=None, pretrained_name=None, pret ui = create_diffusion_prior_ui(model_config) elif model_type == "lm": ui = create_lm_ui(model_config) - + return ui diff --git a/stable_audio_tools/models/__init__.py b/stable_audio_tools/models/__init__.py index 7e27bbcb..ebcd45ba 100644 --- a/stable_audio_tools/models/__init__.py +++ b/stable_audio_tools/models/__init__.py @@ -1 +1 @@ -from .factory import create_model_from_config, create_model_from_config_path \ No newline at end of file +from .factory import create_model_from_config, create_model_from_config_path diff --git a/stable_audio_tools/models/adp.py b/stable_audio_tools/models/adp.py index 49eb526a..ca568fdf 100644 --- a/stable_audio_tools/models/adp.py +++ b/stable_audio_tools/models/adp.py @@ -3,19 +3,19 @@ import math from inspect import isfunction -from math import ceil, floor, log, pi, log2 +from math import ceil, floor, log, log2, pi from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union -from packaging import version import torch import torch.nn as nn +from dac.nn.layers import Snake1d from einops import rearrange, reduce, repeat from einops.layers.torch import Rearrange from einops_exts import rearrange_many +from packaging import version from torch import Tensor, einsum from torch.backends.cuda import sdp_kernel from torch.nn import functional as F -from dac.nn.layers import Snake1d """ Utils @@ -32,22 +32,27 @@ def forward(self, x: Tensor, mapping: Optional[Tensor] = None): x = module(x, mapping) return x + T = TypeVar("T") + def default(val: Optional[T], d: Union[Callable[..., T], T]) -> T: if exists(val): return val return d() if isfunction(d) else d + def exists(val: Optional[T]) -> T: return val is not None + def closest_power_2(x: float) -> int: exponent = log2(x) - distance_fn = lambda z: abs(x - 2 ** z) # noqa + distance_fn = lambda z: abs(x - 2**z) # noqa exponent_closest = min((floor(exponent), ceil(exponent)), key=distance_fn) return 2 ** int(exponent_closest) + def group_dict_by_prefix(prefix: str, d: Dict) -> Tuple[Dict, Dict]: return_dicts: Tuple[Dict, Dict] = ({}, {}) for key in d.keys(): @@ -55,6 +60,7 @@ def group_dict_by_prefix(prefix: str, d: Dict) -> Tuple[Dict, Dict]: return_dicts[no_prefix][key] = d[key] return return_dicts + def groupby(prefix: str, d: Dict, keep_prefix: bool = False) -> Tuple[Dict, Dict]: kwargs_with_prefix, kwargs = group_dict_by_prefix(prefix, d) if keep_prefix: @@ -62,6 +68,7 @@ def groupby(prefix: str, d: Dict, keep_prefix: bool = False) -> Tuple[Dict, Dict kwargs_no_prefix = {k[len(prefix) :]: v for k, v in kwargs_with_prefix.items()} return kwargs_no_prefix, kwargs + """ Convolutional Blocks """ @@ -70,8 +77,8 @@ def groupby(prefix: str, d: Dict, keep_prefix: bool = False) -> Tuple[Dict, Dict # Copied from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/modules/conv.py under MIT License # License available in LICENSES/LICENSE_META.txt -def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, - padding_total: int = 0) -> int: + +def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0) -> int: """See `pad_for_conv1d`.""" length = x.shape[-1] n_frames = (length - kernel_size + padding_total) / stride + 1 @@ -94,14 +101,14 @@ def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total return F.pad(x, (0, extra_padding)) -def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'constant', value: float = 0.): +def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = "constant", value: float = 0.0): """Tiny wrapper around F.pad, just to allow for reflect padding on small input. If this is the case, we insert extra 0 padding to the right before the reflection happen. """ length = x.shape[-1] padding_left, padding_right = paddings assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) - if mode == 'reflect': + if mode == "reflect": max_pad = max(padding_left, padding_right) extra_pad = 0 if length <= max_pad: @@ -120,13 +127,13 @@ def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]): assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) assert (padding_left + padding_right) <= x.shape[-1] end = x.shape[-1] - padding_right - return x[..., padding_left: end] + return x[..., padding_left:end] class Conv1d(nn.Conv1d): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - + def forward(self, x: Tensor, causal=False) -> Tensor: kernel_size = self.kernel_size[0] stride = self.stride[0] @@ -143,7 +150,8 @@ def forward(self, x: Tensor, causal=False) -> Tensor: padding_left = padding_total - padding_right x = pad1d(x, (padding_left, padding_right + extra_padding)) return super().forward(x) - + + class ConvTranspose1d(nn.ConvTranspose1d): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -169,45 +177,29 @@ def forward(self, x: Tensor, causal=False) -> Tensor: padding_left = padding_total - padding_right y = unpad1d(y, (padding_left, padding_right)) return y - -def Downsample1d( - in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2 -) -> nn.Module: + +def Downsample1d(in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2) -> nn.Module: assert kernel_multiplier % 2 == 0, "Kernel multiplier must be even" return Conv1d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=factor * kernel_multiplier + 1, - stride=factor + in_channels=in_channels, out_channels=out_channels, kernel_size=factor * kernel_multiplier + 1, stride=factor ) -def Upsample1d( - in_channels: int, out_channels: int, factor: int, use_nearest: bool = False -) -> nn.Module: +def Upsample1d(in_channels: int, out_channels: int, factor: int, use_nearest: bool = False) -> nn.Module: if factor == 1: - return Conv1d( - in_channels=in_channels, out_channels=out_channels, kernel_size=3 - ) + return Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=3) if use_nearest: return nn.Sequential( nn.Upsample(scale_factor=factor, mode="nearest"), - Conv1d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=3 - ), + Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=3), ) else: return ConvTranspose1d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=factor * 2, - stride=factor + in_channels=in_channels, out_channels=out_channels, kernel_size=factor * 2, stride=factor ) @@ -222,20 +214,16 @@ def __init__( dilation: int = 1, num_groups: int = 8, use_norm: bool = True, - use_snake: bool = False + use_snake: bool = False, ) -> None: super().__init__() - self.groupnorm = ( - nn.GroupNorm(num_groups=num_groups, num_channels=in_channels) - if use_norm - else nn.Identity() - ) + self.groupnorm = nn.GroupNorm(num_groups=num_groups, num_channels=in_channels) if use_norm else nn.Identity() if use_snake: self.activation = Snake1d(in_channels) else: - self.activation = nn.SiLU() + self.activation = nn.SiLU() self.project = Conv1d( in_channels=in_channels, @@ -245,9 +233,7 @@ def __init__( dilation=dilation, ) - def forward( - self, x: Tensor, scale_shift: Optional[Tuple[Tensor, Tensor]] = None, causal=False - ) -> Tensor: + def forward(self, x: Tensor, scale_shift: Optional[Tuple[Tensor, Tensor]] = None, causal=False) -> Tensor: x = self.groupnorm(x) if exists(scale_shift): scale, shift = scale_shift @@ -302,21 +288,19 @@ def __init__( dilation=dilation, use_norm=use_norm, num_groups=num_groups, - use_snake=use_snake + use_snake=use_snake, ) if self.use_mapping: assert exists(context_mapping_features) - self.to_scale_shift = MappingToScaleShift( - features=context_mapping_features, channels=out_channels - ) + self.to_scale_shift = MappingToScaleShift(features=context_mapping_features, channels=out_channels) self.block2 = ConvBlock1d( in_channels=out_channels, out_channels=out_channels, use_norm=use_norm, num_groups=num_groups, - use_snake=use_snake + use_snake=use_snake, ) self.to_out = ( @@ -359,7 +343,7 @@ def __init__( out_channels=out_channels // patch_size, num_groups=1, context_mapping_features=context_mapping_features, - use_snake=use_snake + use_snake=use_snake, ) def forward(self, x: Tensor, mapping: Optional[Tensor] = None, causal=False) -> Tensor: @@ -375,7 +359,7 @@ def __init__( out_channels: int, patch_size: int, context_mapping_features: Optional[int] = None, - use_snake: bool = False + use_snake: bool = False, ): super().__init__() assert_message = f"in_channels must be divisible by patch_size ({patch_size})" @@ -387,7 +371,7 @@ def __init__( out_channels=out_channels, num_groups=1, context_mapping_features=context_mapping_features, - use_snake=use_snake + use_snake=use_snake, ) def forward(self, x: Tensor, mapping: Optional[Tensor] = None, causal=False) -> Tensor: @@ -399,6 +383,8 @@ def forward(self, x: Tensor, mapping: Optional[Tensor] = None, causal=False) -> """ Attention Components """ + + def FeedForward(features: int, multiplier: int) -> nn.Module: mid_features = features * multiplier return nn.Sequential( @@ -407,6 +393,7 @@ def FeedForward(features: int, multiplier: int) -> nn.Module: nn.Linear(in_features=mid_features, out_features=features), ) + def add_mask(sim: Tensor, mask: Tensor) -> Tensor: b, ndim = sim.shape[0], mask.ndim if ndim == 3: @@ -417,12 +404,14 @@ def add_mask(sim: Tensor, mask: Tensor) -> Tensor: sim = sim.masked_fill(~mask, max_neg_value) return sim + def causal_mask(q: Tensor, k: Tensor) -> Tensor: b, i, j, device = q.shape[0], q.shape[-2], k.shape[-2], q.device mask = ~torch.ones((i, j), dtype=torch.bool, device=device).triu(j - i + 1) mask = repeat(mask, "n m -> b n m", b=b) return mask + class AttentionBase(nn.Module): def __init__( self, @@ -438,16 +427,14 @@ def __init__( mid_features = head_features * num_heads out_features = default(out_features, features) - self.to_out = nn.Linear( - in_features=mid_features, out_features=out_features - ) + self.to_out = nn.Linear(in_features=mid_features, out_features=out_features) - self.use_flash = torch.cuda.is_available() and version.parse(torch.__version__) >= version.parse('2.0.0') + self.use_flash = torch.cuda.is_available() and version.parse(torch.__version__) >= version.parse("2.0.0") if not self.use_flash: return - device_properties = torch.cuda.get_device_properties(torch.device('cuda')) + device_properties = torch.cuda.get_device_properties(torch.device("cuda")) if device_properties.major == 8 and device_properties.minor == 0: # Use flash attention for A100 GPUs @@ -483,6 +470,7 @@ def forward( out = rearrange(out, "b h n d -> b n (h d)") return self.to_out(out) + class Attention(nn.Module): def __init__( self, @@ -502,12 +490,8 @@ def __init__( self.norm = nn.LayerNorm(features) self.norm_context = nn.LayerNorm(context_features) - self.to_q = nn.Linear( - in_features=features, out_features=mid_features, bias=False - ) - self.to_kv = nn.Linear( - in_features=context_features, out_features=mid_features * 2, bias=False - ) + self.to_q = nn.Linear(in_features=features, out_features=mid_features, bias=False) + self.to_kv = nn.Linear(in_features=context_features, out_features=mid_features * 2, bias=False) self.attention = AttentionBase( features, num_heads=num_heads, @@ -517,8 +501,8 @@ def __init__( def forward( self, - x: Tensor, # [b, n, c] - context: Optional[Tensor] = None, # [b, m, d] + x: Tensor, # [b, n, c] + context: Optional[Tensor] = None, # [b, m, d] context_mask: Optional[Tensor] = None, # [b, m], false is masked, causal: Optional[bool] = False, ) -> Tensor: @@ -548,6 +532,7 @@ def FeedForward(features: int, multiplier: int) -> nn.Module: nn.Linear(in_features=mid_features, out_features=features), ) + """ Transformer Blocks """ @@ -566,23 +551,23 @@ def __init__( self.use_cross_attention = exists(context_features) and context_features > 0 - self.attention = Attention( - features=features, - num_heads=num_heads, - head_features=head_features - ) + self.attention = Attention(features=features, num_heads=num_heads, head_features=head_features) if self.use_cross_attention: self.cross_attention = Attention( - features=features, - num_heads=num_heads, - head_features=head_features, - context_features=context_features + features=features, num_heads=num_heads, head_features=head_features, context_features=context_features ) self.feed_forward = FeedForward(features=features, multiplier=multiplier) - def forward(self, x: Tensor, *, context: Optional[Tensor] = None, context_mask: Optional[Tensor] = None, causal: Optional[bool] = False) -> Tensor: + def forward( + self, + x: Tensor, + *, + context: Optional[Tensor] = None, + context_mask: Optional[Tensor] = None, + causal: Optional[bool] = False, + ) -> Tensor: x = self.attention(x, causal=causal) + x if self.use_cross_attention: x = self.cross_attention(x, context=context, context_mask=context_mask) + x @@ -639,7 +624,9 @@ def __init__( ), ) - def forward(self, x: Tensor, *, context: Optional[Tensor] = None, context_mask: Optional[Tensor] = None, causal=False) -> Tensor: + def forward( + self, x: Tensor, *, context: Optional[Tensor] = None, context_mask: Optional[Tensor] = None, causal=False + ) -> Tensor: x = self.to_in(x) for block in self.blocks: x = block(x, context=context, context_mask=context_mask, causal=causal) @@ -739,17 +726,14 @@ def __init__( out_channels=channels, num_groups=num_groups, context_mapping_features=context_mapping_features, - use_snake=use_snake + use_snake=use_snake, ) for i in range(num_layers) ] ) if self.use_transformer: - assert ( - (exists(attention_heads) or exists(attention_features)) - and exists(attention_multiplier) - ) + assert (exists(attention_heads) or exists(attention_features)) and exists(attention_multiplier) if attention_features is None and attention_heads is not None: attention_features = channels // attention_heads @@ -763,7 +747,7 @@ def __init__( num_heads=attention_heads, head_features=attention_features, multiplier=attention_multiplier, - context_features=context_embedding_features + context_features=context_embedding_features, ) if self.use_extract: @@ -772,7 +756,7 @@ def __init__( in_channels=out_channels, out_channels=extract_channels, num_groups=num_extract_groups, - use_snake=use_snake + use_snake=use_snake, ) def forward( @@ -783,7 +767,7 @@ def forward( channels: Optional[Tensor] = None, embedding: Optional[Tensor] = None, embedding_mask: Optional[Tensor] = None, - causal: Optional[bool] = False + causal: Optional[bool] = False, ) -> Union[Tuple[Tensor, List[Tensor]], Tensor]: if self.use_pre_downsample: @@ -840,7 +824,7 @@ def __init__( self.use_pre_upsample = use_pre_upsample self.use_transformer = num_transformer_blocks > 0 self.use_skip = use_skip - self.skip_scale = 2 ** -0.5 if use_skip_scale else 1.0 + self.skip_scale = 2**-0.5 if use_skip_scale else 1.0 channels = out_channels if use_pre_upsample else in_channels @@ -851,17 +835,14 @@ def __init__( out_channels=channels, num_groups=num_groups, context_mapping_features=context_mapping_features, - use_snake=use_snake + use_snake=use_snake, ) for _ in range(num_layers) ] ) if self.use_transformer: - assert ( - (exists(attention_heads) or exists(attention_features)) - and exists(attention_multiplier) - ) + assert (exists(attention_heads) or exists(attention_features)) and exists(attention_multiplier) if attention_features is None and attention_heads is not None: attention_features = channels // attention_heads @@ -891,7 +872,7 @@ def __init__( in_channels=out_channels, out_channels=extract_channels, num_groups=num_extract_groups, - use_snake=use_snake + use_snake=use_snake, ) def add_skip(self, x: Tensor, skip: Tensor) -> Tensor: @@ -905,7 +886,7 @@ def forward( mapping: Optional[Tensor] = None, embedding: Optional[Tensor] = None, embedding_mask: Optional[Tensor] = None, - causal: Optional[bool] = False + causal: Optional[bool] = False, ) -> Union[Tuple[Tensor, Tensor], Tensor]: if self.use_pre_upsample: @@ -950,14 +931,11 @@ def __init__( out_channels=channels, num_groups=num_groups, context_mapping_features=context_mapping_features, - use_snake=use_snake + use_snake=use_snake, ) if self.use_transformer: - assert ( - (exists(attention_heads) or exists(attention_features)) - and exists(attention_multiplier) - ) + assert (exists(attention_heads) or exists(attention_features)) and exists(attention_multiplier) if attention_features is None and attention_heads is not None: attention_features = channels // attention_heads @@ -979,7 +957,7 @@ def __init__( out_channels=channels, num_groups=num_groups, context_mapping_features=context_mapping_features, - use_snake=use_snake + use_snake=use_snake, ) def forward( @@ -989,7 +967,7 @@ def forward( mapping: Optional[Tensor] = None, embedding: Optional[Tensor] = None, embedding_mask: Optional[Tensor] = None, - causal: Optional[bool] = False + causal: Optional[bool] = False, ) -> Tensor: x = self.pre_block(x, mapping=mapping, causal=causal) if self.use_transformer: @@ -1056,11 +1034,7 @@ def __init__( self.has_context = has_context self.channels_ids = [sum(has_context[:i]) for i in range(len(has_context))] - assert ( - len(factors) == num_layers - and len(attentions) >= num_layers - and len(num_blocks) == num_layers - ) + assert len(factors) == num_layers and len(attentions) >= num_layers and len(num_blocks) == num_layers if use_context_time or use_context_features: context_mapping_features = channels * context_features_multiplier @@ -1075,18 +1049,14 @@ def __init__( if use_context_time: assert exists(context_mapping_features) self.to_time = nn.Sequential( - TimePositionalEmbedding( - dim=channels, out_features=context_mapping_features - ), + TimePositionalEmbedding(dim=channels, out_features=context_mapping_features), nn.GELU(), ) if use_context_features: assert exists(context_features) and exists(context_mapping_features) self.to_features = nn.Sequential( - nn.Linear( - in_features=context_features, out_features=context_mapping_features - ), + nn.Linear(in_features=context_features, out_features=context_mapping_features), nn.GELU(), ) @@ -1107,7 +1077,7 @@ def __init__( out_channels=channels * multipliers[0], patch_size=patch_size, context_mapping_features=context_mapping_features, - use_snake=use_snake + use_snake=use_snake, ) self.downsamples = nn.ModuleList( @@ -1170,12 +1140,10 @@ def __init__( out_channels=out_channels, patch_size=patch_size, context_mapping_features=context_mapping_features, - use_snake=use_snake + use_snake=use_snake, ) - def get_channels( - self, channels_list: Optional[Sequence[Tensor]] = None, layer: int = 0 - ) -> Optional[Tensor]: + def get_channels(self, channels_list: Optional[Sequence[Tensor]] = None, layer: int = 0) -> Optional[Tensor]: """Gets context channels at `layer` and checks that shape is correct""" use_context_channels = self.use_context_channels and self.has_context[layer] if not use_context_channels: @@ -1195,9 +1163,7 @@ def get_channels( channels = self.stft.encode1d(channels) if self.use_stft_context else channels # type: ignore # noqa return channels - def get_mapping( - self, time: Optional[Tensor] = None, features: Optional[Tensor] = None - ) -> Optional[Tensor]: + def get_mapping(self, time: Optional[Tensor] = None, features: Optional[Tensor] = None) -> Optional[Tensor]: """Combines context time features and features into mapping""" items, mapping = [], None # Compute time features @@ -1248,7 +1214,9 @@ def forward( for i, upsample in enumerate(self.upsamples): skips = skips_list.pop() - x = upsample(x, skips=skips, mapping=mapping, embedding=embedding, embedding_mask=embedding_mask, causal=causal) + x = upsample( + x, skips=skips, mapping=mapping, embedding=embedding, embedding_mask=embedding_mask, causal=causal + ) x += skips_list.pop() x = self.to_out(x, mapping, causal=causal) @@ -1286,7 +1254,6 @@ def rand_bool(shape: Any, proba: float, device: Any = None) -> Tensor: class UNetCFG1d(UNet1d): - """UNet1d with Classifier-Free Guidance""" def __init__( @@ -1296,22 +1263,18 @@ def __init__( use_xattn_time: bool = False, **kwargs, ): - super().__init__( - context_embedding_features=context_embedding_features, **kwargs - ) + super().__init__(context_embedding_features=context_embedding_features, **kwargs) self.use_xattn_time = use_xattn_time if use_xattn_time: assert exists(context_embedding_features) self.to_time_embedding = nn.Sequential( - TimePositionalEmbedding( - dim=kwargs["channels"], out_features=context_embedding_features - ), + TimePositionalEmbedding(dim=kwargs["channels"], out_features=context_embedding_features), nn.GELU(), ) - context_embedding_max_length += 1 # Add one for time embedding + context_embedding_max_length += 1 # Add one for time embedding self.fixed_embedding = FixedEmbedding( max_length=context_embedding_max_length, features=context_embedding_features @@ -1345,9 +1308,7 @@ def forward( # type: ignore if embedding_mask_proba > 0.0: # Randomly mask embedding - batch_mask = rand_bool( - shape=(b, 1, 1), proba=embedding_mask_proba, device=device - ) + batch_mask = rand_bool(shape=(b, 1, 1), proba=embedding_mask_proba, device=device) embedding = torch.where(batch_mask, fixed_embedding, embedding) if embedding_scale != 1.0: @@ -1360,7 +1321,7 @@ def forward( # type: ignore negative_embedding_mask = negative_embedding_mask.to(torch.bool).unsqueeze(2) negative_embedding = torch.where(negative_embedding_mask, negative_embedding, fixed_embedding) - + batch_embed = torch.cat([embedding, negative_embedding], dim=0) else: @@ -1383,13 +1344,23 @@ def forward( # type: ignore batch_channels += [torch.cat([channels, channels], dim=0)] # Compute both normal and fixed embedding outputs - batch_out = super().forward(batch_x, batch_time, embedding=batch_embed, embedding_mask=batch_mask, features=batch_features, channels_list=batch_channels, **kwargs) + batch_out = super().forward( + batch_x, + batch_time, + embedding=batch_embed, + embedding_mask=batch_mask, + features=batch_features, + channels_list=batch_channels, + **kwargs, + ) out, out_masked = batch_out.chunk(2, dim=0) - + else: # Compute both normal and fixed embedding outputs out = super().forward(x, time, embedding=embedding, embedding_mask=embedding_mask, **kwargs) - out_masked = super().forward(x, time, embedding=fixed_embedding, embedding_mask=embedding_mask, **kwargs) + out_masked = super().forward( + x, time, embedding=fixed_embedding, embedding_mask=embedding_mask, **kwargs + ) out_cfg = out_masked + (out - out_masked) * embedding_scale @@ -1398,18 +1369,17 @@ def forward( # type: ignore out_std = out.std(dim=1, keepdim=True) out_cfg_std = out_cfg.std(dim=1, keepdim=True) - return scale_phi * (out_cfg * (out_std/out_cfg_std)) + (1-scale_phi) * out_cfg + return scale_phi * (out_cfg * (out_std / out_cfg_std)) + (1 - scale_phi) * out_cfg else: return out_cfg - + else: return super().forward(x, time, embedding=embedding, embedding_mask=embedding_mask, **kwargs) class UNetNCCA1d(UNet1d): - """UNet1d with Noise Channel Conditioning Augmentation""" def __init__(self, context_features: int, **kwargs): @@ -1426,12 +1396,8 @@ def forward( # type: ignore time: Tensor, *, channels_list: Sequence[Tensor], - channels_augmentation: Union[ - bool, Sequence[bool], Sequence[Sequence[bool]], Tensor - ] = False, - channels_scale: Union[ - float, Sequence[float], Sequence[Sequence[float]], Tensor - ] = 0, + channels_augmentation: Union[bool, Sequence[bool], Sequence[Sequence[bool]], Tensor] = False, + channels_scale: Union[float, Sequence[float], Sequence[Sequence[float]], Tensor] = 0, **kwargs, ) -> Tensor: b, n = x.shape[0], len(channels_list) @@ -1478,6 +1444,7 @@ def XUNet1d(type: str = "base", **kwargs) -> UNet1d: else: raise ValueError(f"Unknown XUNet1d type: {type}") + class NumberEmbedder(nn.Module): def __init__( self, @@ -1574,9 +1541,7 @@ def decode(self, stft_a: Tensor, stft_b: Tensor) -> Tensor: return rearrange(wave, "(b c) t -> b c t", b=b) - def encode1d( - self, wave: Tensor, stacked: bool = True - ) -> Union[Tensor, Tuple[Tensor, Tensor]]: + def encode1d(self, wave: Tensor, stacked: bool = True) -> Union[Tensor, Tuple[Tensor, Tensor]]: stft_a, stft_b = self.encode(wave) stft_a, stft_b = rearrange_many((stft_a, stft_b), "b c f l -> b (c f) l") return torch.cat((stft_a, stft_b), dim=1) if stacked else (stft_a, stft_b) diff --git a/stable_audio_tools/models/autoencoders.py b/stable_audio_tools/models/autoencoders.py index 7c4bdbdc..fd5e96f9 100644 --- a/stable_audio_tools/models/autoencoders.py +++ b/stable_audio_tools/models/autoencoders.py @@ -1,26 +1,33 @@ -import torch import math -import numpy as np +from typing import Any, Dict, Literal +import numpy as np +import torch +from alias_free_torch import Activation1d +from dac.nn.layers import WNConv1d, WNConvTranspose1d from torch import nn from torch.nn import functional as F from torchaudio import transforms as T -from alias_free_torch import Activation1d -from dac.nn.layers import WNConv1d, WNConvTranspose1d -from typing import Literal, Dict, Any from ..inference.sampling import sample from ..inference.utils import prepare_audio from .blocks import SnakeBeta from .bottleneck import Bottleneck, DiscreteBottleneck -from .diffusion import ConditionedDiffusionModel, DAU1DCondWrapper, UNet1DCondWrapper, DiTWrapper -from .factory import create_pretransform_from_config, create_bottleneck_from_config +from .diffusion import ( + ConditionedDiffusionModel, + DAU1DCondWrapper, + DiTWrapper, + UNet1DCondWrapper, +) +from .factory import create_bottleneck_from_config, create_pretransform_from_config from .pretransforms import Pretransform + def checkpoint(function, *args, **kwargs): kwargs.setdefault("use_reentrant", False) return torch.utils.checkpoint.checkpoint(function, *args, **kwargs) + def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module: if activation == "elu": act = nn.ELU() @@ -30,115 +37,134 @@ def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, act = nn.Identity() else: raise ValueError(f"Unknown activation {activation}") - + if antialias: act = Activation1d(act) - + return act + class ResidualUnit(nn.Module): def __init__(self, in_channels, out_channels, dilation, use_snake=False, antialias_activation=False): super().__init__() - + self.dilation = dilation - padding = (dilation * (7-1)) // 2 + padding = (dilation * (7 - 1)) // 2 self.layers = nn.Sequential( get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels), - WNConv1d(in_channels=in_channels, out_channels=out_channels, - kernel_size=7, dilation=dilation, padding=padding), + WNConv1d( + in_channels=in_channels, out_channels=out_channels, kernel_size=7, dilation=dilation, padding=padding + ), get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels), - WNConv1d(in_channels=out_channels, out_channels=out_channels, - kernel_size=1) + WNConv1d(in_channels=out_channels, out_channels=out_channels, kernel_size=1), ) def forward(self, x): res = x - - #x = checkpoint(self.layers, x) + + # x = checkpoint(self.layers, x) x = self.layers(x) return x + res + class EncoderBlock(nn.Module): def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False): super().__init__() self.layers = nn.Sequential( - ResidualUnit(in_channels=in_channels, - out_channels=in_channels, dilation=1, use_snake=use_snake), - ResidualUnit(in_channels=in_channels, - out_channels=in_channels, dilation=3, use_snake=use_snake), - ResidualUnit(in_channels=in_channels, - out_channels=in_channels, dilation=9, use_snake=use_snake), + ResidualUnit(in_channels=in_channels, out_channels=in_channels, dilation=1, use_snake=use_snake), + ResidualUnit(in_channels=in_channels, out_channels=in_channels, dilation=3, use_snake=use_snake), + ResidualUnit(in_channels=in_channels, out_channels=in_channels, dilation=9, use_snake=use_snake), get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels), - WNConv1d(in_channels=in_channels, out_channels=out_channels, - kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2)), + WNConv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=2 * stride, + stride=stride, + padding=math.ceil(stride / 2), + ), ) def forward(self, x): return self.layers(x) + class DecoderBlock(nn.Module): - def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False, use_nearest_upsample=False): + def __init__( + self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False, use_nearest_upsample=False + ): super().__init__() if use_nearest_upsample: upsample_layer = nn.Sequential( nn.Upsample(scale_factor=stride, mode="nearest"), - WNConv1d(in_channels=in_channels, - out_channels=out_channels, - kernel_size=2*stride, - stride=1, - bias=False, - padding='same') + WNConv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=2 * stride, + stride=1, + bias=False, + padding="same", + ), ) else: - upsample_layer = WNConvTranspose1d(in_channels=in_channels, - out_channels=out_channels, - kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2)) + upsample_layer = WNConvTranspose1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=2 * stride, + stride=stride, + padding=math.ceil(stride / 2), + ) self.layers = nn.Sequential( get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels), upsample_layer, - ResidualUnit(in_channels=out_channels, out_channels=out_channels, - dilation=1, use_snake=use_snake), - ResidualUnit(in_channels=out_channels, out_channels=out_channels, - dilation=3, use_snake=use_snake), - ResidualUnit(in_channels=out_channels, out_channels=out_channels, - dilation=9, use_snake=use_snake), + ResidualUnit(in_channels=out_channels, out_channels=out_channels, dilation=1, use_snake=use_snake), + ResidualUnit(in_channels=out_channels, out_channels=out_channels, dilation=3, use_snake=use_snake), + ResidualUnit(in_channels=out_channels, out_channels=out_channels, dilation=9, use_snake=use_snake), ) def forward(self, x): return self.layers(x) + class OobleckEncoder(nn.Module): - def __init__(self, - in_channels=2, - channels=128, - latent_dim=32, - c_mults = [1, 2, 4, 8], - strides = [2, 4, 8, 8], - use_snake=False, - antialias_activation=False - ): + def __init__( + self, + in_channels=2, + channels=128, + latent_dim=32, + c_mults=[1, 2, 4, 8], + strides=[2, 4, 8, 8], + use_snake=False, + antialias_activation=False, + ): super().__init__() - + c_mults = [1] + c_mults self.depth = len(c_mults) - layers = [ - WNConv1d(in_channels=in_channels, out_channels=c_mults[0] * channels, kernel_size=7, padding=3) - ] - - for i in range(self.depth-1): - layers += [EncoderBlock(in_channels=c_mults[i]*channels, out_channels=c_mults[i+1]*channels, stride=strides[i], use_snake=use_snake)] + layers = [WNConv1d(in_channels=in_channels, out_channels=c_mults[0] * channels, kernel_size=7, padding=3)] + + for i in range(self.depth - 1): + layers += [ + EncoderBlock( + in_channels=c_mults[i] * channels, + out_channels=c_mults[i + 1] * channels, + stride=strides[i], + use_snake=use_snake, + ) + ] layers += [ - get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[-1] * channels), - WNConv1d(in_channels=c_mults[-1]*channels, out_channels=latent_dim, kernel_size=3, padding=1) + get_activation( + "snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[-1] * channels + ), + WNConv1d(in_channels=c_mults[-1] * channels, out_channels=latent_dim, kernel_size=3, padding=1), ] self.layers = nn.Sequential(*layers) @@ -148,41 +174,48 @@ def forward(self, x): class OobleckDecoder(nn.Module): - def __init__(self, - out_channels=2, - channels=128, - latent_dim=32, - c_mults = [1, 2, 4, 8], - strides = [2, 4, 8, 8], - use_snake=False, - antialias_activation=False, - use_nearest_upsample=False, - final_tanh=True): + def __init__( + self, + out_channels=2, + channels=128, + latent_dim=32, + c_mults=[1, 2, 4, 8], + strides=[2, 4, 8, 8], + use_snake=False, + antialias_activation=False, + use_nearest_upsample=False, + final_tanh=True, + ): super().__init__() c_mults = [1] + c_mults - + self.depth = len(c_mults) layers = [ - WNConv1d(in_channels=latent_dim, out_channels=c_mults[-1]*channels, kernel_size=7, padding=3), + WNConv1d(in_channels=latent_dim, out_channels=c_mults[-1] * channels, kernel_size=7, padding=3), ] - - for i in range(self.depth-1, 0, -1): - layers += [DecoderBlock( - in_channels=c_mults[i]*channels, - out_channels=c_mults[i-1]*channels, - stride=strides[i-1], - use_snake=use_snake, - antialias_activation=antialias_activation, - use_nearest_upsample=use_nearest_upsample + + for i in range(self.depth - 1, 0, -1): + layers += [ + DecoderBlock( + in_channels=c_mults[i] * channels, + out_channels=c_mults[i - 1] * channels, + stride=strides[i - 1], + use_snake=use_snake, + antialias_activation=antialias_activation, + use_nearest_upsample=use_nearest_upsample, ) ] layers += [ - get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[0] * channels), - WNConv1d(in_channels=c_mults[0] * channels, out_channels=out_channels, kernel_size=7, padding=3, bias=False), - nn.Tanh() if final_tanh else nn.Identity() + get_activation( + "snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[0] * channels + ), + WNConv1d( + in_channels=c_mults[0] * channels, out_channels=out_channels, kernel_size=7, padding=3, bias=False + ), + nn.Tanh() if final_tanh else nn.Identity(), ] self.layers = nn.Sequential(*layers) @@ -204,7 +237,9 @@ def __init__(self, in_channels=1, **kwargs): self.latent_dim = latent_dim # Latent-dim support was added to DAC after this was first written, and implemented differently, so this is for backwards compatibility - self.proj_out = nn.Conv1d(self.encoder.enc_dim, latent_dim, kernel_size=1) if latent_dim is not None else nn.Identity() + self.proj_out = ( + nn.Conv1d(self.encoder.enc_dim, latent_dim, kernel_size=1) if latent_dim is not None else nn.Identity() + ) if in_channels != 1: self.encoder.block[0] = WNConv1d(in_channels, kwargs.get("d_model", 64), kernel_size=7, padding=3) @@ -214,19 +249,21 @@ def forward(self, x): x = self.proj_out(x) return x + class DACDecoderWrapper(nn.Module): def __init__(self, latent_dim, out_channels=1, **kwargs): super().__init__() from dac.model.dac import Decoder as DACDecoder - self.decoder = DACDecoder(**kwargs, input_channel = latent_dim, d_out=out_channels) + self.decoder = DACDecoder(**kwargs, input_channel=latent_dim, d_out=out_channels) self.latent_dim = latent_dim def forward(self, x): return self.decoder(x) + class AudioAutoencoder(nn.Module): def __init__( self, @@ -238,9 +275,9 @@ def __init__( io_channels=2, bottleneck: Bottleneck = None, pretransform: Pretransform = None, - in_channels = None, - out_channels = None, - soft_clip = False + in_channels=None, + out_channels=None, + soft_clip=False, ): super().__init__() @@ -269,7 +306,7 @@ def __init__( self.pretransform = pretransform self.soft_clip = soft_clip - + self.is_discrete = self.bottleneck is not None and self.bottleneck.is_discrete def encode(self, audio, return_info=False, skip_pretransform=False, iterate_batch=False, **kwargs): @@ -281,7 +318,7 @@ def encode(self, audio, return_info=False, skip_pretransform=False, iterate_batc if iterate_batch: audios = [] for i in range(audio.shape[0]): - audios.append(self.pretransform.encode(audio[i:i+1])) + audios.append(self.pretransform.encode(audio[i : i + 1])) audio = torch.cat(audios, dim=0) else: audio = self.pretransform.encode(audio) @@ -290,7 +327,7 @@ def encode(self, audio, return_info=False, skip_pretransform=False, iterate_batc if iterate_batch: audios = [] for i in range(audio.shape[0]): - audios.append(self.pretransform.encode(audio[i:i+1])) + audios.append(self.pretransform.encode(audio[i : i + 1])) audio = torch.cat(audios, dim=0) else: audio = self.pretransform.encode(audio) @@ -299,7 +336,7 @@ def encode(self, audio, return_info=False, skip_pretransform=False, iterate_batc if iterate_batch: latents = [] for i in range(audio.shape[0]): - latents.append(self.encoder(audio[i:i+1])) + latents.append(self.encoder(audio[i : i + 1])) latents = torch.cat(latents, dim=0) else: latents = self.encoder(audio) @@ -311,7 +348,7 @@ def encode(self, audio, return_info=False, skip_pretransform=False, iterate_batc latents, bottleneck_info = self.bottleneck.encode(latents, return_info=True, **kwargs) info.update(bottleneck_info) - + if return_info: return latents, info @@ -323,7 +360,7 @@ def decode(self, latents, iterate_batch=False, **kwargs): if iterate_batch: decoded = [] for i in range(latents.shape[0]): - decoded.append(self.bottleneck.decode(latents[i:i+1])) + decoded.append(self.bottleneck.decode(latents[i : i + 1])) latents = torch.cat(decoded, dim=0) else: latents = self.bottleneck.decode(latents) @@ -331,7 +368,7 @@ def decode(self, latents, iterate_batch=False, **kwargs): if iterate_batch: decoded = [] for i in range(latents.shape[0]): - decoded.append(self.decoder(latents[i:i+1])) + decoded.append(self.decoder(latents[i : i + 1])) decoded = torch.cat(decoded, dim=0) else: decoded = self.decoder(latents, **kwargs) @@ -341,7 +378,7 @@ def decode(self, latents, iterate_batch=False, **kwargs): if iterate_batch: decodeds = [] for i in range(decoded.shape[0]): - decodeds.append(self.pretransform.decode(decoded[i:i+1])) + decodeds.append(self.pretransform.decode(decoded[i : i + 1])) decoded = torch.cat(decodeds, dim=0) else: decoded = self.pretransform.decode(decoded) @@ -350,52 +387,51 @@ def decode(self, latents, iterate_batch=False, **kwargs): if iterate_batch: decodeds = [] for i in range(latents.shape[0]): - decodeds.append(self.pretransform.decode(decoded[i:i+1])) + decodeds.append(self.pretransform.decode(decoded[i : i + 1])) decoded = torch.cat(decodeds, dim=0) else: decoded = self.pretransform.decode(decoded) if self.soft_clip: decoded = torch.tanh(decoded) - + return decoded - + def decode_tokens(self, tokens, **kwargs): - ''' + """ Decode discrete tokens to audio Only works with discrete autoencoders - ''' + """ assert isinstance(self.bottleneck, DiscreteBottleneck), "decode_tokens only works with discrete autoencoders" latents = self.bottleneck.decode_tokens(tokens, **kwargs) return self.decode(latents, **kwargs) - - + def preprocess_audio_for_encoder(self, audio, in_sr): - ''' + """ Preprocess single audio tensor (Channels x Length) to be compatible with the encoder. If the model is mono, stereo audio will be converted to mono. Audio will be silence-padded to be a multiple of the model's downsampling ratio. - Audio will be resampled to the model's sample rate. + Audio will be resampled to the model's sample rate. The output will have batch size 1 and be shape (1 x Channels x Length) - ''' + """ return self.preprocess_audio_list_for_encoder([audio], [in_sr]) def preprocess_audio_list_for_encoder(self, audio_list, in_sr_list): - ''' - Preprocess a [list] of audio (Channels x Length) into a batch tensor to be compatable with the encoder. - The audio in that list can be of different lengths and channels. + """ + Preprocess a [list] of audio (Channels x Length) into a batch tensor to be compatable with the encoder. + The audio in that list can be of different lengths and channels. in_sr can be an integer or list. If it's an integer it will be assumed it is the input sample_rate for every audio. - All audio will be resampled to the model's sample rate. - Audio will be silence-padded to the longest length, and further padded to be a multiple of the model's downsampling ratio. - If the model is mono, all audio will be converted to mono. + All audio will be resampled to the model's sample rate. + Audio will be silence-padded to the longest length, and further padded to be a multiple of the model's downsampling ratio. + If the model is mono, all audio will be converted to mono. The output will be a tensor of shape (Batch x Channels x Length) - ''' + """ batch_size = len(audio_list) if isinstance(in_sr_list, int): - in_sr_list = [in_sr_list]*batch_size + in_sr_list = [in_sr_list] * batch_size assert len(in_sr_list) == batch_size, "list of sample rates must be the same length of audio_list" new_audio = [] max_length = 0 @@ -409,7 +445,7 @@ def preprocess_audio_list_for_encoder(self, audio_list, in_sr_list): elif len(audio.shape) == 1: # Mono signal, channel dimension is missing, unsqueeze it in audio = audio.unsqueeze(0) - assert len(audio.shape)==2, "Audio should be shape (Channels x Length) with no batch dimension" + assert len(audio.shape) == 2, "Audio should be shape (Channels x Length) with no batch dimension" # Resample audio if in_sr != self.sample_rate: resample_tf = T.Resample(in_sr, self.sample_rate).to(audio.device) @@ -421,25 +457,31 @@ def preprocess_audio_list_for_encoder(self, audio_list, in_sr_list): padded_audio_length = max_length + (self.min_length - (max_length % self.min_length)) % self.min_length for i in range(batch_size): # Pad it & if necessary, mixdown/duplicate stereo/mono channels to support model - new_audio[i] = prepare_audio(new_audio[i], in_sr=in_sr, target_sr=in_sr, target_length=padded_audio_length, - target_channels=self.in_channels, device=new_audio[i].device).squeeze(0) - # convert to tensor - return torch.stack(new_audio) + new_audio[i] = prepare_audio( + new_audio[i], + in_sr=in_sr, + target_sr=in_sr, + target_length=padded_audio_length, + target_channels=self.in_channels, + device=new_audio[i].device, + ).squeeze(0) + # convert to tensor + return torch.stack(new_audio) def encode_audio(self, audio, chunked=False, overlap=32, chunk_size=128, **kwargs): - ''' + """ Encode audios into latents. Audios should already be preprocesed by preprocess_audio_for_encoder. If chunked is True, split the audio into chunks of a given maximum size chunk_size, with given overlap. - Overlap and chunk_size params are both measured in number of latents (not audio samples) - # and therefore you likely could use the same values with decode_audio. - A overlap of zero will cause discontinuity artefacts. Overlap should be => receptive field size. + Overlap and chunk_size params are both measured in number of latents (not audio samples) + # and therefore you likely could use the same values with decode_audio. + A overlap of zero will cause discontinuity artefacts. Overlap should be => receptive field size. Every autoencoder will have a different receptive field size, and thus ideal overlap. You can determine it empirically by diffing unchunked vs chunked output and looking at maximum diff. The final chunk may have a longer overlap in order to keep chunk_size consistent for all chunks. Smaller chunk_size uses less memory, but more compute. The chunk_size vs memory tradeoff isn't linear, and possibly depends on the GPU and CUDA version For example, on a A6000 chunk_size 128 is overall faster than 256 and 512 even though it has more chunks - ''' + """ if not chunked: # default behavior. Encode the entire audio in parallel return self.encode(audio, **kwargs) @@ -447,18 +489,18 @@ def encode_audio(self, audio, chunked=False, overlap=32, chunk_size=128, **kwarg # CHUNKED ENCODING # samples_per_latent is just the downsampling ratio (which is also the upsampling ratio) samples_per_latent = self.downsampling_ratio - total_size = audio.shape[2] # in samples + total_size = audio.shape[2] # in samples batch_size = audio.shape[0] - chunk_size *= samples_per_latent # converting metric in latents to samples - overlap *= samples_per_latent # converting metric in latents to samples + chunk_size *= samples_per_latent # converting metric in latents to samples + overlap *= samples_per_latent # converting metric in latents to samples hop_size = chunk_size - overlap chunks = [] for i in range(0, total_size - chunk_size + 1, hop_size): - chunk = audio[:,:,i:i+chunk_size] + chunk = audio[:, :, i : i + chunk_size] chunks.append(chunk) - if i+chunk_size != total_size: + if i + chunk_size != total_size: # Final chunk - chunk = audio[:,:,-chunk_size:] + chunk = audio[:, :, -chunk_size:] chunks.append(chunk) chunks = torch.stack(chunks) num_chunks = chunks.shape[0] @@ -467,13 +509,13 @@ def encode_audio(self, audio, chunked=False, overlap=32, chunk_size=128, **kwarg # However, the audio should've been padded to a multiple of samples_per_latent by now. y_size = total_size // samples_per_latent # Create an empty latent, we will populate it with chunks as we encode them - y_final = torch.zeros((batch_size,self.latent_dim,y_size)).to(audio.device) + y_final = torch.zeros((batch_size, self.latent_dim, y_size)).to(audio.device) for i in range(num_chunks): - x_chunk = chunks[i,:] + x_chunk = chunks[i, :] # encode the chunk y_chunk = self.encode(x_chunk) # figure out where to put the audio along the time domain - if i == num_chunks-1: + if i == num_chunks - 1: # final chunk always goes at the end t_end = y_size t_start = t_end - y_chunk.shape[2] @@ -481,33 +523,33 @@ def encode_audio(self, audio, chunked=False, overlap=32, chunk_size=128, **kwarg t_start = i * hop_size // samples_per_latent t_end = t_start + chunk_size // samples_per_latent # remove the edges of the overlaps - ol = overlap//samples_per_latent//2 + ol = overlap // samples_per_latent // 2 chunk_start = 0 chunk_end = y_chunk.shape[2] if i > 0: # no overlap for the start of the first chunk t_start += ol chunk_start += ol - if i < num_chunks-1: + if i < num_chunks - 1: # no overlap for the end of the last chunk t_end -= ol chunk_end -= ol # paste the chunked audio into our y_final output audio - y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end] + y_final[:, :, t_start:t_end] = y_chunk[:, :, chunk_start:chunk_end] return y_final - + def decode_audio(self, latents, chunked=False, overlap=32, chunk_size=128, **kwargs): - ''' - Decode latents to audio. - If chunked is True, split the latents into chunks of a given maximum size chunk_size, with given overlap, both of which are measured in number of latents. - A overlap of zero will cause discontinuity artefacts. Overlap should be => receptive field size. + """ + Decode latents to audio. + If chunked is True, split the latents into chunks of a given maximum size chunk_size, with given overlap, both of which are measured in number of latents. + A overlap of zero will cause discontinuity artefacts. Overlap should be => receptive field size. Every autoencoder will have a different receptive field size, and thus ideal overlap. You can determine it empirically by diffing unchunked vs chunked audio and looking at maximum diff. The final chunk may have a longer overlap in order to keep chunk_size consistent for all chunks. Smaller chunk_size uses less memory, but more compute. The chunk_size vs memory tradeoff isn't linear, and possibly depends on the GPU and CUDA version For example, on a A6000 chunk_size 128 is overall faster than 256 and 512 even though it has more chunks - ''' + """ if not chunked: # default behavior. Decode the entire latent in parallel return self.decode(latents, **kwargs) @@ -518,11 +560,11 @@ def decode_audio(self, latents, chunked=False, overlap=32, chunk_size=128, **kwa batch_size = latents.shape[0] chunks = [] for i in range(0, total_size - chunk_size + 1, hop_size): - chunk = latents[:,:,i:i+chunk_size] + chunk = latents[:, :, i : i + chunk_size] chunks.append(chunk) - if i+chunk_size != total_size: + if i + chunk_size != total_size: # Final chunk - chunk = latents[:,:,-chunk_size:] + chunk = latents[:, :, -chunk_size:] chunks.append(chunk) chunks = torch.stack(chunks) num_chunks = chunks.shape[0] @@ -530,13 +572,13 @@ def decode_audio(self, latents, chunked=False, overlap=32, chunk_size=128, **kwa samples_per_latent = self.downsampling_ratio # Create an empty waveform, we will populate it with chunks as decode them y_size = total_size * samples_per_latent - y_final = torch.zeros((batch_size,self.out_channels,y_size)).to(latents.device) + y_final = torch.zeros((batch_size, self.out_channels, y_size)).to(latents.device) for i in range(num_chunks): - x_chunk = chunks[i,:] + x_chunk = chunks[i, :] # decode the chunk y_chunk = self.decode(x_chunk) # figure out where to put the audio along the time domain - if i == num_chunks-1: + if i == num_chunks - 1: # final chunk always goes at the end t_end = y_size t_start = t_end - y_chunk.shape[2] @@ -544,30 +586,24 @@ def decode_audio(self, latents, chunked=False, overlap=32, chunk_size=128, **kwa t_start = i * hop_size * samples_per_latent t_end = t_start + chunk_size * samples_per_latent # remove the edges of the overlaps - ol = (overlap//2) * samples_per_latent + ol = (overlap // 2) * samples_per_latent chunk_start = 0 chunk_end = y_chunk.shape[2] if i > 0: # no overlap for the start of the first chunk t_start += ol chunk_start += ol - if i < num_chunks-1: + if i < num_chunks - 1: # no overlap for the end of the last chunk t_end -= ol chunk_end -= ol # paste the chunked audio into our y_final output audio - y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end] + y_final[:, :, t_start:t_end] = y_chunk[:, :, chunk_start:chunk_end] return y_final - + class DiffusionAutoencoder(AudioAutoencoder): - def __init__( - self, - diffusion: ConditionedDiffusionModel, - diffusion_downsampling_ratio, - *args, - **kwargs - ): + def __init__(self, diffusion: ConditionedDiffusionModel, diffusion_downsampling_ratio, *args, **kwargs): super().__init__(*args, **kwargs) self.diffusion = diffusion @@ -589,10 +625,10 @@ def decode(self, latents, steps=100): if self.decoder is not None: latents = self.decode(latents) - + # Upsample latents to match diffusion length if latents.shape[2] != upsampled_length: - latents = F.interpolate(latents, size=upsampled_length, mode='nearest') + latents = F.interpolate(latents, size=upsampled_length, mode="nearest") noise = torch.randn(latents.shape[0], self.io_channels, upsampled_length, device=latents.device) decoded = sample(self.diffusion, noise, steps, 0, input_concat_cond=latents) @@ -605,27 +641,26 @@ def decode(self, latents, steps=100): decoded = self.pretransform.decode(decoded) return decoded - + + # AE factories + def create_encoder_from_config(encoder_config: Dict[str, Any]): encoder_type = encoder_config.get("type", None) assert encoder_type is not None, "Encoder type must be specified" if encoder_type == "oobleck": - encoder = OobleckEncoder( - **encoder_config["config"] - ) - + encoder = OobleckEncoder(**encoder_config["config"]) + elif encoder_type == "seanet": from encodec.modules import SEANetEncoder + seanet_encoder_config = encoder_config["config"] - #SEANet encoder expects strides in reverse order + # SEANet encoder expects strides in reverse order seanet_encoder_config["ratios"] = list(reversed(seanet_encoder_config.get("ratios", [2, 2, 2, 2, 2]))) - encoder = SEANetEncoder( - **seanet_encoder_config - ) + encoder = SEANetEncoder(**seanet_encoder_config) elif encoder_type == "dac": dac_config = encoder_config["config"] @@ -635,12 +670,10 @@ def create_encoder_from_config(encoder_config: Dict[str, Any]): local_attn_config = encoder_config["config"] - encoder = TransformerEncoder1D( - **local_attn_config - ) + encoder = TransformerEncoder1D(**local_attn_config) else: raise ValueError(f"Unknown encoder type {encoder_type}") - + requires_grad = encoder_config.get("requires_grad", True) if not requires_grad: for param in encoder.parameters(): @@ -648,20 +681,17 @@ def create_encoder_from_config(encoder_config: Dict[str, Any]): return encoder + def create_decoder_from_config(decoder_config: Dict[str, Any]): decoder_type = decoder_config.get("type", None) assert decoder_type is not None, "Decoder type must be specified" if decoder_type == "oobleck": - decoder = OobleckDecoder( - **decoder_config["config"] - ) + decoder = OobleckDecoder(**decoder_config["config"]) elif decoder_type == "seanet": from encodec.modules import SEANetDecoder - decoder = SEANetDecoder( - **decoder_config["config"] - ) + decoder = SEANetDecoder(**decoder_config["config"]) elif decoder_type == "dac": dac_config = decoder_config["config"] @@ -671,12 +701,10 @@ def create_decoder_from_config(decoder_config: Dict[str, Any]): local_attn_config = decoder_config["config"] - decoder = TransformerDecoder1D( - **local_attn_config - ) + decoder = TransformerDecoder1D(**local_attn_config) else: raise ValueError(f"Unknown decoder type {decoder_type}") - + requires_grad = decoder_config.get("requires_grad", True) if not requires_grad: for param in decoder.parameters(): @@ -684,8 +712,9 @@ def create_decoder_from_config(decoder_config: Dict[str, Any]): return decoder + def create_autoencoder_from_config(config: Dict[str, Any]): - + ae_config = config["model"] encoder = create_encoder_from_config(ae_config["encoder"]) @@ -726,11 +755,12 @@ def create_autoencoder_from_config(config: Dict[str, Any]): pretransform=pretransform, in_channels=in_channels, out_channels=out_channels, - soft_clip=soft_clip + soft_clip=soft_clip, ) + def create_diffAE_from_config(config: Dict[str, Any]): - + diffae_config = config["model"] if "encoder" in diffae_config: @@ -771,7 +801,7 @@ def create_diffAE_from_config(config: Dict[str, Any]): if bottleneck is not None: bottleneck = create_bottleneck_from_config(bottleneck) - diffusion_downsampling_ratio = None, + diffusion_downsampling_ratio = (None,) if diffusion_model_type == "DAU1d": diffusion_downsampling_ratio = np.prod(diffae_config["diffusion"]["config"]["strides"]) @@ -790,5 +820,5 @@ def create_diffAE_from_config(config: Dict[str, Any]): downsampling_ratio=downsampling_ratio, diffusion_downsampling_ratio=diffusion_downsampling_ratio, bottleneck=bottleneck, - pretransform=pretransform + pretransform=pretransform, ) diff --git a/stable_audio_tools/models/blocks.py b/stable_audio_tools/models/blocks.py index 3c827fd2..a11536e3 100644 --- a/stable_audio_tools/models/blocks.py +++ b/stable_audio_tools/models/blocks.py @@ -1,14 +1,14 @@ -from functools import reduce import math +from functools import reduce + import numpy as np import torch +from dac.nn.layers import Snake1d +from packaging import version from torch import nn -from torch.nn import functional as F - from torch.backends.cuda import sdp_kernel -from packaging import version +from torch.nn import functional as F -from dac.nn.layers import Snake1d class ResidualBlock(nn.Module): def __init__(self, main, skip=None): @@ -19,20 +19,25 @@ def __init__(self, main, skip=None): def forward(self, input): return self.main(input) + self.skip(input) + class ResConvBlock(ResidualBlock): def __init__(self, c_in, c_mid, c_out, is_last=False, kernel_size=5, conv_bias=True, use_snake=False): skip = None if c_in == c_out else nn.Conv1d(c_in, c_out, 1, bias=False) - super().__init__([ - nn.Conv1d(c_in, c_mid, kernel_size, padding=kernel_size//2, bias=conv_bias), - nn.GroupNorm(1, c_mid), - Snake1d(c_mid) if use_snake else nn.GELU(), - nn.Conv1d(c_mid, c_out, kernel_size, padding=kernel_size//2, bias=conv_bias), - nn.GroupNorm(1, c_out) if not is_last else nn.Identity(), - (Snake1d(c_out) if use_snake else nn.GELU()) if not is_last else nn.Identity(), - ], skip) + super().__init__( + [ + nn.Conv1d(c_in, c_mid, kernel_size, padding=kernel_size // 2, bias=conv_bias), + nn.GroupNorm(1, c_mid), + Snake1d(c_mid) if use_snake else nn.GELU(), + nn.Conv1d(c_mid, c_out, kernel_size, padding=kernel_size // 2, bias=conv_bias), + nn.GroupNorm(1, c_out) if not is_last else nn.Identity(), + (Snake1d(c_out) if use_snake else nn.GELU()) if not is_last else nn.Identity(), + ], + skip, + ) + class SelfAttention1d(nn.Module): - def __init__(self, c_in, n_head=1, dropout_rate=0.): + def __init__(self, c_in, n_head=1, dropout_rate=0.0): super().__init__() assert c_in % n_head == 0 self.norm = nn.GroupNorm(1, c_in) @@ -41,12 +46,12 @@ def __init__(self, c_in, n_head=1, dropout_rate=0.): self.out_proj = nn.Conv1d(c_in, c_in, 1) self.dropout = nn.Dropout(dropout_rate, inplace=True) - self.use_flash = torch.cuda.is_available() and version.parse(torch.__version__) >= version.parse('2.0.0') + self.use_flash = torch.cuda.is_available() and version.parse(torch.__version__) >= version.parse("2.0.0") if not self.use_flash: return - device_properties = torch.cuda.get_device_properties(torch.device('cuda')) + device_properties = torch.cuda.get_device_properties(torch.device("cuda")) if device_properties.major == 8 and device_properties.minor == 0: # Use flash attention for A100 GPUs @@ -58,10 +63,9 @@ def __init__(self, c_in, n_head=1, dropout_rate=0.): def forward(self, input): n, c, s = input.shape qkv = self.qkv_proj(self.norm(input)) - qkv = qkv.view( - [n, self.n_head * 3, c // self.n_head, s]).transpose(2, 3) + qkv = qkv.view([n, self.n_head * 3, c // self.n_head, s]).transpose(2, 3) q, k, v = qkv.chunk(3, dim=1) - scale = k.shape[3]**-0.25 + scale = k.shape[3] ** -0.25 if self.use_flash: with sdp_kernel(*self.sdp_kernel_config): @@ -70,9 +74,9 @@ def forward(self, input): att = ((q * scale) @ (k.transpose(2, 3) * scale)).softmax(3) y = (att @ v).transpose(2, 3).contiguous().view([n, c, s]) - return input + self.dropout(self.out_proj(y)) + class SkipBlock(nn.Module): def __init__(self, *main): super().__init__() @@ -81,42 +85,51 @@ def __init__(self, *main): def forward(self, input): return torch.cat([self.main(input), input], dim=1) + class FourierFeatures(nn.Module): - def __init__(self, in_features, out_features, std=1.): + def __init__(self, in_features, out_features, std=1.0): super().__init__() assert out_features % 2 == 0 - self.weight = nn.Parameter(torch.randn( - [out_features // 2, in_features]) * std) + self.weight = nn.Parameter(torch.randn([out_features // 2, in_features]) * std) def forward(self, input): f = 2 * math.pi * input @ self.weight.T return torch.cat([f.cos(), f.sin()], dim=-1) + def expand_to_planes(input, shape): return input[..., None].repeat([1, 1, shape[2]]) + _kernels = { - 'linear': - [1 / 8, 3 / 8, 3 / 8, 1 / 8], - 'cubic': - [-0.01171875, -0.03515625, 0.11328125, 0.43359375, - 0.43359375, 0.11328125, -0.03515625, -0.01171875], - 'lanczos3': - [0.003689131001010537, 0.015056144446134567, -0.03399861603975296, - -0.066637322306633, 0.13550527393817902, 0.44638532400131226, - 0.44638532400131226, 0.13550527393817902, -0.066637322306633, - -0.03399861603975296, 0.015056144446134567, 0.003689131001010537] + "linear": [1 / 8, 3 / 8, 3 / 8, 1 / 8], + "cubic": [-0.01171875, -0.03515625, 0.11328125, 0.43359375, 0.43359375, 0.11328125, -0.03515625, -0.01171875], + "lanczos3": [ + 0.003689131001010537, + 0.015056144446134567, + -0.03399861603975296, + -0.066637322306633, + 0.13550527393817902, + 0.44638532400131226, + 0.44638532400131226, + 0.13550527393817902, + -0.066637322306633, + -0.03399861603975296, + 0.015056144446134567, + 0.003689131001010537, + ], } + class Downsample1d(nn.Module): - def __init__(self, kernel='linear', pad_mode='reflect', channels_last=False): + def __init__(self, kernel="linear", pad_mode="reflect", channels_last=False): super().__init__() self.pad_mode = pad_mode kernel_1d = torch.tensor(_kernels[kernel]) self.pad = kernel_1d.shape[0] // 2 - 1 - self.register_buffer('kernel', kernel_1d) + self.register_buffer("kernel", kernel_1d) self.channels_last = channels_last - + def forward(self, x): if self.channels_last: x = x.permute(0, 2, 1) @@ -131,14 +144,14 @@ def forward(self, x): class Upsample1d(nn.Module): - def __init__(self, kernel='linear', pad_mode='reflect', channels_last=False): + def __init__(self, kernel="linear", pad_mode="reflect", channels_last=False): super().__init__() self.pad_mode = pad_mode kernel_1d = torch.tensor(_kernels[kernel]) * 2 self.pad = kernel_1d.shape[0] // 2 - 1 - self.register_buffer('kernel', kernel_1d) + self.register_buffer("kernel", kernel_1d) self.channels_last = channels_last - + def forward(self, x): if self.channels_last: x = x.permute(0, 2, 1) @@ -150,10 +163,9 @@ def forward(self, x): if self.channels_last: x = x.permute(0, 2, 1) return x - -def Downsample1d_2( - in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2 -) -> nn.Module: + + +def Downsample1d_2(in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2) -> nn.Module: assert kernel_multiplier % 2 == 0, "Kernel multiplier must be even" return nn.Conv1d( @@ -165,14 +177,10 @@ def Downsample1d_2( ) -def Upsample1d_2( - in_channels: int, out_channels: int, factor: int, use_nearest: bool = False -) -> nn.Module: +def Upsample1d_2(in_channels: int, out_channels: int, factor: int, use_nearest: bool = False) -> nn.Module: if factor == 1: - return nn.Conv1d( - in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1 - ) + return nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1) if use_nearest: return nn.Sequential( @@ -194,38 +202,44 @@ def Upsample1d_2( output_padding=factor % 2, ) + def zero_init(layer): nn.init.zeros_(layer.weight) if layer.bias is not None: nn.init.zeros_(layer.bias) return layer + def rms_norm(x, scale, eps): dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32)) - mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True) + mean_sq = torch.mean(x.to(dtype) ** 2, dim=-1, keepdim=True) scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps) return x * scale.to(x.dtype) -#rms_norm = torch.compile(rms_norm) + +# rms_norm = torch.compile(rms_norm) + class AdaRMSNorm(nn.Module): def __init__(self, features, cond_features, eps=1e-6): super().__init__() self.eps = eps self.linear = zero_init(nn.Linear(cond_features, features, bias=False)) - + def extra_repr(self): return f"eps={self.eps}," def forward(self, x, cond): return rms_norm(x, self.linear(cond)[:, None, :] + 1, self.eps) - + + def normalize(x, eps=1e-4): dim = list(range(1, x.ndim)) n = torch.linalg.vector_norm(x, dim=dim, keepdim=True) alpha = np.sqrt(n.numel() / x.numel()) return x / torch.add(eps, n, alpha=alpha) + class ForcedWNConv1d(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=1): super().__init__() @@ -235,17 +249,19 @@ def forward(self, x): if self.training: with torch.no_grad(): self.weight.copy_(normalize(self.weight)) - + fan_in = self.weight[0].numel() w = normalize(self.weight) / math.sqrt(fan_in) - return F.conv1d(x, w, padding='same') - + return F.conv1d(x, w, padding="same") + + # Kernels use_compile = True + def compile(function, *args, **kwargs): if not use_compile: return function @@ -267,12 +283,14 @@ def linear_geglu(x, weight, bias=None): @compile def rms_norm(x, scale, eps): dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32)) - mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True) + mean_sq = torch.mean(x.to(dtype) ** 2, dim=-1, keepdim=True) scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps) return x * scale.to(x.dtype) + # Layers + class LinearGEGLU(nn.Linear): def __init__(self, in_features, out_features, bias=True): super().__init__(in_features, out_features * 2, bias=bias) @@ -283,7 +301,7 @@ def forward(self, x): class RMSNorm(nn.Module): - def __init__(self, shape, fix_scale = False, eps=1e-6): + def __init__(self, shape, fix_scale=False, eps=1e-6): super().__init__() self.eps = eps @@ -296,16 +314,19 @@ def extra_repr(self): return f"shape={tuple(self.scale.shape)}, eps={self.eps}" def forward(self, x): - return rms_norm(x, self.scale, self.eps) + return rms_norm(x, self.scale, self.eps) + def snake_beta(x, alpha, beta): return x + (1.0 / (beta + 0.000000001)) * pow(torch.sin(x * alpha), 2) + # try: # snake_beta = torch.compile(snake_beta) # except RuntimeError: # pass + # Adapted from https://github.com/NVIDIA/BigVGAN/blob/main/activations.py under MIT license # License available in LICENSES/LICENSE_NVIDIA.txt class SnakeBeta(nn.Module): @@ -316,10 +337,10 @@ def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale= # initialize alpha self.alpha_logscale = alpha_logscale - if self.alpha_logscale: # log scale alphas initialized to zeros + if self.alpha_logscale: # log scale alphas initialized to zeros self.alpha = nn.Parameter(torch.zeros(in_features) * alpha) self.beta = nn.Parameter(torch.zeros(in_features) * alpha) - else: # linear scale alphas initialized to ones + else: # linear scale alphas initialized to ones self.alpha = nn.Parameter(torch.ones(in_features) * alpha) self.beta = nn.Parameter(torch.ones(in_features) * alpha) @@ -329,11 +350,11 @@ def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale= self.no_div_by_zero = 0.000000001 def forward(self, x): - alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] + alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] beta = self.beta.unsqueeze(0).unsqueeze(-1) if self.alpha_logscale: alpha = torch.exp(alpha) beta = torch.exp(beta) x = snake_beta(x, alpha, beta) - return x \ No newline at end of file + return x diff --git a/stable_audio_tools/models/bottleneck.py b/stable_audio_tools/models/bottleneck.py index 5e81cab4..2ac09cc8 100644 --- a/stable_audio_tools/models/bottleneck.py +++ b/stable_audio_tools/models/bottleneck.py @@ -1,11 +1,11 @@ import numpy as np import torch +from dac.nn.quantize import ResidualVectorQuantize as DACResidualVQ +from einops import rearrange from torch import nn from torch.nn import functional as F +from vector_quantize_pytorch import FSQ, ResidualVQ -from einops import rearrange -from vector_quantize_pytorch import ResidualVQ, FSQ -from dac.nn.quantize import ResidualVectorQuantize as DACResidualVQ class Bottleneck(nn.Module): def __init__(self, is_discrete: bool = False): @@ -19,6 +19,7 @@ def encode(self, x, return_info=False, **kwargs): def decode(self, x): raise NotImplementedError + class DiscreteBottleneck(Bottleneck): def __init__(self, num_quantizers, codebook_size, tokens_id): super().__init__(is_discrete=True) @@ -29,7 +30,8 @@ def __init__(self, num_quantizers, codebook_size, tokens_id): def decode_tokens(self, codes, **kwargs): raise NotImplementedError - + + class TanhBottleneck(Bottleneck): def __init__(self): super().__init__(is_discrete=False) @@ -48,15 +50,17 @@ def encode(self, x, return_info=False): def decode(self, x): return x + def vae_sample(mean, scale): - stdev = nn.functional.softplus(scale) + 1e-4 - var = stdev * stdev - logvar = torch.log(var) - latents = torch.randn_like(mean) * stdev + mean + stdev = nn.functional.softplus(scale) + 1e-4 + var = stdev * stdev + logvar = torch.log(var) + latents = torch.randn_like(mean) * stdev + mean - kl = (mean * mean + var - logvar - 1).sum(1).mean() + kl = (mean * mean + var - logvar - 1).sum(1).mean() + + return latents, kl - return latents, kl class VAEBottleneck(Bottleneck): def __init__(self): @@ -79,9 +83,11 @@ def encode(self, x, return_info=False, **kwargs): def decode(self, x): return x + def compute_mean_kernel(x, y): - kernel_input = (x[:, None] - y[None]).pow(2).mean(2) / x.shape[-1] - return torch.exp(-kernel_input).mean() + kernel_input = (x[:, None] - y[None]).pow(2).mean(2) / x.shape[-1] + return torch.exp(-kernel_input).mean() + def compute_mmd(latents): latents_reshaped = latents.permute(0, 2, 1).reshape(-1, latents.shape[1]) @@ -90,17 +96,18 @@ def compute_mmd(latents): latents_kernel = compute_mean_kernel(latents_reshaped, latents_reshaped) noise_kernel = compute_mean_kernel(noise, noise) latents_noise_kernel = compute_mean_kernel(latents_reshaped, noise) - + mmd = latents_kernel + noise_kernel - 2 * latents_noise_kernel return mmd.mean() + class WassersteinBottleneck(Bottleneck): def __init__(self, noise_augment_dim: int = 0, bypass_mmd: bool = False): super().__init__(is_discrete=False) self.noise_augment_dim = noise_augment_dim self.bypass_mmd = bypass_mmd - + def encode(self, x, return_info=False): info = {} @@ -109,27 +116,27 @@ def encode(self, x, return_info=False): mmd = torch.tensor(0.0) else: mmd = compute_mmd(x) - + info["mmd"] = mmd - + if return_info: return x, info - + return x def decode(self, x): if self.noise_augment_dim > 0: - noise = torch.randn(x.shape[0], self.noise_augment_dim, - x.shape[-1]).type_as(x) + noise = torch.randn(x.shape[0], self.noise_augment_dim, x.shape[-1]).type_as(x) x = torch.cat([x, noise], dim=1) return x + class L2Bottleneck(Bottleneck): def __init__(self): super().__init__(is_discrete=False) - + def encode(self, x, return_info=False): info = {} @@ -139,13 +146,18 @@ def encode(self, x, return_info=False): return x, info else: return x - + def decode(self, x): return F.normalize(x, dim=1) - + + class RVQBottleneck(DiscreteBottleneck): def __init__(self, **quantizer_kwargs): - super().__init__(num_quantizers = quantizer_kwargs["num_quantizers"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "quantizer_indices") + super().__init__( + num_quantizers=quantizer_kwargs["num_quantizers"], + codebook_size=quantizer_kwargs["codebook_size"], + tokens_id="quantizer_indices", + ) self.quantizer = ResidualVQ(**quantizer_kwargs) self.num_quantizers = quantizer_kwargs["num_quantizers"] @@ -163,18 +175,23 @@ def encode(self, x, return_info=False, **kwargs): return x, info else: return x - + def decode(self, x): return x - + def decode_tokens(self, codes, **kwargs): latents = self.quantizer.get_outputs_from_indices(codes) return self.decode(latents, **kwargs) - + + class RVQVAEBottleneck(DiscreteBottleneck): def __init__(self, **quantizer_kwargs): - super().__init__(num_quantizers = quantizer_kwargs["num_quantizers"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "quantizer_indices") + super().__init__( + num_quantizers=quantizer_kwargs["num_quantizers"], + codebook_size=quantizer_kwargs["codebook_size"], + tokens_id="quantizer_indices", + ) self.quantizer = ResidualVQ(**quantizer_kwargs) self.num_quantizers = quantizer_kwargs["num_quantizers"] @@ -196,18 +213,23 @@ def encode(self, x, return_info=False): return x, info else: return x - + def decode(self, x): return x - + def decode_tokens(self, codes, **kwargs): latents = self.quantizer.get_outputs_from_indices(codes) return self.decode(latents, **kwargs) + class DACRVQBottleneck(DiscreteBottleneck): def __init__(self, quantize_on_decode=False, noise_augment_dim=0, **quantizer_kwargs): - super().__init__(num_quantizers = quantizer_kwargs["n_codebooks"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "codes") + super().__init__( + num_quantizers=quantizer_kwargs["n_codebooks"], + codebook_size=quantizer_kwargs["codebook_size"], + tokens_id="codes", + ) self.quantizer = DACResidualVQ(**quantizer_kwargs) self.num_quantizers = quantizer_kwargs["n_codebooks"] self.quantize_on_decode = quantize_on_decode @@ -238,29 +260,33 @@ def encode(self, x, return_info=False, **kwargs): if return_info: return output["z"], info - + return output["z"] - + def decode(self, x): if self.quantize_on_decode: x = self.quantizer(x)[0] if self.noise_augment_dim > 0: - noise = torch.randn(x.shape[0], self.noise_augment_dim, - x.shape[-1]).type_as(x) + noise = torch.randn(x.shape[0], self.noise_augment_dim, x.shape[-1]).type_as(x) x = torch.cat([x, noise], dim=1) return x - + def decode_tokens(self, codes, **kwargs): latents, _, _ = self.quantizer.from_codes(codes) return self.decode(latents, **kwargs) + class DACRVQVAEBottleneck(DiscreteBottleneck): def __init__(self, quantize_on_decode=False, **quantizer_kwargs): - super().__init__(num_quantizers = quantizer_kwargs["n_codebooks"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "codes") + super().__init__( + num_quantizers=quantizer_kwargs["n_codebooks"], + codebook_size=quantizer_kwargs["codebook_size"], + tokens_id="codes", + ) self.quantizer = DACResidualVQ(**quantizer_kwargs) self.num_quantizers = quantizer_kwargs["n_codebooks"] self.quantize_on_decode = quantize_on_decode @@ -295,9 +321,9 @@ def encode(self, x, return_info=False, n_quantizers: int = None): if return_info: return output["z"], info - + return output["z"] - + def decode(self, x): if self.quantize_on_decode: @@ -309,10 +335,15 @@ def decode_tokens(self, codes, **kwargs): latents, _, _ = self.quantizer.from_codes(codes) return self.decode(latents, **kwargs) - + + class FSQBottleneck(DiscreteBottleneck): def __init__(self, noise_augment_dim=0, **kwargs): - super().__init__(num_quantizers = kwargs.get("num_codebooks", 1), codebook_size = np.prod(kwargs["levels"]), tokens_id = "quantizer_indices") + super().__init__( + num_quantizers=kwargs.get("num_codebooks", 1), + codebook_size=np.prod(kwargs["levels"]), + tokens_id="quantizer_indices", + ) self.noise_augment_dim = noise_augment_dim @@ -339,17 +370,16 @@ def encode(self, x, return_info=False): return x, info else: return x - + def decode(self, x): if self.noise_augment_dim > 0: - noise = torch.randn(x.shape[0], self.noise_augment_dim, - x.shape[-1]).type_as(x) + noise = torch.randn(x.shape[0], self.noise_augment_dim, x.shape[-1]).type_as(x) x = torch.cat([x, noise], dim=1) return x - + def decode_tokens(self, tokens, **kwargs): latents = self.quantizer.indices_to_codes(tokens) - return self.decode(latents, **kwargs) \ No newline at end of file + return self.decode(latents, **kwargs) diff --git a/stable_audio_tools/models/codebook_patterns.py b/stable_audio_tools/models/codebook_patterns.py index f9bd2a9b..13efe799 100644 --- a/stable_audio_tools/models/codebook_patterns.py +++ b/stable_audio_tools/models/codebook_patterns.py @@ -1,16 +1,16 @@ # Copied from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/modules/codebooks_patterns.py under MIT License # License available in LICENSES/LICENSE_META.txt +import logging +import typing as tp +from abc import ABC, abstractmethod from collections import namedtuple from dataclasses import dataclass from functools import lru_cache -import logging -import typing as tp -from abc import ABC, abstractmethod import torch -LayoutCoord = namedtuple('LayoutCoord', ['t', 'q']) # (timestep, codebook index) +LayoutCoord = namedtuple("LayoutCoord", ["t", "q"]) # (timestep, codebook index) PatternLayout = tp.List[tp.List[LayoutCoord]] # Sequence of coordinates logger = logging.getLogger(__name__) @@ -36,6 +36,7 @@ class Pattern: to fill and specify invalid positions if needed. See the dedicated methods for more details. """ + # Pattern layout, for each sequence step, we have a list of coordinates # corresponding to the original codebook timestep and position. # The first list is always an empty list in order to properly insert @@ -65,12 +66,12 @@ def _validate_layout(self): for coord in seq_coords: qs.add(coord.q) last_q_timestep = q_timesteps[coord.q] - assert coord.t >= last_q_timestep, \ - f"Past timesteps are found in the sequence for codebook = {coord.q} at step {s}" + assert ( + coord.t >= last_q_timestep + ), f"Past timesteps are found in the sequence for codebook = {coord.q} at step {s}" q_timesteps[coord.q] = coord.t # each sequence step contains at max 1 coordinate per codebook - assert len(qs) == len(seq_coords), \ - f"Multiple entries for a same codebook are found at step {s}" + assert len(qs) == len(seq_coords), f"Multiple entries for a same codebook are found at step {s}" @property def num_sequence_steps(self): @@ -114,8 +115,9 @@ def get_first_step_with_timesteps(self, t: int, q: tp.Optional[int] = None) -> t steps_with_timesteps = self.get_steps_with_timestep(t, q) return steps_with_timesteps[0] if len(steps_with_timesteps) > 0 else None - def _build_pattern_sequence_scatter_indexes(self, timesteps: int, n_q: int, keep_only_valid_steps: bool, - device: tp.Union[torch.device, str] = 'cpu'): + def _build_pattern_sequence_scatter_indexes( + self, timesteps: int, n_q: int, keep_only_valid_steps: bool, device: tp.Union[torch.device, str] = "cpu" + ): """Build scatter indexes corresponding to the pattern, up to the provided sequence_steps. Args: @@ -175,10 +177,14 @@ def build_pattern_sequence(self, z: torch.Tensor, special_token: int, keep_only_ values = values.view(B, K, indexes.shape[-1]) return values, indexes, mask - def _build_reverted_sequence_scatter_indexes(self, sequence_steps: int, n_q: int, - keep_only_valid_steps: bool = False, - is_model_output: bool = False, - device: tp.Union[torch.device, str] = 'cpu'): + def _build_reverted_sequence_scatter_indexes( + self, + sequence_steps: int, + n_q: int, + keep_only_valid_steps: bool = False, + is_model_output: bool = False, + device: tp.Union[torch.device, str] = "cpu", + ): """Builds scatter indexes required to retrieve the original multi-codebook sequence from interleaving pattern. @@ -197,8 +203,9 @@ def _build_reverted_sequence_scatter_indexes(self, sequence_steps: int, n_q: int # TODO(jade): Do we want to further truncate to only valid timesteps here as well? timesteps = self.timesteps assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}" - assert sequence_steps <= len(ref_layout), \ - f"sequence to revert is longer than the defined pattern: {sequence_steps} > {len(ref_layout)}" + assert sequence_steps <= len( + ref_layout + ), f"sequence to revert is longer than the defined pattern: {sequence_steps} > {len(ref_layout)}" # ensure we take the appropriate indexes to keep the model output from the first special token as well if is_model_output and self.starts_with_special_token(): @@ -284,6 +291,7 @@ class CodebooksPatternProvider(ABC): cached (bool): if True, patterns for a given length are cached. In general that should be true for efficiency reason to avoid synchronization points. """ + def __init__(self, n_q: int, cached: bool = True): assert n_q > 0 self.n_q = n_q @@ -322,8 +330,10 @@ class DelayedPatternProvider(CodebooksPatternProvider): flatten_first (int): Flatten the first N timesteps. empty_initial (int): Prepend with N empty list of coordinates. """ - def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None, - flatten_first: int = 0, empty_initial: int = 0): + + def __init__( + self, n_q: int, delays: tp.Optional[tp.List[int]] = None, flatten_first: int = 0, empty_initial: int = 0 + ): super().__init__(n_q) if delays is None: delays = list(range(n_q)) @@ -362,6 +372,7 @@ class ParallelPatternProvider(DelayedPatternProvider): n_q (int): Number of codebooks. empty_initial (int): Prepend with N empty list of coordinates. """ + def __init__(self, n_q: int, empty_initial: int = 0): super().__init__(n_q, [0] * n_q, empty_initial=empty_initial) @@ -415,10 +426,12 @@ class UnrolledPatternProvider(CodebooksPatternProvider): Note that two codebooks that will be flattened to the same inner step should have the same delay, otherwise the pattern is considered as invalid. """ - FlattenedCodebook = namedtuple('FlattenedCodebook', ['codebooks', 'delay']) - def __init__(self, n_q: int, flattening: tp.Optional[tp.List[int]] = None, - delays: tp.Optional[tp.List[int]] = None): + FlattenedCodebook = namedtuple("FlattenedCodebook", ["codebooks", "delay"]) + + def __init__( + self, n_q: int, flattening: tp.Optional[tp.List[int]] = None, delays: tp.Optional[tp.List[int]] = None + ): super().__init__(n_q) if flattening is None: flattening = list(range(n_q)) @@ -444,7 +457,7 @@ def _build_flattened_codebooks(self, delays: tp.List[int], flattening: tp.List[i flat_codebook = flattened_codebooks[inner_step] assert flat_codebook.delay == delay, ( "Delay and flattening between codebooks is inconsistent: ", - "two codebooks flattened to the same position should have the same delay." + "two codebooks flattened to the same position should have the same delay.", ) flat_codebook.codebooks.append(q) flattened_codebooks[inner_step] = flat_codebook @@ -452,8 +465,7 @@ def _build_flattened_codebooks(self, delays: tp.List[int], flattening: tp.List[i @property def _num_inner_steps(self): - """Number of inner steps to unroll between timesteps in order to flatten the codebooks. - """ + """Number of inner steps to unroll between timesteps in order to flatten the codebooks.""" return max([inner_step for inner_step in self._flattened_codebooks.keys()]) + 1 def num_virtual_steps(self, timesteps: int) -> int: @@ -501,6 +513,7 @@ class CoarseFirstPattern(CodebooksPatternProvider): delays (list of int, optional): Delay for each of the codebooks. If delays not defined, each codebook is delayed by 1 compared to the previous one. """ + def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None): super().__init__(n_q) if delays is None: @@ -532,6 +545,7 @@ class MusicLMPattern(CodebooksPatternProvider): n_q (int): Number of codebooks. group_by (int): Number of codebooks to group together. """ + def __init__(self, n_q: int, group_by: int = 2): super().__init__(n_q) self.group_by = group_by @@ -542,4 +556,4 @@ def get_pattern(self, timesteps: int) -> Pattern: for t in range(timesteps): for q in range(offset, offset + self.group_by): out.append([LayoutCoord(t, q)]) - return Pattern(out, n_q=self.n_q, timesteps=timesteps) \ No newline at end of file + return Pattern(out, n_q=self.n_q, timesteps=timesteps) diff --git a/stable_audio_tools/models/conditioners.py b/stable_audio_tools/models/conditioners.py index e998ab10..a266b81a 100644 --- a/stable_audio_tools/models/conditioners.py +++ b/stable_audio_tools/models/conditioners.py @@ -1,28 +1,25 @@ -#Heavily influenced by https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/modules/conditioners.py +# Heavily influenced by https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/modules/conditioners.py -import torch -import logging, warnings +import gc +import logging import string import typing as tp -import gc +import warnings + +import torch +from torch import nn -from .adp import NumberEmbedder from ..inference.utils import set_audio_channels +from ..training.utils import copy_state_dict +from .adp import NumberEmbedder from .factory import create_pretransform_from_config from .pretransforms import Pretransform -from ..training.utils import copy_state_dict from .utils import load_ckpt_state_dict -from torch import nn class Conditioner(nn.Module): - def __init__( - self, - dim: int, - output_dim: int, - project_out: bool = False - ): - + def __init__(self, dim: int, output_dim: int, project_out: bool = False): + super().__init__() self.dim = dim @@ -31,13 +28,10 @@ def __init__( def forward(self, x: tp.Any) -> tp.Any: raise NotImplementedError() - + + class IntConditioner(Conditioner): - def __init__(self, - output_dim: int, - min_val: int=0, - max_val: int=512 - ): + def __init__(self, output_dim: int, min_val: int = 0, max_val: int = 512): super().__init__(output_dim, output_dim) self.min_val = min_val @@ -45,25 +39,23 @@ def __init__(self, self.int_embedder = nn.Embedding(max_val - min_val + 1, output_dim).requires_grad_(True) def forward(self, ints: tp.List[int], device=None) -> tp.Any: - - #self.int_embedder.to(device) - - ints = torch.tensor(ints).to(device) - ints = ints.clamp(self.min_val, self.max_val) - - int_embeds = self.int_embedder(ints).unsqueeze(1) - - return [int_embeds, torch.ones(int_embeds.shape[0], 1).to(device)] + + # self.int_embedder.to(device) + + ints = torch.tensor(ints).to(device) + ints = ints.clamp(self.min_val, self.max_val) + + int_embeds = self.int_embedder(ints).unsqueeze(1) + + return [int_embeds, torch.ones(int_embeds.shape[0], 1).to(device)] + class NumberConditioner(Conditioner): - ''' - Conditioner that takes a list of floats, normalizes them for a given range, and returns a list of embeddings - ''' - def __init__(self, - output_dim: int, - min_val: float=0, - max_val: float=1 - ): + """ + Conditioner that takes a list of floats, normalizes them for a given range, and returns a list of embeddings + """ + + def __init__(self, output_dim: int, min_val: float = 0, max_val: float = 1): super().__init__(output_dim, output_dim) self.min_val = min_val @@ -72,34 +64,37 @@ def __init__(self, self.embedder = NumberEmbedder(features=output_dim) def forward(self, floats: tp.List[float], device=None) -> tp.Any: - - # Cast the inputs to floats - floats = [float(x) for x in floats] - floats = torch.tensor(floats).to(device) + # Cast the inputs to floats + floats = [float(x) for x in floats] + + floats = torch.tensor(floats).to(device) + + floats = floats.clamp(self.min_val, self.max_val) - floats = floats.clamp(self.min_val, self.max_val) - - normalized_floats = (floats - self.min_val) / (self.max_val - self.min_val) + normalized_floats = (floats - self.min_val) / (self.max_val - self.min_val) - # Cast floats to same type as embedder - embedder_dtype = next(self.embedder.parameters()).dtype - normalized_floats = normalized_floats.to(embedder_dtype) + # Cast floats to same type as embedder + embedder_dtype = next(self.embedder.parameters()).dtype + normalized_floats = normalized_floats.to(embedder_dtype) + + float_embeds = self.embedder(normalized_floats).unsqueeze(1) + + return [float_embeds, torch.ones(float_embeds.shape[0], 1).to(device)] - float_embeds = self.embedder(normalized_floats).unsqueeze(1) - - return [float_embeds, torch.ones(float_embeds.shape[0], 1).to(device)] class CLAPTextConditioner(Conditioner): - def __init__(self, - output_dim: int, - clap_ckpt_path, - use_text_features = False, - feature_layer_ix: int = -1, - audio_model_type="HTSAT-base", - enable_fusion=True, - project_out: bool = False, - finetune: bool = False): + def __init__( + self, + output_dim: int, + clap_ckpt_path, + use_text_features=False, + feature_layer_ix: int = -1, + audio_model_type="HTSAT-base", + enable_fusion=True, + project_out: bool = False, + finetune: bool = False, + ): super().__init__(768 if use_text_features else 512, output_dim, project_out=project_out) self.use_text_features = use_text_features @@ -113,13 +108,15 @@ def __init__(self, warnings.simplefilter("ignore") try: import laion_clap - from laion_clap.clap_module.factory import load_state_dict as clap_load_state_dict - - model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=audio_model_type, device='cpu') + from laion_clap.clap_module.factory import ( + load_state_dict as clap_load_state_dict, + ) + + model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=audio_model_type, device="cpu") if self.finetune: self.model = model - else: + else: self.__dict__["model"] = model state_dict = clap_load_state_dict(clap_ckpt_path) @@ -146,7 +143,7 @@ def get_clap_features(self, prompts, layer_ix=-2, device: tp.Any = "cuda"): prompt_features = self.model.model.text_branch( input_ids=prompt_tokens["input_ids"].to(device=device, non_blocking=True), attention_mask=attention_mask, - output_hidden_states=True + output_hidden_states=True, )["hidden_states"][layer_ix] return prompt_features, attention_mask @@ -156,11 +153,15 @@ def forward(self, texts: tp.List[str], device: tp.Any = "cuda") -> tp.Any: if self.use_text_features: if len(texts) == 1: - text_features, text_attention_mask = self.get_clap_features([texts[0], ""], layer_ix=self.feature_layer_ix, device=device) + text_features, text_attention_mask = self.get_clap_features( + [texts[0], ""], layer_ix=self.feature_layer_ix, device=device + ) text_features = text_features[:1, ...] text_attention_mask = text_attention_mask[:1, ...] else: - text_features, text_attention_mask = self.get_clap_features(texts, layer_ix=self.feature_layer_ix, device=device) + text_features, text_attention_mask = self.get_clap_features( + texts, layer_ix=self.feature_layer_ix, device=device + ) return [self.proj_out(text_features), text_attention_mask] # Fix for CLAP bug when only one text is passed @@ -173,16 +174,19 @@ def forward(self, texts: tp.List[str], device: tp.Any = "cuda") -> tp.Any: return [self.proj_out(text_embedding), torch.ones(text_embedding.shape[0], 1).to(device)] + class CLAPAudioConditioner(Conditioner): - def __init__(self, - output_dim: int, - clap_ckpt_path, - audio_model_type="HTSAT-base", - enable_fusion=True, - project_out: bool = False): + def __init__( + self, + output_dim: int, + clap_ckpt_path, + audio_model_type="HTSAT-base", + enable_fusion=True, + project_out: bool = False, + ): super().__init__(512, output_dim, project_out=project_out) - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Suppress logging from transformers previous_level = logging.root.manager.disable @@ -191,13 +195,15 @@ def __init__(self, warnings.simplefilter("ignore") try: import laion_clap - from laion_clap.clap_module.factory import load_state_dict as clap_load_state_dict - - model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=audio_model_type, device='cpu') + from laion_clap.clap_module.factory import ( + load_state_dict as clap_load_state_dict, + ) + + model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=audio_model_type, device="cpu") if self.finetune: self.model = model - else: + else: self.__dict__["model"] = model state_dict = clap_load_state_dict(clap_ckpt_path) @@ -218,7 +224,9 @@ def __init__(self, gc.collect() torch.cuda.empty_cache() - def forward(self, audios: tp.Union[torch.Tensor, tp.List[torch.Tensor], tp.Tuple[torch.Tensor]] , device: tp.Any = "cuda") -> tp.Any: + def forward( + self, audios: tp.Union[torch.Tensor, tp.List[torch.Tensor], tp.Tuple[torch.Tensor]], device: tp.Any = "cuda" + ) -> tp.Any: self.model.to(device) @@ -235,12 +243,22 @@ def forward(self, audios: tp.Union[torch.Tensor, tp.List[torch.Tensor], tp.Tuple return [self.proj_out(audio_embedding), torch.ones(audio_embedding.shape[0], 1).to(device)] + class T5Conditioner(Conditioner): - T5_MODELS = ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b", - "google/flan-t5-small", "google/flan-t5-base", "google/flan-t5-large", - "google/flan-t5-xl", "google/flan-t5-xxl"] - + T5_MODELS = [ + "t5-small", + "t5-base", + "t5-large", + "t5-3b", + "t5-11b", + "google/flan-t5-small", + "google/flan-t5-base", + "google/flan-t5-large", + "google/flan-t5-xl", + "google/flan-t5-xxl", + ] + T5_MODEL_DIMS = { "t5-small": 512, "t5-base": 768, @@ -259,17 +277,17 @@ class T5Conditioner(Conditioner): } def __init__( - self, - output_dim: int, - t5_model_name: str = "t5-base", - max_length: str = 128, - enable_grad: bool = False, - project_out: bool = False + self, + output_dim: int, + t5_model_name: str = "t5-base", + max_length: str = 128, + enable_grad: bool = False, + project_out: bool = False, ): assert t5_model_name in self.T5_MODELS, f"Unknown T5 model name: {t5_model_name}" super().__init__(self.T5_MODEL_DIMS[t5_model_name], output_dim, project_out=project_out) - - from transformers import T5EncoderModel, AutoTokenizer + + from transformers import AutoTokenizer, T5EncoderModel self.max_length = max_length self.enable_grad = enable_grad @@ -283,18 +301,22 @@ def __init__( # self.tokenizer = T5Tokenizer.from_pretrained(t5_model_name, model_max_length = max_length) # model = T5EncoderModel.from_pretrained(t5_model_name, max_length=max_length).train(enable_grad).requires_grad_(enable_grad) self.tokenizer = AutoTokenizer.from_pretrained(t5_model_name) - model = T5EncoderModel.from_pretrained(t5_model_name).train(enable_grad).requires_grad_(enable_grad).to(torch.float16) + model = ( + T5EncoderModel.from_pretrained(t5_model_name) + .train(enable_grad) + .requires_grad_(enable_grad) + .to(torch.float16) + ) finally: logging.disable(previous_level) - + if self.enable_grad: self.model = model - else: + else: self.__dict__["model"] = model - def forward(self, texts: tp.List[str], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]: - + self.model.to(device) self.proj_out.to(device) @@ -310,18 +332,17 @@ def forward(self, texts: tp.List[str], device: tp.Union[torch.device, str]) -> t attention_mask = encoded["attention_mask"].to(device).to(torch.bool) self.model.eval() - + with torch.cuda.amp.autocast(dtype=torch.float16) and torch.set_grad_enabled(self.enable_grad): - embeddings = self.model( - input_ids=input_ids, attention_mask=attention_mask - )["last_hidden_state"] - + embeddings = self.model(input_ids=input_ids, attention_mask=attention_mask)["last_hidden_state"] + embeddings = self.proj_out(embeddings.float()) embeddings = embeddings * attention_mask.unsqueeze(-1).float() return embeddings, attention_mask - + + class PhonemeConditioner(Conditioner): """ A conditioner that turns text into phonemes and embeds them using a lookup table @@ -334,13 +355,13 @@ class PhonemeConditioner(Conditioner): """ def __init__( - self, - output_dim: int, - max_length: int = 1024, - project_out: bool = False, + self, + output_dim: int, + max_length: int = 1024, + project_out: bool = False, ): super().__init__(output_dim, output_dim, project_out=project_out) - + from g2p_en import G2p self.max_length = max_length @@ -351,33 +372,36 @@ def __init__( self.phoneme_embedder = nn.Embedding(len(self.g2p.phonemes) + 2, output_dim) def forward(self, texts: tp.List[str], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]: - + self.phoneme_embedder.to(device) self.proj_out.to(device) - batch_phonemes = [self.g2p(text) for text in texts] # shape [batch_size, length] - + batch_phonemes = [self.g2p(text) for text in texts] # shape [batch_size, length] + phoneme_ignore = [" ", *string.punctuation] # Remove ignored phonemes and cut to max length batch_phonemes = [[p if p not in phoneme_ignore else "_" for p in phonemes] for phonemes in batch_phonemes] # Convert to ids - phoneme_ids = [[self.g2p.p2idx[p] + 2 if p in self.g2p.p2idx else 1 for p in phonemes] for phonemes in batch_phonemes] + phoneme_ids = [ + [self.g2p.p2idx[p] + 2 if p in self.g2p.p2idx else 1 for p in phonemes] for phonemes in batch_phonemes + ] - #Pad to match longest and make a mask tensor for the padding + # Pad to match longest and make a mask tensor for the padding longest = max([len(ids) for ids in phoneme_ids]) phoneme_ids = [ids + [0] * (longest - len(ids)) for ids in phoneme_ids] - + phoneme_ids = torch.tensor(phoneme_ids).to(device) # Convert to embeddings phoneme_embeds = self.phoneme_embedder(phoneme_ids) - + phoneme_embeds = self.proj_out(phoneme_embeds) return phoneme_embeds, torch.ones(phoneme_embeds.shape[0], phoneme_embeds.shape[1]).to(device) - + + class TokenizerLUTConditioner(Conditioner): """ A conditioner that embeds text using a lookup table on a pretrained tokenizer's vocabulary @@ -390,17 +414,17 @@ class TokenizerLUTConditioner(Conditioner): """ def __init__( - self, - tokenizer_name: str, # Name of a tokenizer from the Hugging Face transformers library - output_dim: int, - max_length: int = 1024, - project_out: bool = False, + self, + tokenizer_name: str, # Name of a tokenizer from the Hugging Face transformers library + output_dim: int, + max_length: int = 1024, + project_out: bool = False, ): super().__init__(output_dim, output_dim, project_out=project_out) - + from transformers import AutoTokenizer - # Suppress logging from transformers + # Suppress logging from transformers previous_level = logging.root.manager.disable logging.disable(logging.ERROR) with warnings.catch_warnings(): @@ -427,15 +451,16 @@ def forward(self, texts: tp.List[str], device: tp.Union[torch.device, str]) -> t input_ids = encoded["input_ids"].to(device) attention_mask = encoded["attention_mask"].to(device).to(torch.bool) - + embeddings = self.token_embedder(input_ids) - + embeddings = self.proj_out(embeddings) embeddings = embeddings * attention_mask.unsqueeze(-1).float() return embeddings, attention_mask + class PretransformConditioner(Conditioner): """ A conditioner that uses a pretransform's encoder for conditioning @@ -444,12 +469,17 @@ class PretransformConditioner(Conditioner): pretransform: an instantiated pretransform to use for conditioning output_dim: the dimension of the output embeddings """ + def __init__(self, pretransform: Pretransform, output_dim: int): super().__init__(pretransform.encoded_channels, output_dim) self.pretransform = pretransform - def forward(self, audio: tp.Union[torch.Tensor, tp.List[torch.Tensor], tp.Tuple[torch.Tensor]], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]: + def forward( + self, + audio: tp.Union[torch.Tensor, tp.List[torch.Tensor], tp.Tuple[torch.Tensor]], + device: tp.Union[torch.device, str], + ) -> tp.Tuple[torch.Tensor, torch.Tensor]: self.pretransform.to(device) self.proj_out.to(device) @@ -459,13 +489,14 @@ def forward(self, audio: tp.Union[torch.Tensor, tp.List[torch.Tensor], tp.Tuple[ # Convert audio to pretransform input channels audio = set_audio_channels(audio, self.pretransform.io_channels) - + latents = self.pretransform.encode(audio) latents = self.proj_out(latents) return [latents, torch.ones(latents.shape[0], latents.shape[2]).to(latents.device)] + class MultiConditioner(nn.Module): """ A module that applies multiple conditioners to an input dictionary based on the keys @@ -474,13 +505,16 @@ class MultiConditioner(nn.Module): conditioners: a dictionary of conditioners with keys corresponding to the keys of the conditioning input dictionary (e.g. "prompt") default_keys: a dictionary of default keys to use if the key is not in the input dictionary (e.g. {"prompt_t5": "prompt"}) """ + def __init__(self, conditioners: tp.Dict[str, Conditioner], default_keys: tp.Dict[str, str] = {}): super().__init__() self.conditioners = nn.ModuleDict(conditioners) self.default_keys = default_keys - def forward(self, batch_metadata: tp.List[tp.Dict[str, tp.Any]], device: tp.Union[torch.device, str]) -> tp.Dict[str, tp.Any]: + def forward( + self, batch_metadata: tp.List[tp.Dict[str, tp.Any]], device: tp.Union[torch.device, str] + ) -> tp.Dict[str, tp.Any]: output = {} for key, conditioner in self.conditioners.items(): @@ -496,19 +530,24 @@ def forward(self, batch_metadata: tp.List[tp.Dict[str, tp.Any]], device: tp.Unio else: raise ValueError(f"Conditioner key {condition_key} not found in batch metadata") - #Unwrap the condition info if it's a single-element list or tuple, this is to support collation functions that wrap everything in a list - if isinstance(x[condition_key], list) or isinstance(x[condition_key], tuple) and len(x[condition_key]) == 1: + # Unwrap the condition info if it's a single-element list or tuple, this is to support collation functions that wrap everything in a list + if ( + isinstance(x[condition_key], list) + or isinstance(x[condition_key], tuple) + and len(x[condition_key]) == 1 + ): conditioner_input = x[condition_key][0] - + else: conditioner_input = x[condition_key] conditioner_inputs.append(conditioner_input) - + output[key] = conditioner(conditioner_inputs, device) return output - + + def create_multi_conditioner_from_conditioning_config(config: tp.Dict[str, tp.Any]) -> MultiConditioner: """ Create a MultiConditioner from a conditioning config dictionary @@ -519,7 +558,7 @@ def create_multi_conditioner_from_conditioning_config(config: tp.Dict[str, tp.An """ conditioners = {} cond_dim = config["cond_dim"] - + default_keys = config.get("default_keys", {}) for conditioner_info in config["configs"]: @@ -528,7 +567,7 @@ def create_multi_conditioner_from_conditioning_config(config: tp.Dict[str, tp.An conditioner_type = conditioner_info["type"] conditioner_config = {"output_dim": cond_dim} - + conditioner_config.update(conditioner_info["config"]) if conditioner_type == "t5": @@ -549,7 +588,9 @@ def create_multi_conditioner_from_conditioning_config(config: tp.Dict[str, tp.An sample_rate = conditioner_config.pop("sample_rate", None) assert sample_rate is not None, "Sample rate must be specified for pretransform conditioners" - pretransform = create_pretransform_from_config(conditioner_config.pop("pretransform_config"), sample_rate=sample_rate) + pretransform = create_pretransform_from_config( + conditioner_config.pop("pretransform_config"), sample_rate=sample_rate + ) if conditioner_config.get("pretransform_ckpt_path", None) is not None: pretransform.load_state_dict(load_ckpt_state_dict(conditioner_config.pop("pretransform_ckpt_path"))) @@ -558,4 +599,4 @@ def create_multi_conditioner_from_conditioning_config(config: tp.Dict[str, tp.An else: raise ValueError(f"Unknown conditioner type: {conditioner_type}") - return MultiConditioner(conditioners, default_keys=default_keys) \ No newline at end of file + return MultiConditioner(conditioners, default_keys=default_keys) diff --git a/stable_audio_tools/models/diffusion.py b/stable_audio_tools/models/diffusion.py index aead49e2..265f5c98 100644 --- a/stable_audio_tools/models/diffusion.py +++ b/stable_audio_tools/models/diffusion.py @@ -1,20 +1,33 @@ +import typing as tp +from functools import partial +from time import time + +import numpy as np import torch from torch import nn from torch.nn import functional as F -from functools import partial -import numpy as np -import typing as tp -from .blocks import ResConvBlock, FourierFeatures, Upsample1d, Upsample1d_2, Downsample1d, Downsample1d_2, SelfAttention1d, SkipBlock, expand_to_planes -from .conditioners import MultiConditioner, create_multi_conditioner_from_conditioning_config +from ..inference.generation import generate_diffusion_cond +from .adp import UNet1d, UNetCFG1d +from .blocks import ( + Downsample1d, + Downsample1d_2, + FourierFeatures, + ResConvBlock, + SelfAttention1d, + SkipBlock, + Upsample1d, + Upsample1d_2, + expand_to_planes, +) +from .conditioners import ( + MultiConditioner, + create_multi_conditioner_from_conditioning_config, +) from .dit import DiffusionTransformer from .factory import create_pretransform_from_config from .pretransforms import Pretransform -from ..inference.generation import generate_diffusion_cond - -from .adp import UNetCFG1d, UNet1d -from time import time class Profiler: @@ -33,6 +46,7 @@ def __repr__(self): rep += 80 * "=" + "\n\n\n" return rep + class DiffusionModel(nn.Module): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -40,15 +54,16 @@ def __init__(self, *args, **kwargs): def forward(self, x, t, **kwargs): raise NotImplementedError() + class DiffusionModelWrapper(nn.Module): def __init__( - self, - model: DiffusionModel, - io_channels, - sample_size, - sample_rate, - min_input_length, - pretransform: tp.Optional[Pretransform] = None, + self, + model: DiffusionModel, + io_channels, + sample_size, + sample_rate, + min_input_length, + pretransform: tp.Optional[Pretransform] = None, ): super().__init__() self.io_channels = io_channels @@ -66,54 +81,61 @@ def __init__( def forward(self, x, t, **kwargs): return self.model(x, t, **kwargs) + class ConditionedDiffusionModel(nn.Module): - def __init__(self, - *args, - supports_cross_attention: bool = False, - supports_input_concat: bool = False, - supports_global_cond: bool = False, - supports_prepend_cond: bool = False, - **kwargs): + def __init__( + self, + *args, + supports_cross_attention: bool = False, + supports_input_concat: bool = False, + supports_global_cond: bool = False, + supports_prepend_cond: bool = False, + **kwargs, + ): super().__init__(*args, **kwargs) self.supports_cross_attention = supports_cross_attention self.supports_input_concat = supports_input_concat self.supports_global_cond = supports_global_cond self.supports_prepend_cond = supports_prepend_cond - def forward(self, - x: torch.Tensor, - t: torch.Tensor, - cross_attn_cond: torch.Tensor = None, - cross_attn_mask: torch.Tensor = None, - input_concat_cond: torch.Tensor = None, - global_embed: torch.Tensor = None, - prepend_cond: torch.Tensor = None, - prepend_cond_mask: torch.Tensor = None, - cfg_scale: float = 1.0, - cfg_dropout_prob: float = 0.0, - batch_cfg: bool = False, - rescale_cfg: bool = False, - **kwargs): + def forward( + self, + x: torch.Tensor, + t: torch.Tensor, + cross_attn_cond: torch.Tensor = None, + cross_attn_mask: torch.Tensor = None, + input_concat_cond: torch.Tensor = None, + global_embed: torch.Tensor = None, + prepend_cond: torch.Tensor = None, + prepend_cond_mask: torch.Tensor = None, + cfg_scale: float = 1.0, + cfg_dropout_prob: float = 0.0, + batch_cfg: bool = False, + rescale_cfg: bool = False, + **kwargs, + ): raise NotImplementedError() + class ConditionedDiffusionModelWrapper(nn.Module): """ A diffusion model that takes in conditioning """ + def __init__( - self, - model: ConditionedDiffusionModel, - conditioner: MultiConditioner, - io_channels, - sample_rate, - min_input_length: int, - diffusion_objective: tp.Literal["v", "rectified_flow"] = "v", - pretransform: tp.Optional[Pretransform] = None, - cross_attn_cond_ids: tp.List[str] = [], - global_cond_ids: tp.List[str] = [], - input_concat_ids: tp.List[str] = [], - prepend_cond_ids: tp.List[str] = [], - ): + self, + model: ConditionedDiffusionModel, + conditioner: MultiConditioner, + io_channels, + sample_rate, + min_input_length: int, + diffusion_objective: tp.Literal["v", "rectified_flow"] = "v", + pretransform: tp.Optional[Pretransform] = None, + cross_attn_cond_ids: tp.List[str] = [], + global_cond_ids: tp.List[str] = [], + input_concat_ids: tp.List[str] = [], + prepend_cond_ids: tp.List[str] = [], + ): super().__init__() self.model = model @@ -195,7 +217,7 @@ def get_conditioning_inputs(self, conditioning_tensors: tp.Dict[str, tp.Any], ne "negative_cross_attn_cond": cross_attention_input, "negative_cross_attn_mask": cross_attention_masks, "negative_global_cond": global_cond, - "negative_input_concat_cond": input_concat_cond + "negative_input_concat_cond": input_concat_cond, } else: return { @@ -204,7 +226,7 @@ def get_conditioning_inputs(self, conditioning_tensors: tp.Dict[str, tp.Any], ne "global_cond": global_cond, "input_concat_cond": input_concat_cond, "prepend_cond": prepend_cond, - "prepend_cond_mask": prepend_cond_mask + "prepend_cond_mask": prepend_cond_mask, } def forward(self, x: torch.Tensor, t: torch.Tensor, cond: tp.Dict[str, tp.Any], **kwargs): @@ -213,12 +235,9 @@ def forward(self, x: torch.Tensor, t: torch.Tensor, cond: tp.Dict[str, tp.Any], def generate(self, *args, **kwargs): return generate_diffusion_cond(self, *args, **kwargs) + class UNetCFG1DWrapper(ConditionedDiffusionModel): - def __init__( - self, - *args, - **kwargs - ): + def __init__(self, *args, **kwargs): super().__init__(supports_cross_attention=True, supports_global_cond=True, supports_input_concat=True) self.model = UNetCFG1d(*args, **kwargs) @@ -227,24 +246,26 @@ def __init__( for param in self.model.parameters(): param *= 0.5 - def forward(self, - x, - t, - cross_attn_cond=None, - cross_attn_mask=None, - input_concat_cond=None, - global_cond=None, - cfg_scale=1.0, - cfg_dropout_prob: float = 0.0, - batch_cfg: bool = False, - rescale_cfg: bool = False, - negative_cross_attn_cond=None, - negative_cross_attn_mask=None, - negative_global_cond=None, - negative_input_concat_cond=None, - prepend_cond=None, - prepend_cond_mask=None, - **kwargs): + def forward( + self, + x, + t, + cross_attn_cond=None, + cross_attn_mask=None, + input_concat_cond=None, + global_cond=None, + cfg_scale=1.0, + cfg_dropout_prob: float = 0.0, + batch_cfg: bool = False, + rescale_cfg: bool = False, + negative_cross_attn_cond=None, + negative_cross_attn_mask=None, + negative_global_cond=None, + negative_input_concat_cond=None, + prepend_cond=None, + prepend_cond_mask=None, + **kwargs, + ): p = Profiler() p.tick("start") @@ -266,19 +287,17 @@ def forward(self, rescale_cfg=rescale_cfg, negative_embedding=negative_cross_attn_cond, negative_embedding_mask=negative_cross_attn_mask, - **kwargs) + **kwargs, + ) p.tick("UNetCFG1D forward") - #print(f"Profiler: {p}") + # print(f"Profiler: {p}") return outputs + class UNet1DCondWrapper(ConditionedDiffusionModel): - def __init__( - self, - *args, - **kwargs - ): + def __init__(self, *args, **kwargs): super().__init__(supports_cross_attention=False, supports_global_cond=True, supports_input_concat=True) self.model = UNet1d(*args, **kwargs) @@ -287,50 +306,43 @@ def __init__( for param in self.model.parameters(): param *= 0.5 - def forward(self, - x, - t, - input_concat_cond=None, - global_cond=None, - cross_attn_cond=None, - cross_attn_mask=None, - prepend_cond=None, - prepend_cond_mask=None, - cfg_scale=1.0, - cfg_dropout_prob: float = 0.0, - batch_cfg: bool = False, - rescale_cfg: bool = False, - negative_cross_attn_cond=None, - negative_cross_attn_mask=None, - negative_global_cond=None, - negative_input_concat_cond=None, - **kwargs): + def forward( + self, + x, + t, + input_concat_cond=None, + global_cond=None, + cross_attn_cond=None, + cross_attn_mask=None, + prepend_cond=None, + prepend_cond_mask=None, + cfg_scale=1.0, + cfg_dropout_prob: float = 0.0, + batch_cfg: bool = False, + rescale_cfg: bool = False, + negative_cross_attn_cond=None, + negative_cross_attn_mask=None, + negative_global_cond=None, + negative_input_concat_cond=None, + **kwargs, + ): channels_list = None if input_concat_cond is not None: # Interpolate input_concat_cond to the same length as x if input_concat_cond.shape[2] != x.shape[2]: - input_concat_cond = F.interpolate(input_concat_cond, (x.shape[2], ), mode='nearest') + input_concat_cond = F.interpolate(input_concat_cond, (x.shape[2],), mode="nearest") channels_list = [input_concat_cond] - outputs = self.model( - x, - t, - features=global_cond, - channels_list=channels_list, - **kwargs) + outputs = self.model(x, t, features=global_cond, channels_list=channels_list, **kwargs) return outputs + class UNet1DUncondWrapper(DiffusionModel): - def __init__( - self, - in_channels, - *args, - **kwargs - ): + def __init__(self, in_channels, *args, **kwargs): super().__init__() self.model = UNet1d(in_channels=in_channels, *args, **kwargs) @@ -344,12 +356,9 @@ def __init__( def forward(self, x, t, **kwargs): return self.model(x, t, **kwargs) + class DAU1DCondWrapper(ConditionedDiffusionModel): - def __init__( - self, - *args, - **kwargs - ): + def __init__(self, *args, **kwargs): super().__init__(supports_cross_attention=False, supports_global_cond=False, supports_input_concat=True) self.model = DiffusionAttnUnet1D(*args, **kwargs) @@ -358,40 +367,43 @@ def __init__( for param in self.model.parameters(): param *= 0.5 - def forward(self, - x, - t, - input_concat_cond=None, - cross_attn_cond=None, - cross_attn_mask=None, - global_cond=None, - cfg_scale=1.0, - cfg_dropout_prob: float = 0.0, - batch_cfg: bool = False, - rescale_cfg: bool = False, - negative_cross_attn_cond=None, - negative_cross_attn_mask=None, - negative_global_cond=None, - negative_input_concat_cond=None, - prepend_cond=None, - **kwargs): - - return self.model(x, t, cond = input_concat_cond) + def forward( + self, + x, + t, + input_concat_cond=None, + cross_attn_cond=None, + cross_attn_mask=None, + global_cond=None, + cfg_scale=1.0, + cfg_dropout_prob: float = 0.0, + batch_cfg: bool = False, + rescale_cfg: bool = False, + negative_cross_attn_cond=None, + negative_cross_attn_mask=None, + negative_global_cond=None, + negative_input_concat_cond=None, + prepend_cond=None, + **kwargs, + ): + + return self.model(x, t, cond=input_concat_cond) + class DiffusionAttnUnet1D(nn.Module): def __init__( self, - io_channels = 2, + io_channels=2, depth=14, - n_attn_layers = 6, - channels = [128, 128, 256, 256] + [512] * 10, - cond_dim = 0, - cond_noise_aug = False, - kernel_size = 5, - learned_resample = False, - strides = [2] * 13, - conv_bias = True, - use_snake = False + n_attn_layers=6, + channels=[128, 128, 256, 256] + [512] * 10, + cond_dim=0, + cond_noise_aug=False, + kernel_size=5, + learned_resample=False, + strides=[2] * 13, + conv_bias=True, + use_snake=False, ): super().__init__() @@ -410,11 +422,11 @@ def __init__( block = nn.Identity() - conv_block = partial(ResConvBlock, kernel_size=kernel_size, conv_bias = conv_bias, use_snake=use_snake) + conv_block = partial(ResConvBlock, kernel_size=kernel_size, conv_bias=conv_bias, use_snake=use_snake) for i in range(depth, 0, -1): c = channels[i - 1] - stride = strides[i-1] + stride = strides[i - 1] if stride > 2 and not learned_resample: raise ValueError("Must have stride 2 without learned resampling") @@ -422,27 +434,25 @@ def __init__( c_prev = channels[i - 2] add_attn = i >= attn_layer and n_attn_layers > 0 block = SkipBlock( - Downsample1d_2(c_prev, c_prev, stride) if (learned_resample or stride == 1) else Downsample1d("cubic"), + ( + Downsample1d_2(c_prev, c_prev, stride) + if (learned_resample or stride == 1) + else Downsample1d("cubic") + ), conv_block(c_prev, c, c), - SelfAttention1d( - c, c // 32) if add_attn else nn.Identity(), + SelfAttention1d(c, c // 32) if add_attn else nn.Identity(), conv_block(c, c, c), - SelfAttention1d( - c, c // 32) if add_attn else nn.Identity(), + SelfAttention1d(c, c // 32) if add_attn else nn.Identity(), conv_block(c, c, c), - SelfAttention1d( - c, c // 32) if add_attn else nn.Identity(), + SelfAttention1d(c, c // 32) if add_attn else nn.Identity(), block, conv_block(c * 2 if i != depth else c, c, c), - SelfAttention1d( - c, c // 32) if add_attn else nn.Identity(), + SelfAttention1d(c, c // 32) if add_attn else nn.Identity(), conv_block(c, c, c), - SelfAttention1d( - c, c // 32) if add_attn else nn.Identity(), + SelfAttention1d(c, c // 32) if add_attn else nn.Identity(), conv_block(c, c, c_prev), - SelfAttention1d(c_prev, c_prev // - 32) if add_attn else nn.Identity(), - Upsample1d_2(c_prev, c_prev, stride) if learned_resample else Upsample1d(kernel="cubic") + SelfAttention1d(c_prev, c_prev // 32) if add_attn else nn.Identity(), + Upsample1d_2(c_prev, c_prev, stride) if learned_resample else Upsample1d(kernel="cubic"), ) else: cond_embed_dim = 16 if not self.cond_noise_aug else 32 @@ -469,7 +479,7 @@ def forward(self, x, t, cond=None, cond_aug_scale=None): if cond is not None: if cond.shape[2] != x.shape[2]: - cond = F.interpolate(cond, (x.shape[2], ), mode='linear', align_corners=False) + cond = F.interpolate(cond, (x.shape[2],), mode="linear", align_corners=False) if self.cond_noise_aug: # Get a random number between 0 and 1, uniformly sampled @@ -492,12 +502,9 @@ def forward(self, x, t, cond=None, cond_aug_scale=None): return outputs + class DiTWrapper(ConditionedDiffusionModel): - def __init__( - self, - *args, - **kwargs - ): + def __init__(self, *args, **kwargs): super().__init__(supports_cross_attention=True, supports_global_cond=False, supports_input_concat=False) self.model = DiffusionTransformer(*args, **kwargs) @@ -506,28 +513,30 @@ def __init__( for param in self.model.parameters(): param *= 0.5 - def forward(self, - x, - t, - cross_attn_cond=None, - cross_attn_mask=None, - negative_cross_attn_cond=None, - negative_cross_attn_mask=None, - input_concat_cond=None, - negative_input_concat_cond=None, - global_cond=None, - negative_global_cond=None, - prepend_cond=None, - prepend_cond_mask=None, - cfg_scale=1.0, - cfg_dropout_prob: float = 0.0, - batch_cfg: bool = True, - rescale_cfg: bool = False, - scale_phi: float = 0.0, - **kwargs): + def forward( + self, + x, + t, + cross_attn_cond=None, + cross_attn_mask=None, + negative_cross_attn_cond=None, + negative_cross_attn_mask=None, + input_concat_cond=None, + negative_input_concat_cond=None, + global_cond=None, + negative_global_cond=None, + prepend_cond=None, + prepend_cond_mask=None, + cfg_scale=1.0, + cfg_dropout_prob: float = 0.0, + batch_cfg: bool = True, + rescale_cfg: bool = False, + scale_phi: float = 0.0, + **kwargs, + ): assert batch_cfg, "batch_cfg must be True for DiTWrapper" - #assert negative_input_concat_cond is None, "negative_input_concat_cond is not supported for DiTWrapper" + # assert negative_input_concat_cond is None, "negative_input_concat_cond is not supported for DiTWrapper" return self.model( x, @@ -543,15 +552,12 @@ def forward(self, cfg_dropout_prob=cfg_dropout_prob, scale_phi=scale_phi, global_embed=global_cond, - **kwargs) + **kwargs, + ) + class DiTUncondWrapper(DiffusionModel): - def __init__( - self, - in_channels, - *args, - **kwargs - ): + def __init__(self, in_channels, *args, **kwargs): super().__init__() self.model = DiffusionTransformer(io_channels=in_channels, *args, **kwargs) @@ -565,12 +571,13 @@ def __init__( def forward(self, x, t, **kwargs): return self.model(x, t, **kwargs) + def create_diffusion_uncond_from_config(config: tp.Dict[str, tp.Any]): diffusion_uncond_config = config["model"] - model_type = diffusion_uncond_config.get('type', None) + model_type = diffusion_uncond_config.get("type", None) - diffusion_config = diffusion_uncond_config.get('config', {}) + diffusion_config = diffusion_uncond_config.get("config", {}) assert model_type is not None, "Must specify model type in config" @@ -588,32 +595,29 @@ def create_diffusion_uncond_from_config(config: tp.Dict[str, tp.Any]): else: min_input_length = 1 - if model_type == 'DAU1d': + if model_type == "DAU1d": + + model = DiffusionAttnUnet1D(**diffusion_config) - model = DiffusionAttnUnet1D( - **diffusion_config - ) - elif model_type == "adp_uncond_1d": - model = UNet1DUncondWrapper( - **diffusion_config - ) + model = UNet1DUncondWrapper(**diffusion_config) elif model_type == "dit": - model = DiTUncondWrapper( - **diffusion_config - ) + model = DiTUncondWrapper(**diffusion_config) else: - raise NotImplementedError(f'Unknown model type: {model_type}') + raise NotImplementedError(f"Unknown model type: {model_type}") + + return DiffusionModelWrapper( + model, + io_channels=model.io_channels, + sample_size=sample_size, + sample_rate=sample_rate, + pretransform=pretransform, + min_input_length=min_input_length, + ) - return DiffusionModelWrapper(model, - io_channels=model.io_channels, - sample_size=sample_size, - sample_rate=sample_rate, - pretransform=pretransform, - min_input_length=min_input_length) def create_diffusion_cond_from_config(config: tp.Dict[str, tp.Any]): @@ -621,40 +625,40 @@ def create_diffusion_cond_from_config(config: tp.Dict[str, tp.Any]): model_type = config["model_type"] - diffusion_config = model_config.get('diffusion', None) + diffusion_config = model_config.get("diffusion", None) assert diffusion_config is not None, "Must specify diffusion config" - diffusion_model_type = diffusion_config.get('type', None) + diffusion_model_type = diffusion_config.get("type", None) assert diffusion_model_type is not None, "Must specify diffusion model type" - diffusion_model_config = diffusion_config.get('config', None) + diffusion_model_config = diffusion_config.get("config", None) assert diffusion_model_config is not None, "Must specify diffusion model config" - if diffusion_model_type == 'adp_cfg_1d': + if diffusion_model_type == "adp_cfg_1d": diffusion_model = UNetCFG1DWrapper(**diffusion_model_config) - elif diffusion_model_type == 'adp_1d': + elif diffusion_model_type == "adp_1d": diffusion_model = UNet1DCondWrapper(**diffusion_model_config) - elif diffusion_model_type == 'dit': + elif diffusion_model_type == "dit": diffusion_model = DiTWrapper(**diffusion_model_config) - io_channels = model_config.get('io_channels', None) + io_channels = model_config.get("io_channels", None) assert io_channels is not None, "Must specify io_channels in model config" - sample_rate = config.get('sample_rate', None) + sample_rate = config.get("sample_rate", None) assert sample_rate is not None, "Must specify sample_rate in config" - diffusion_objective = diffusion_config.get('diffusion_objective', 'v') + diffusion_objective = diffusion_config.get("diffusion_objective", "v") - conditioning_config = model_config.get('conditioning', None) + conditioning_config = model_config.get("conditioning", None) conditioner = None if conditioning_config is not None: conditioner = create_multi_conditioner_from_conditioning_config(conditioning_config) - cross_attention_ids = diffusion_config.get('cross_attention_cond_ids', []) - global_cond_ids = diffusion_config.get('global_cond_ids', []) - input_concat_ids = diffusion_config.get('input_concat_ids', []) - prepend_cond_ids = diffusion_config.get('prepend_cond_ids', []) + cross_attention_ids = diffusion_config.get("cross_attention_cond_ids", []) + global_cond_ids = diffusion_config.get("global_cond_ids", []) + input_concat_ids = diffusion_config.get("input_concat_ids", []) + prepend_cond_ids = diffusion_config.get("prepend_cond_ids", []) pretransform = model_config.get("pretransform", None) @@ -684,8 +688,9 @@ def create_diffusion_cond_from_config(config: tp.Dict[str, tp.Any]): if prior_type == "mono_stereo": from .diffusion_prior import MonoToStereoDiffusionPrior + wrapper_fn = MonoToStereoDiffusionPrior - + return wrapper_fn( diffusion_model, conditioner, @@ -697,5 +702,5 @@ def create_diffusion_cond_from_config(config: tp.Dict[str, tp.Any]): prepend_cond_ids=prepend_cond_ids, pretransform=pretransform, io_channels=io_channels, - **extra_kwargs - ) \ No newline at end of file + **extra_kwargs, + ) diff --git a/stable_audio_tools/models/diffusion_prior.py b/stable_audio_tools/models/diffusion_prior.py index 8529ad28..2e973406 100644 --- a/stable_audio_tools/models/diffusion_prior.py +++ b/stable_audio_tools/models/diffusion_prior.py @@ -1,30 +1,33 @@ -from enum import Enum import typing as tp - -from .diffusion import ConditionedDiffusionModelWrapper -from ..inference.generation import generate_diffusion_cond -from ..inference.utils import prepare_audio +from enum import Enum import torch from torch.nn import functional as F from torchaudio import transforms as T +from ..inference.generation import generate_diffusion_cond +from ..inference.utils import prepare_audio +from .diffusion import ConditionedDiffusionModelWrapper + + # Define prior types enum class PriorType(Enum): MonoToStereo = 1 + class DiffusionPrior(ConditionedDiffusionModelWrapper): - def __init__(self, *args, prior_type: PriorType=None, **kwargs): + def __init__(self, *args, prior_type: PriorType = None, **kwargs): super().__init__(*args, **kwargs) - self.prior_type = prior_type + self.prior_type = prior_type + class MonoToStereoDiffusionPrior(DiffusionPrior): def __init__(self, *args, **kwargs): super().__init__(*args, prior_type=PriorType.MonoToStereo, **kwargs) def stereoize( - self, - audio: torch.Tensor, # (batch, channels, time) + self, + audio: torch.Tensor, # (batch, channels, time) in_sr: int, steps: int, sampler_kwargs: dict = {}, @@ -67,13 +70,13 @@ def stereoize( conditioning = {"source": [dual_mono]} stereo_audio = generate_diffusion_cond( - self, + self, conditioning_tensors=conditioning, steps=steps, sample_size=padded_input_length, sample_rate=sample_rate, device=device, **sampler_kwargs, - ) + ) - return stereo_audio \ No newline at end of file + return stereo_audio diff --git a/stable_audio_tools/models/discriminators.py b/stable_audio_tools/models/discriminators.py index b593168d..ed460b83 100644 --- a/stable_audio_tools/models/discriminators.py +++ b/stable_audio_tools/models/discriminators.py @@ -1,18 +1,21 @@ +import typing as tp +from functools import reduce + +import numpy as np import torch import torch.nn as nn import torch.nn.functional as F -import numpy as np -from functools import reduce -import typing as tp -from einops import rearrange from audiotools import AudioSignal, STFTParams from dac.model.discriminator import WNConv1d, WNConv2d +from einops import rearrange + def get_hinge_losses(score_real, score_fake): gen_loss = -score_fake.mean() dis_loss = torch.relu(1 - score_real).mean() + torch.relu(1 + score_fake).mean() return dis_loss, gen_loss + class EncodecDiscriminator(nn.Module): def __init__(self, *args, **kwargs): @@ -27,12 +30,12 @@ def forward(self, x): return logits, features def loss(self, x, y): - feature_matching_distance = 0. + feature_matching_distance = 0.0 logits_true, feature_true = self.forward(x) logits_fake, feature_fake = self.forward(y) - dis_loss = torch.tensor(0.) - adv_loss = torch.tensor(0.) + dis_loss = torch.tensor(0.0) + adv_loss = torch.tensor(0.0) for i, (scale_true, scale_fake) in enumerate(zip(feature_true, feature_fake)): @@ -41,7 +44,8 @@ def loss(self, x, y): lambda x, y: abs(x - y).mean(), scale_true, scale_fake, - )) / len(scale_true) + ) + ) / len(scale_true) _dis, _adv = get_hinge_losses( logits_true[i], @@ -53,12 +57,14 @@ def loss(self, x, y): return dis_loss, adv_loss, feature_matching_distance + # Discriminators from oobleck IndividualDiscriminatorOut = tp.Tuple[torch.Tensor, tp.Sequence[torch.Tensor]] TensorDict = tp.Dict[str, torch.Tensor] + class SharedDiscriminatorConvNet(nn.Module): def __init__( @@ -75,7 +81,7 @@ def __init__( ) -> None: super().__init__() channels = [in_size] - channels += list(capacity * 2**np.arange(n_layers)) + channels += list(capacity * 2 ** np.arange(n_layers)) if isinstance(stride, int): stride = n_layers * [stride] @@ -97,7 +103,9 @@ def __init__( kernel_size, stride=s, padding=pad, - ))) + ) + ) + ) net.append(activation()) net.append(convolution(channels[-1], out_size, 1)) @@ -116,10 +124,7 @@ def forward(self, x) -> IndividualDiscriminatorOut: class MultiScaleDiscriminator(nn.Module): - def __init__(self, - in_channels: int, - n_scales: int, - **conv_kwargs) -> None: + def __init__(self, in_channels: int, n_scales: int, **conv_kwargs) -> None: super().__init__() layers = [] for _ in range(n_scales): @@ -136,12 +141,10 @@ def forward(self, x: torch.Tensor) -> IndividualDiscriminatorOut: x = nn.functional.avg_pool1d(x, 2) return score, features + class MultiPeriodDiscriminator(nn.Module): - def __init__(self, - in_channels: int, - periods: tp.Sequence[int], - **conv_kwargs) -> None: + def __init__(self, in_channels: int, periods: tp.Sequence[int], **conv_kwargs) -> None: super().__init__() layers = [] self.periods = periods @@ -173,8 +176,7 @@ class MultiDiscriminator(nn.Module): Sequence[NxB C' T']. """ - def __init__(self, discriminator_list: tp.Sequence[nn.Module], - keys: tp.Sequence[str]) -> None: + def __init__(self, discriminator_list: tp.Sequence[nn.Module], keys: tp.Sequence[str]) -> None: super().__init__() self.discriminators = nn.ModuleList(discriminator_list) self.keys = keys @@ -206,7 +208,7 @@ def sum_dicts(dict_a, dict_b): out_dict = {} keys = set(list(dict_a.keys()) + list(dict_b.keys())) for k in keys: - out_dict[k] = 0. + out_dict[k] = 0.0 if k in dict_a: out_dict[k] = out_dict[k] + dict_a[k] if k in dict_b: @@ -236,13 +238,14 @@ def forward(self, inputs: TensorDict) -> TensorDict: inputs.update(all_features) return inputs - + + class OobleckDiscriminator(nn.Module): def __init__( - self, - in_channels=1, - ): + self, + in_channels=1, + ): super().__init__() multi_scale_discriminator = MultiScaleDiscriminator( @@ -250,10 +253,7 @@ def __init__( n_scales=3, ) - multi_period_discriminator = MultiPeriodDiscriminator( - in_channels=in_channels, - periods=[2, 3, 5, 7, 11] - ) + multi_period_discriminator = MultiPeriodDiscriminator(in_channels=in_channels, periods=[2, 3, 5, 7, 11]) # multi_resolution_discriminator = MultiScaleSTFTDiscriminator( # filters=32, @@ -265,8 +265,8 @@ def __init__( # ) self.multi_discriminator = MultiDiscriminator( - [multi_scale_discriminator, multi_period_discriminator], #, multi_resolution_discriminator], - ["reals", "fakes"] + [multi_scale_discriminator, multi_period_discriminator], # , multi_resolution_discriminator], + ["reals", "fakes"], ) def loss(self, reals, fakes): @@ -284,8 +284,8 @@ def loss(self, reals, fakes): features_fake = inputs["features_fakes"] dis_loss, gen_loss = get_hinge_losses(scores_real, scores_fake) - - feature_matching_distance = torch.tensor(0.) + + feature_matching_distance = torch.tensor(0.0) for _, (scale_real, scale_fake) in enumerate(zip(features_real, features_fake)): @@ -294,10 +294,11 @@ def loss(self, reals, fakes): lambda real, fake: abs(real - fake).mean(), scale_real, scale_fake, - )) / len(scale_real) - + ) + ) / len(scale_real) + return dis_loss, gen_loss, feature_matching_distance - + ## Discriminators from Descript Audio Codec repo ## Copied and modified under MIT license, see LICENSES/LICENSE_DESCRIPT.txt @@ -315,9 +316,7 @@ def __init__(self, period, channels=1): WNConv2d(1024, 1024, (5, 1), 1, padding=(2, 0)), ] ) - self.conv_post = WNConv2d( - 1024, 1, kernel_size=(3, 1), padding=(1, 0), act=False - ) + self.conv_post = WNConv2d(1024, 1, kernel_size=(3, 1), padding=(1, 0), act=False) def pad_to_period(self, x): t = x.shape[-1] @@ -384,7 +383,7 @@ def __init__( hop_factor: float = 0.25, sample_rate: int = 44100, bands: list = BANDS, - channels: int = 1 + channels: int = 1, ): """Complex multi-band spectrogram discriminator. Parameters @@ -499,6 +498,7 @@ def forward(self, x): fmaps = [d(x) for d in self.discriminators] return fmaps + class DACGANLoss(nn.Module): """ Computes a discriminator loss, given a discriminator on @@ -538,9 +538,9 @@ def generator_loss(self, fake, real): for j in range(len(d_fake[i]) - 1): loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j].detach()) return loss_g, loss_feature - + def loss(self, fake, real): gen_loss, feature_distance = self.generator_loss(fake, real) dis_loss = self.discriminator_loss(fake, real) - return dis_loss, gen_loss, feature_distance \ No newline at end of file + return dis_loss, gen_loss, feature_distance diff --git a/stable_audio_tools/models/dit.py b/stable_audio_tools/models/dit.py index 4aef6c77..bca440f2 100644 --- a/stable_audio_tools/models/dit.py +++ b/stable_audio_tools/models/dit.py @@ -1,7 +1,6 @@ import typing as tp import torch - from einops import rearrange from torch import nn from torch.nn import functional as F @@ -10,9 +9,11 @@ from .blocks import FourierFeatures from .transformer import ContinuousTransformer + class DiffusionTransformer(nn.Module): - def __init__(self, - io_channels=32, + def __init__( + self, + io_channels=32, patch_size=1, embed_dim=768, cond_token_dim=0, @@ -25,10 +26,11 @@ def __init__(self, num_heads=8, transformer_type: tp.Literal["x-transformers", "continuous_transformer"] = "x-transformers", global_cond_type: tp.Literal["prepend", "adaLN"] = "prepend", - **kwargs): + **kwargs, + ): super().__init__() - + self.cond_token_dim = cond_token_dim # Timestep embeddings @@ -49,7 +51,7 @@ def __init__(self, self.to_cond_embed = nn.Sequential( nn.Linear(cond_token_dim, cond_embed_dim, bias=False), nn.SiLU(), - nn.Linear(cond_embed_dim, cond_embed_dim, bias=False) + nn.Linear(cond_embed_dim, cond_embed_dim, bias=False), ) else: cond_embed_dim = 0 @@ -60,7 +62,7 @@ def __init__(self, self.to_global_embed = nn.Sequential( nn.Linear(global_cond_dim, global_embed_dim, bias=False), nn.SiLU(), - nn.Linear(global_embed_dim, global_embed_dim, bias=False) + nn.Linear(global_embed_dim, global_embed_dim, bias=False), ) if prepend_cond_dim > 0: @@ -68,7 +70,7 @@ def __init__(self, self.to_prepend_embed = nn.Sequential( nn.Linear(prepend_cond_dim, embed_dim, bias=False), nn.SiLU(), - nn.Linear(embed_dim, embed_dim, bias=False) + nn.Linear(embed_dim, embed_dim, bias=False), ) self.input_concat_dim = input_concat_dim @@ -87,21 +89,21 @@ def __init__(self, self.transformer = ContinuousTransformerWrapper( dim_in=dim_in * patch_size, dim_out=io_channels * patch_size, - max_seq_len=0, #Not relevant without absolute positional embeds - attn_layers = Encoder( + max_seq_len=0, # Not relevant without absolute positional embeds + attn_layers=Encoder( dim=embed_dim, depth=depth, heads=num_heads, - attn_flash = True, - cross_attend = cond_token_dim > 0, + attn_flash=True, + cross_attend=cond_token_dim > 0, dim_context=None if cond_embed_dim == 0 else cond_embed_dim, zero_init_branch_output=True, - use_abs_pos_emb = False, + use_abs_pos_emb=False, rotary_pos_emb=True, - ff_swish = True, - ff_glu = True, - **kwargs - ) + ff_swish=True, + ff_glu=True, + **kwargs, + ), ) elif self.transformer_type == "continuous_transformer": @@ -118,12 +120,12 @@ def __init__(self, dim_heads=embed_dim // num_heads, dim_in=dim_in * patch_size, dim_out=io_channels * patch_size, - cross_attend = cond_token_dim > 0, - cond_token_dim = cond_embed_dim, + cross_attend=cond_token_dim > 0, + cond_token_dim=cond_embed_dim, global_cond_dim=global_dim, - **kwargs + **kwargs, ) - + else: raise ValueError(f"Unknown transformer type: {self.transformer_type}") @@ -133,9 +135,9 @@ def __init__(self, nn.init.zeros_(self.postprocess_conv.weight) def _forward( - self, - x, - t, + self, + x, + t, mask=None, cross_attn_cond=None, cross_attn_cond_mask=None, @@ -144,7 +146,8 @@ def _forward( prepend_cond=None, prepend_cond_mask=None, return_info=False, - **kwargs): + **kwargs, + ): if cross_attn_cond is not None: cross_attn_cond = self.to_cond_embed(cross_attn_cond) @@ -153,13 +156,13 @@ def _forward( # Project the global conditioning to the embedding dimension global_embed = self.to_global_embed(global_embed) - prepend_inputs = None + prepend_inputs = None prepend_mask = None prepend_length = 0 if prepend_cond is not None: # Project the prepend conditioning to the embedding dimension prepend_cond = self.to_prepend_embed(prepend_cond) - + prepend_inputs = prepend_cond if prepend_cond_mask is not None: prepend_mask = prepend_cond_mask @@ -168,12 +171,12 @@ def _forward( # Interpolate input_concat_cond to the same length as x if input_concat_cond.shape[2] != x.shape[2]: - input_concat_cond = F.interpolate(input_concat_cond, (x.shape[2], ), mode='nearest') + input_concat_cond = F.interpolate(input_concat_cond, (x.shape[2],), mode="nearest") x = torch.cat([x, input_concat_cond], dim=1) # Get the batch of timestep embeddings - timestep_embed = self.to_timestep_embed(self.timestep_features(t[:, None])) # (b, embed_dim) + timestep_embed = self.to_timestep_embed(self.timestep_features(t[:, None])) # (b, embed_dim) # Timestep embedding is considered a global embedding. Add to the global conditioning if it exists if global_embed is not None: @@ -190,7 +193,9 @@ def _forward( else: # Prepend inputs are the prepend conditioning + the global embed prepend_inputs = torch.cat([prepend_inputs, global_embed.unsqueeze(1)], dim=1) - prepend_mask = torch.cat([prepend_mask, torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)], dim=1) + prepend_mask = torch.cat( + [prepend_mask, torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)], dim=1 + ) prepend_length = prepend_inputs.shape[1] @@ -207,16 +212,37 @@ def _forward( x = rearrange(x, "b (t p) c -> b t (c p)", p=self.patch_size) if self.transformer_type == "x-transformers": - output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, context_mask=cross_attn_cond_mask, mask=mask, prepend_mask=prepend_mask, **extra_args, **kwargs) + output = self.transformer( + x, + prepend_embeds=prepend_inputs, + context=cross_attn_cond, + context_mask=cross_attn_cond_mask, + mask=mask, + prepend_mask=prepend_mask, + **extra_args, + **kwargs, + ) elif self.transformer_type == "continuous_transformer": - output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, context_mask=cross_attn_cond_mask, mask=mask, prepend_mask=prepend_mask, return_info=return_info, **extra_args, **kwargs) + output = self.transformer( + x, + prepend_embeds=prepend_inputs, + context=cross_attn_cond, + context_mask=cross_attn_cond_mask, + mask=mask, + prepend_mask=prepend_mask, + return_info=return_info, + **extra_args, + **kwargs, + ) if return_info: output, info = output elif self.transformer_type == "mm_transformer": - output = self.transformer(x, context=cross_attn_cond, mask=mask, context_mask=cross_attn_cond_mask, **extra_args, **kwargs) + output = self.transformer( + x, context=cross_attn_cond, mask=mask, context_mask=cross_attn_cond_mask, **extra_args, **kwargs + ) - output = rearrange(output, "b t c -> b c t")[:,:,prepend_length:] + output = rearrange(output, "b t c -> b c t")[:, :, prepend_length:] if self.patch_size > 1: output = rearrange(output, "b (c p) t -> b c (t p)", p=self.patch_size) @@ -229,9 +255,9 @@ def _forward( return output def forward( - self, - x, - t, + self, + x, + t, cross_attn_cond=None, cross_attn_cond_mask=None, negative_cross_attn_cond=None, @@ -247,14 +273,17 @@ def forward( scale_phi=0.0, mask=None, return_info=False, - **kwargs): + **kwargs, + ): assert causal == False, "Causal mode is not supported for DiffusionTransformer" if cross_attn_cond_mask is not None: cross_attn_cond_mask = cross_attn_cond_mask.bool() - cross_attn_cond_mask = None # Temporarily disabling conditioning masks due to kernel issue for flash attention + cross_attn_cond_mask = ( + None # Temporarily disabling conditioning masks due to kernel issue for flash attention + ) if prepend_cond_mask is not None: prepend_cond_mask = prepend_cond_mask.bool() @@ -263,18 +292,21 @@ def forward( if cfg_dropout_prob > 0.0: if cross_attn_cond is not None: null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device) - dropout_mask = torch.bernoulli(torch.full((cross_attn_cond.shape[0], 1, 1), cfg_dropout_prob, device=cross_attn_cond.device)).to(torch.bool) + dropout_mask = torch.bernoulli( + torch.full((cross_attn_cond.shape[0], 1, 1), cfg_dropout_prob, device=cross_attn_cond.device) + ).to(torch.bool) cross_attn_cond = torch.where(dropout_mask, null_embed, cross_attn_cond) if prepend_cond is not None: null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device) - dropout_mask = torch.bernoulli(torch.full((prepend_cond.shape[0], 1, 1), cfg_dropout_prob, device=prepend_cond.device)).to(torch.bool) + dropout_mask = torch.bernoulli( + torch.full((prepend_cond.shape[0], 1, 1), cfg_dropout_prob, device=prepend_cond.device) + ).to(torch.bool) prepend_cond = torch.where(dropout_mask, null_embed, prepend_cond) - if cfg_scale != 1.0 and (cross_attn_cond is not None or prepend_cond is not None): # Classifier-free guidance - # Concatenate conditioned and unconditioned inputs on the batch dimension + # Concatenate conditioned and unconditioned inputs on the batch dimension batch_inputs = torch.cat([x, x], dim=0) batch_timestep = torch.cat([t, t], dim=0) @@ -290,7 +322,7 @@ def forward( batch_cond = None batch_cond_masks = None - + # Handle CFG for cross-attention conditioning if cross_attn_cond is not None: @@ -303,8 +335,10 @@ def forward( if negative_cross_attn_mask is not None: negative_cross_attn_mask = negative_cross_attn_mask.to(torch.bool).unsqueeze(2) - negative_cross_attn_cond = torch.where(negative_cross_attn_mask, negative_cross_attn_cond, null_embed) - + negative_cross_attn_cond = torch.where( + negative_cross_attn_mask, negative_cross_attn_cond, null_embed + ) + batch_cond = torch.cat([cross_attn_cond, negative_cross_attn_cond], dim=0) else: @@ -312,7 +346,7 @@ def forward( if cross_attn_cond_mask is not None: batch_cond_masks = torch.cat([cross_attn_cond_mask, cross_attn_cond_mask], dim=0) - + batch_prepend_cond = None batch_prepend_cond_mask = None @@ -321,28 +355,28 @@ def forward( null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device) batch_prepend_cond = torch.cat([prepend_cond, null_embed], dim=0) - + if prepend_cond_mask is not None: batch_prepend_cond_mask = torch.cat([prepend_cond_mask, prepend_cond_mask], dim=0) - if mask is not None: batch_masks = torch.cat([mask, mask], dim=0) else: batch_masks = None - + batch_output = self._forward( - batch_inputs, - batch_timestep, - cross_attn_cond=batch_cond, - cross_attn_cond_mask=batch_cond_masks, - mask = batch_masks, - input_concat_cond=batch_input_concat_cond, - global_embed = batch_global_cond, - prepend_cond = batch_prepend_cond, - prepend_cond_mask = batch_prepend_cond_mask, - return_info = return_info, - **kwargs) + batch_inputs, + batch_timestep, + cross_attn_cond=batch_cond, + cross_attn_cond_mask=batch_cond_masks, + mask=batch_masks, + input_concat_cond=batch_input_concat_cond, + global_embed=batch_global_cond, + prepend_cond=batch_prepend_cond, + prepend_cond_mask=batch_prepend_cond_mask, + return_info=return_info, + **kwargs, + ) if return_info: batch_output, info = batch_output @@ -354,26 +388,26 @@ def forward( if scale_phi != 0.0: cond_out_std = cond_output.std(dim=1, keepdim=True) out_cfg_std = cfg_output.std(dim=1, keepdim=True) - output = scale_phi * (cfg_output * (cond_out_std/out_cfg_std)) + (1-scale_phi) * cfg_output + output = scale_phi * (cfg_output * (cond_out_std / out_cfg_std)) + (1 - scale_phi) * cfg_output else: output = cfg_output - + if return_info: return output, info return output - + else: return self._forward( x, t, - cross_attn_cond=cross_attn_cond, - cross_attn_cond_mask=cross_attn_cond_mask, - input_concat_cond=input_concat_cond, - global_embed=global_embed, - prepend_cond=prepend_cond, + cross_attn_cond=cross_attn_cond, + cross_attn_cond_mask=cross_attn_cond_mask, + input_concat_cond=input_concat_cond, + global_embed=global_embed, + prepend_cond=prepend_cond, prepend_cond_mask=prepend_cond_mask, mask=mask, return_info=return_info, - **kwargs - ) \ No newline at end of file + **kwargs, + ) diff --git a/stable_audio_tools/models/factory.py b/stable_audio_tools/models/factory.py index 41887030..5357ca03 100644 --- a/stable_audio_tools/models/factory.py +++ b/stable_audio_tools/models/factory.py @@ -1,40 +1,48 @@ import json + def create_model_from_config(model_config): - model_type = model_config.get('model_type', None) + model_type = model_config.get("model_type", None) - assert model_type is not None, 'model_type must be specified in model config' + assert model_type is not None, "model_type must be specified in model config" - if model_type == 'autoencoder': + if model_type == "autoencoder": from .autoencoders import create_autoencoder_from_config + return create_autoencoder_from_config(model_config) - elif model_type == 'diffusion_uncond': + elif model_type == "diffusion_uncond": from .diffusion import create_diffusion_uncond_from_config + return create_diffusion_uncond_from_config(model_config) - elif model_type == 'diffusion_cond' or model_type == 'diffusion_cond_inpaint' or model_type == "diffusion_prior": + elif model_type == "diffusion_cond" or model_type == "diffusion_cond_inpaint" or model_type == "diffusion_prior": from .diffusion import create_diffusion_cond_from_config + return create_diffusion_cond_from_config(model_config) - elif model_type == 'diffusion_autoencoder': + elif model_type == "diffusion_autoencoder": from .autoencoders import create_diffAE_from_config + return create_diffAE_from_config(model_config) - elif model_type == 'lm': + elif model_type == "lm": from .lm import create_audio_lm_from_config + return create_audio_lm_from_config(model_config) else: - raise NotImplementedError(f'Unknown model type: {model_type}') + raise NotImplementedError(f"Unknown model type: {model_type}") + def create_model_from_config_path(model_config_path): with open(model_config_path) as f: model_config = json.load(f) - + return create_model_from_config(model_config) + def create_pretransform_from_config(pretransform_config, sample_rate): - pretransform_type = pretransform_config.get('type', None) + pretransform_type = pretransform_config.get("type", None) - assert pretransform_type is not None, 'type must be specified in pretransform config' + assert pretransform_type is not None, "type must be specified in pretransform config" - if pretransform_type == 'autoencoder': + if pretransform_type == "autoencoder": from .autoencoders import create_autoencoder_from_config from .pretransforms import AutoencoderPretransform @@ -48,8 +56,10 @@ def create_pretransform_from_config(pretransform_config, sample_rate): iterate_batch = pretransform_config.get("iterate_batch", False) chunked = pretransform_config.get("chunked", False) - pretransform = AutoencoderPretransform(autoencoder, scale=scale, model_half=model_half, iterate_batch=iterate_batch, chunked=chunked) - elif pretransform_type == 'wavelet': + pretransform = AutoencoderPretransform( + autoencoder, scale=scale, model_half=model_half, iterate_batch=iterate_batch, chunked=chunked + ) + elif pretransform_type == "wavelet": from .pretransforms import WaveletPretransform wavelet_config = pretransform_config["config"] @@ -58,12 +68,14 @@ def create_pretransform_from_config(pretransform_config, sample_rate): wavelet = wavelet_config["wavelet"] pretransform = WaveletPretransform(channels, levels, wavelet) - elif pretransform_type == 'pqmf': + elif pretransform_type == "pqmf": from .pretransforms import PQMFPretransform + pqmf_config = pretransform_config["config"] pretransform = PQMFPretransform(**pqmf_config) - elif pretransform_type == 'dac_pretrained': + elif pretransform_type == "dac_pretrained": from .pretransforms import PretrainedDACPretransform + pretrained_dac_config = pretransform_config["config"] pretransform = PretrainedDACPretransform(**pretrained_dac_config) elif pretransform_type == "audiocraft_pretrained": @@ -72,27 +84,30 @@ def create_pretransform_from_config(pretransform_config, sample_rate): audiocraft_config = pretransform_config["config"] pretransform = AudiocraftCompressionPretransform(**audiocraft_config) else: - raise NotImplementedError(f'Unknown pretransform type: {pretransform_type}') - - enable_grad = pretransform_config.get('enable_grad', False) + raise NotImplementedError(f"Unknown pretransform type: {pretransform_type}") + + enable_grad = pretransform_config.get("enable_grad", False) pretransform.enable_grad = enable_grad pretransform.eval().requires_grad_(pretransform.enable_grad) return pretransform + def create_bottleneck_from_config(bottleneck_config): - bottleneck_type = bottleneck_config.get('type', None) + bottleneck_type = bottleneck_config.get("type", None) - assert bottleneck_type is not None, 'type must be specified in bottleneck config' + assert bottleneck_type is not None, "type must be specified in bottleneck config" - if bottleneck_type == 'tanh': + if bottleneck_type == "tanh": from .bottleneck import TanhBottleneck + bottleneck = TanhBottleneck() - elif bottleneck_type == 'vae': + elif bottleneck_type == "vae": from .bottleneck import VAEBottleneck + bottleneck = VAEBottleneck() - elif bottleneck_type == 'rvq': + elif bottleneck_type == "rvq": from .bottleneck import RVQBottleneck quantizer_params = { @@ -112,8 +127,8 @@ def create_bottleneck_from_config(bottleneck_config): from .bottleneck import DACRVQBottleneck bottleneck = DACRVQBottleneck(**bottleneck_config["config"]) - - elif bottleneck_type == 'rvq_vae': + + elif bottleneck_type == "rvq_vae": from .bottleneck import RVQVAEBottleneck quantizer_params = { @@ -129,23 +144,27 @@ def create_bottleneck_from_config(bottleneck_config): quantizer_params.update(bottleneck_config["config"]) bottleneck = RVQVAEBottleneck(**quantizer_params) - - elif bottleneck_type == 'dac_rvq_vae': + + elif bottleneck_type == "dac_rvq_vae": from .bottleneck import DACRVQVAEBottleneck + bottleneck = DACRVQVAEBottleneck(**bottleneck_config["config"]) - elif bottleneck_type == 'l2_norm': + elif bottleneck_type == "l2_norm": from .bottleneck import L2Bottleneck + bottleneck = L2Bottleneck() elif bottleneck_type == "wasserstein": from .bottleneck import WassersteinBottleneck + bottleneck = WassersteinBottleneck(**bottleneck_config.get("config", {})) elif bottleneck_type == "fsq": from .bottleneck import FSQBottleneck + bottleneck = FSQBottleneck(**bottleneck_config["config"]) else: - raise NotImplementedError(f'Unknown bottleneck type: {bottleneck_type}') - - requires_grad = bottleneck_config.get('requires_grad', True) + raise NotImplementedError(f"Unknown bottleneck type: {bottleneck_type}") + + requires_grad = bottleneck_config.get("requires_grad", True) if not requires_grad: for param in bottleneck.parameters(): param.requires_grad = False diff --git a/stable_audio_tools/models/lm.py b/stable_audio_tools/models/lm.py index 1897fa72..c880eb34 100644 --- a/stable_audio_tools/models/lm.py +++ b/stable_audio_tools/models/lm.py @@ -1,27 +1,40 @@ +import typing as tp from dataclasses import dataclass + import torch -from tqdm.auto import trange -import typing as tp from einops import rearrange from torch import nn - -from .conditioners import MultiConditioner, create_multi_conditioner_from_conditioning_config -from .factory import create_pretransform_from_config -from .lm_backbone import AudioLMBackbone, XTransformersAudioLMBackbone, ContinuousTransformerAudioLMBackbone -from .pretransforms import Pretransform, AutoencoderPretransform, PretrainedDACPretransform, AudiocraftCompressionPretransform -from .utils import multinomial, sample_top_k, sample_top_p +from tqdm.auto import trange from .codebook_patterns import ( CodebooksPatternProvider, DelayedPatternProvider, MusicLMPattern, ParallelPatternProvider, - UnrolledPatternProvider + UnrolledPatternProvider, +) +from .conditioners import ( + MultiConditioner, + create_multi_conditioner_from_conditioning_config, +) +from .factory import create_pretransform_from_config +from .lm_backbone import ( + AudioLMBackbone, + ContinuousTransformerAudioLMBackbone, + XTransformersAudioLMBackbone, ) +from .pretransforms import ( + AudiocraftCompressionPretransform, + AutoencoderPretransform, + PretrainedDACPretransform, + Pretransform, +) +from .utils import multinomial, sample_top_k, sample_top_p # Copied and modified from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/models/lm.py under MIT license # License can be found in LICENSES/LICENSE_META.txt + @dataclass class LMOutput: # The logits are already re-aligned with the input codes @@ -29,16 +42,17 @@ class LMOutput: logits: torch.Tensor # [B, K, T, card] mask: torch.Tensor # [B, K, T] + # Wrapper for a multi-codebook language model # Handles patterns and quantizer heads class AudioLanguageModel(nn.Module): def __init__( - self, - pattern_provider: CodebooksPatternProvider, - backbone: AudioLMBackbone, - num_quantizers: int, - codebook_size: int - ): + self, + pattern_provider: CodebooksPatternProvider, + backbone: AudioLMBackbone, + num_quantizers: int, + codebook_size: int, + ): super().__init__() self.pattern_provider = pattern_provider @@ -50,26 +64,33 @@ def __init__( # Per-quantizer embedders # Add one for the mask embed - self.embeds = nn.ModuleList([nn.Embedding(codebook_size + 1, backbone.embed_dim) for _ in range(num_quantizers)]) + self.embeds = nn.ModuleList( + [nn.Embedding(codebook_size + 1, backbone.embed_dim) for _ in range(num_quantizers)] + ) # Per-quantizer output heads - self.quantizer_heads = nn.ModuleList([ - nn.Linear(backbone.embed_dim, codebook_size) for _ in range(num_quantizers) - ]) - - def forward(self, - sequence: torch.Tensor, #[batch, seq_len, - prepend_cond=None, #[batch, seq, channels] - prepend_cond_mask=None, - cross_attn_cond=None, #[batch, seq, channels], - **kwargs - ): + self.quantizer_heads = nn.ModuleList( + [nn.Linear(backbone.embed_dim, codebook_size) for _ in range(num_quantizers)] + ) + + def forward( + self, + sequence: torch.Tensor, # [batch, seq_len, + prepend_cond=None, # [batch, seq, channels] + prepend_cond_mask=None, + cross_attn_cond=None, # [batch, seq, channels], + **kwargs, + ): batch, num_quantizers, seq_len = sequence.shape - assert num_quantizers == self.num_quantizers, "Number of quantizers in sequence must match number of quantizers in model" + assert ( + num_quantizers == self.num_quantizers + ), "Number of quantizers in sequence must match number of quantizers in model" - backbone_input = sum([self.embeds[i](sequence[:, i]) for i in range(num_quantizers)]) # [batch, seq_len, embed_dim] + backbone_input = sum( + [self.embeds[i](sequence[:, i]) for i in range(num_quantizers)] + ) # [batch, seq_len, embed_dim] dtype = next(self.parameters()).dtype @@ -81,7 +102,7 @@ def forward(self, if prepend_cond_mask is not None: prepend_cond_mask = prepend_cond_mask.to(dtype) - + backbone_input = backbone_input.to(dtype) output = self.backbone( @@ -89,34 +110,29 @@ def forward(self, cross_attn_cond=cross_attn_cond, prepend_cond=prepend_cond, prepend_cond_mask=prepend_cond_mask, - **kwargs - ) # [batch, seq_len, embed_dim] + **kwargs, + ) # [batch, seq_len, embed_dim] # Run output through quantizer heads - logits = torch.stack([self.quantizer_heads[i](output) for i in range(num_quantizers)], dim=1) # [batch, num_quantizers, seq_len, codebook_size] + logits = torch.stack( + [self.quantizer_heads[i](output) for i in range(num_quantizers)], dim=1 + ) # [batch, num_quantizers, seq_len, codebook_size] return logits - - def compute_logits( - self, - codes, #[batch, num_quantizers, seq_len] - **kwargs): + + def compute_logits(self, codes, **kwargs): # [batch, num_quantizers, seq_len] """ Compute logits for a batch of codes, optionally conditioning on cross-attention and prepend conditioning Handles translation between input sequence and pattern-shifted sequence Only used during training """ - + batch, _, seq_len = codes.shape pattern = self.pattern_provider.get_pattern(seq_len) # Apply the token pattern to the codes, shifting the codes as needed and masking out invalid steps - shifted_codes, _, _ = pattern.build_pattern_sequence( - codes, - self.masked_token_id, - keep_only_valid_steps=True - ) + shifted_codes, _, _ = pattern.build_pattern_sequence(codes, self.masked_token_id, keep_only_valid_steps=True) # Run the model to get logits for each quantizer [batch, num_quantizers, seq_len, codebook_size] logits = self(shifted_codes, **kwargs) @@ -125,32 +141,31 @@ def compute_logits( logits = rearrange(logits, "b n s c -> b c n s") # Revert sequence logits back to original sequence length, removing masked steps - logits, _, logits_mask = pattern.revert_pattern_logits( - logits, float('nan'), keep_only_valid_steps=True - ) + logits, _, logits_mask = pattern.revert_pattern_logits(logits, float("nan"), keep_only_valid_steps=True) logits = rearrange(logits, "b c n t -> b n t c") - logits_mask = logits_mask[None, :, :].expand(batch, -1, -1) # [batch, num_quantizers, seq_len] + logits_mask = logits_mask[None, :, :].expand(batch, -1, -1) # [batch, num_quantizers, seq_len] return LMOutput(logits=logits, mask=logits_mask) + # Conditioning and generation wrapper for a multi-codebook language model # Handles conditioning, CFG, generation, and encoding/decoding class AudioLanguageModelWrapper(nn.Module): def __init__( - self, - pretransform: Pretransform, - lm: AudioLanguageModel, - sample_rate: int, - min_input_length: int, - conditioner: MultiConditioner = None, - cross_attn_cond_ids: tp.List[str] = [], - prepend_cond_ids: tp.List[str] = [], - global_cond_ids: tp.List[str] = [] - ): + self, + pretransform: Pretransform, + lm: AudioLanguageModel, + sample_rate: int, + min_input_length: int, + conditioner: MultiConditioner = None, + cross_attn_cond_ids: tp.List[str] = [], + prepend_cond_ids: tp.List[str] = [], + global_cond_ids: tp.List[str] = [], + ): super().__init__() - + assert pretransform.is_discrete, "Pretransform must be discrete" self.pretransform = pretransform @@ -179,7 +194,7 @@ def __init__( self.cross_attn_cond_ids = cross_attn_cond_ids self.prepend_cond_ids = prepend_cond_ids self.global_cond_ids = global_cond_ids - + def get_conditioning_inputs(self, cond: tp.Dict[str, tp.Any], negative=False): cross_attention_input = None prepend_cond = None @@ -209,23 +224,17 @@ def get_conditioning_inputs(self, cond: tp.Dict[str, tp.Any], negative=False): "negative_cross_attn_cond": cross_attention_input, "negative_prepend_cond": prepend_cond, "negative_prepend_cond_mask": prepend_cond_mask, - "negative_global_cond": global_cond + "negative_global_cond": global_cond, } else: return { "cross_attn_cond": cross_attention_input, "prepend_cond": prepend_cond, "prepend_cond_mask": prepend_cond_mask, - "global_cond": global_cond + "global_cond": global_cond, } - - def compute_logits( - self, - codes, - condition_tensors=None, - cfg_dropout_prob=0.0, - **kwargs - ): + + def compute_logits(self, codes, condition_tensors=None, cfg_dropout_prob=0.0, **kwargs): """ Compute logits for a batch of codes, and translates from conditioning inputs to model inputs Handles CFG dropout @@ -244,34 +253,47 @@ def compute_logits( if cfg_dropout_prob > 0.0: if cross_attn_cond is not None: null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device) - dropout_mask = torch.bernoulli(torch.full((cross_attn_cond.shape[0], 1, 1), cfg_dropout_prob, device=cross_attn_cond.device)).to(torch.bool) + dropout_mask = torch.bernoulli( + torch.full((cross_attn_cond.shape[0], 1, 1), cfg_dropout_prob, device=cross_attn_cond.device) + ).to(torch.bool) cross_attn_cond = torch.where(dropout_mask, null_embed, cross_attn_cond) - + if prepend_cond is not None: null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device) - dropout_mask = torch.bernoulli(torch.full((prepend_cond.shape[0], 1, 1), cfg_dropout_prob, device=prepend_cond.device)).to(torch.bool) + dropout_mask = torch.bernoulli( + torch.full((prepend_cond.shape[0], 1, 1), cfg_dropout_prob, device=prepend_cond.device) + ).to(torch.bool) prepend_cond = torch.where(dropout_mask, null_embed, prepend_cond) if global_cond is not None: null_embed = torch.zeros_like(global_cond, device=global_cond.device) - dropout_mask = torch.bernoulli(torch.full((global_cond.shape[0], 1), cfg_dropout_prob, device=global_cond.device)).to(torch.bool) + dropout_mask = torch.bernoulli( + torch.full((global_cond.shape[0], 1), cfg_dropout_prob, device=global_cond.device) + ).to(torch.bool) global_cond = torch.where(dropout_mask, null_embed, global_cond) - return self.lm.compute_logits(codes, cross_attn_cond=cross_attn_cond, prepend_cond=prepend_cond, prepend_cond_mask=prepend_cond_mask, global_cond=global_cond, **kwargs) - + return self.lm.compute_logits( + codes, + cross_attn_cond=cross_attn_cond, + prepend_cond=prepend_cond, + prepend_cond_mask=prepend_cond_mask, + global_cond=global_cond, + **kwargs, + ) + def _sample_next_token( - self, - sequence, #[batch, num_quantizers, seq_len] - conditioning_tensors=None, - cross_attn_use_cfg=True, - prepend_use_cfg=True, - global_use_cfg=True, - cfg_scale=1.0, - top_k=250, - top_p=0.0, - temp=1.0, - **kwargs - ): + self, + sequence, # [batch, num_quantizers, seq_len] + conditioning_tensors=None, + cross_attn_use_cfg=True, + prepend_use_cfg=True, + global_use_cfg=True, + cfg_scale=1.0, + top_k=250, + top_p=0.0, + temp=1.0, + **kwargs, + ): """ Sample the next token for a batch of codes, and translates from conditioning inputs to model inputs Handles CFG inference @@ -288,7 +310,7 @@ def _sample_next_token( global_cond = conditioning_inputs["global_cond"] if cfg_scale != 1.0: - + # Batch size is doubled to account for negative samples sequence = torch.cat([sequence, sequence], dim=0) @@ -296,11 +318,11 @@ def _sample_next_token( null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device) cross_attn_cond = torch.cat([cross_attn_cond, null_embed], dim=0) - + if prepend_cond is not None and prepend_use_cfg: null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device) - prepend_cond = torch.cat([prepend_cond, null_embed], dim=0) + prepend_cond = torch.cat([prepend_cond, null_embed], dim=0) if prepend_cond_mask is not None: prepend_cond_mask = torch.cat([prepend_cond_mask, prepend_cond_mask], dim=0) @@ -310,17 +332,24 @@ def _sample_next_token( global_cond = torch.cat([global_cond, null_embed], dim=0) - logits = self.lm(sequence, cross_attn_cond=cross_attn_cond, prepend_cond=prepend_cond, prepend_cond_mask=prepend_cond_mask, global_cond=global_cond, **kwargs) + logits = self.lm( + sequence, + cross_attn_cond=cross_attn_cond, + prepend_cond=prepend_cond, + prepend_cond_mask=prepend_cond_mask, + global_cond=global_cond, + **kwargs, + ) if cfg_scale != 1.0: cond_logits, uncond_logits = logits.chunk(2, dim=0) logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale - logits = rearrange(logits, "b n s c -> b n c s") # [batch, num_quantizers, codebook_size, seq_len] - + logits = rearrange(logits, "b n s c -> b n c s") # [batch, num_quantizers, codebook_size, seq_len] + # Grab the logits for the last step - logits = logits[:, :, :, -1] # [batch, num_quantizers, codebook_size] + logits = logits[:, :, :, -1] # [batch, num_quantizers, codebook_size] # Apply top-k or top-p sampling @@ -335,7 +364,7 @@ def _sample_next_token( next_token = multinomial(probs, num_samples=1) else: - next_token = torch.argmax(logits, dim=-1, keepdim=True) # [batch, num_quantizers, 1] + next_token = torch.argmax(logits, dim=-1, keepdim=True) # [batch, num_quantizers, 1] return next_token @@ -350,7 +379,7 @@ def generate( callback: tp.Optional[tp.Callable[[int, int], None]] = None, use_cache: bool = True, cfg_scale: float = 1.0, - **kwargs + **kwargs, ): device = next(self.parameters()).device @@ -371,10 +400,12 @@ def generate( else: possible_batch_sizes.append(1) - assert [x == possible_batch_sizes[0] for x in possible_batch_sizes], "Batch size must be consistent across inputs" + assert [ + x == possible_batch_sizes[0] for x in possible_batch_sizes + ], "Batch size must be consistent across inputs" batch_size = possible_batch_sizes[0] - + if init_data is None: # Initialize with zeros assert batch_size > 0 @@ -390,10 +421,14 @@ def generate( unknown_token = -1 # Initialize the generated codes with the init data, padded with unknown tokens - gen_codes = torch.full((batch_size, num_quantizers, max_gen_len), unknown_token, device=device, dtype=torch.long) - gen_codes[:, :, :start_offset] = init_data # [batch, num_quantizers, max_gen_len] + gen_codes = torch.full( + (batch_size, num_quantizers, max_gen_len), unknown_token, device=device, dtype=torch.long + ) + gen_codes[:, :, :start_offset] = init_data # [batch, num_quantizers, max_gen_len] - gen_sequence, _, mask = pattern.build_pattern_sequence(gen_codes, self.lm.masked_token_id) # [batch, num_quantizers, gen_sequence_len] + gen_sequence, _, mask = pattern.build_pattern_sequence( + gen_codes, self.lm.masked_token_id + ) # [batch, num_quantizers, gen_sequence_len] start_offset_sequence = pattern.get_first_step_with_timesteps(start_offset) assert start_offset_sequence is not None @@ -416,17 +451,17 @@ def generate( conditioning_tensors=conditioning_tensors, use_cache=use_cache, cfg_scale=cfg_scale, - **kwargs + **kwargs, ) - valid_mask = mask[..., offset:offset+1].expand(batch_size, -1, -1) + valid_mask = mask[..., offset : offset + 1].expand(batch_size, -1, -1) next_token[~valid_mask] = self.lm.masked_token_id # Update the generated sequence with the next token - gen_sequence[..., offset:offset+1] = torch.where( - gen_sequence[..., offset:offset+1] == unknown_token, - next_token, - gen_sequence[..., offset:offset+1] + gen_sequence[..., offset : offset + 1] = torch.where( + gen_sequence[..., offset : offset + 1] == unknown_token, + next_token, + gen_sequence[..., offset : offset + 1], ) if use_cache and self.lm.backbone.use_generation_cache: @@ -448,15 +483,11 @@ def generate( assert (out_codes[..., :max_gen_len] != unknown_token).all() assert (out_mask[..., :max_gen_len] == 1).all() - #out_codes = out_codes[..., 0:max_gen_len] + # out_codes = out_codes[..., 0:max_gen_len] return out_codes - - def generate_audio( - self, - **kwargs - ): + def generate_audio(self, **kwargs): """ Generate audio from a batch of codes """ @@ -469,26 +500,26 @@ def generate_audio( def create_audio_lm_from_config(config): - model_config = config.get('model', None) - assert model_config is not None, 'model config must be specified in config' + model_config = config.get("model", None) + assert model_config is not None, "model config must be specified in config" - sample_rate = config.get('sample_rate', None) + sample_rate = config.get("sample_rate", None) assert sample_rate is not None, "Must specify sample_rate in config" - - lm_config = model_config.get('lm', None) - assert lm_config is not None, 'lm config must be specified in model config' + + lm_config = model_config.get("lm", None) + assert lm_config is not None, "lm config must be specified in model config" codebook_pattern = lm_config.get("codebook_pattern", "delay") pattern_providers = { - 'parallel': ParallelPatternProvider, - 'delay': DelayedPatternProvider, - 'unroll': UnrolledPatternProvider, - 'musiclm': MusicLMPattern, + "parallel": ParallelPatternProvider, + "delay": DelayedPatternProvider, + "unroll": UnrolledPatternProvider, + "musiclm": MusicLMPattern, } pretransform_config = model_config.get("pretransform", None) - + pretransform = create_pretransform_from_config(pretransform_config, sample_rate) assert pretransform.is_discrete, "Pretransform must be discrete" @@ -497,15 +528,15 @@ def create_audio_lm_from_config(config): pattern_provider = pattern_providers[codebook_pattern](n_q=pretransform.num_quantizers) - conditioning_config = model_config.get('conditioning', None) + conditioning_config = model_config.get("conditioning", None) conditioner = None if conditioning_config is not None: conditioner = create_multi_conditioner_from_conditioning_config(conditioning_config) - cross_attn_cond_ids = lm_config.get('cross_attention_cond_ids', []) - prepend_cond_ids = lm_config.get('prepend_cond_ids', []) - global_cond_ids = lm_config.get('global_cond_ids', []) + cross_attn_cond_ids = lm_config.get("cross_attention_cond_ids", []) + prepend_cond_ids = lm_config.get("prepend_cond_ids", []) + global_cond_ids = lm_config.get("global_cond_ids", []) lm_type = lm_config.get("type", None) lm_model_config = lm_config.get("config", None) @@ -524,7 +555,7 @@ def create_audio_lm_from_config(config): pattern_provider=pattern_provider, backbone=backbone, num_quantizers=pretransform.num_quantizers, - codebook_size=pretransform.codebook_size + codebook_size=pretransform.codebook_size, ) model = AudioLanguageModelWrapper( @@ -535,7 +566,7 @@ def create_audio_lm_from_config(config): min_input_length=min_input_length, cross_attn_cond_ids=cross_attn_cond_ids, prepend_cond_ids=prepend_cond_ids, - global_cond_ids=global_cond_ids + global_cond_ids=global_cond_ids, ) - return model \ No newline at end of file + return model diff --git a/stable_audio_tools/models/lm_backbone.py b/stable_audio_tools/models/lm_backbone.py index c80cce60..2746a738 100644 --- a/stable_audio_tools/models/lm_backbone.py +++ b/stable_audio_tools/models/lm_backbone.py @@ -3,6 +3,7 @@ from .transformer import ContinuousTransformer + # Interface for backbone of a language model # Handles conditioning and cross-attention # Does not have to deal with patterns or quantizer heads @@ -14,55 +15,44 @@ def __init__(self, embed_dim: int, use_generation_cache=False, **kwargs): self.use_generation_cache = use_generation_cache def forward( - self, - x, - cross_attn_cond=None, - prepend_cond=None, + self, + x, + cross_attn_cond=None, + prepend_cond=None, prepend_cond_mask=None, global_cond=None, use_cache=False, **kwargs - ): - raise NotImplementedError - - def reset_generation_cache( - self, - max_seq_len, - batch_size, - dtype=None ): + raise NotImplementedError + + def reset_generation_cache(self, max_seq_len, batch_size, dtype=None): pass - def update_generation_cache( - self, - seqlen_offset - ): + def update_generation_cache(self, seqlen_offset): pass + class XTransformersAudioLMBackbone(AudioLMBackbone): - def __init__(self, - embed_dim: int, - cross_attn_cond_dim: int = 0, - prepend_cond_dim: int = 0, - **kwargs): + def __init__(self, embed_dim: int, cross_attn_cond_dim: int = 0, prepend_cond_dim: int = 0, **kwargs): super().__init__(embed_dim=embed_dim) # Embeddings are done in the AudioLanguageModel, so we use the continuous-input transformer self.model = ContinuousTransformerWrapper( dim_in=embed_dim, dim_out=embed_dim, - max_seq_len=0, #Not relevant without absolute positional embeds, + max_seq_len=0, # Not relevant without absolute positional embeds, attn_layers=Decoder( dim=embed_dim, - attn_flash = True, - cross_attend = cross_attn_cond_dim > 0, + attn_flash=True, + cross_attend=cross_attn_cond_dim > 0, zero_init_branch_output=True, - use_abs_pos_emb = False, + use_abs_pos_emb=False, rotary_pos_emb=True, - ff_swish = True, - ff_glu = True, + ff_swish=True, + ff_glu=True, **kwargs - ) + ), ) if prepend_cond_dim > 0: @@ -70,7 +60,7 @@ def __init__(self, self.to_prepend_embed = nn.Sequential( nn.Linear(prepend_cond_dim, embed_dim, bias=False), nn.SiLU(), - nn.Linear(embed_dim, embed_dim, bias=False) + nn.Linear(embed_dim, embed_dim, bias=False), ) if cross_attn_cond_dim > 0: @@ -78,10 +68,19 @@ def __init__(self, self.to_cross_attn_embed = nn.Sequential( nn.Linear(cross_attn_cond_dim, embed_dim, bias=False), nn.SiLU(), - nn.Linear(embed_dim, embed_dim, bias=False) + nn.Linear(embed_dim, embed_dim, bias=False), ) - def forward(self, x, mask=None, prepend_cond=None, prepend_cond_mask=None, cross_attn_cond=None, global_cond=None, use_cache=False): + def forward( + self, + x, + mask=None, + prepend_cond=None, + prepend_cond_mask=None, + cross_attn_cond=None, + global_cond=None, + use_cache=False, + ): prepend_length = 0 if prepend_cond is not None: @@ -97,15 +96,20 @@ def forward(self, x, mask=None, prepend_cond=None, prepend_cond_mask=None, cross # Project the cross-attention conditioning to the embedding dimension cross_attn_cond = self.to_cross_attn_embed(cross_attn_cond) - return self.model(x, mask=mask, context=cross_attn_cond, prepend_embeds=prepend_cond, prepend_mask=prepend_cond_mask)[:, prepend_length:, :] - + return self.model( + x, mask=mask, context=cross_attn_cond, prepend_embeds=prepend_cond, prepend_mask=prepend_cond_mask + )[:, prepend_length:, :] + + class ContinuousTransformerAudioLMBackbone(AudioLMBackbone): - def __init__(self, - embed_dim: int, - cross_attn_cond_dim: int = 0, - prepend_cond_dim: int = 0, - project_cross_attn_cond: bool = False, - **kwargs): + def __init__( + self, + embed_dim: int, + cross_attn_cond_dim: int = 0, + prepend_cond_dim: int = 0, + project_cross_attn_cond: bool = False, + **kwargs + ): super().__init__(embed_dim=embed_dim) # Embeddings are done in the AudioLanguageModel, so we use the continuous-input transformer @@ -113,8 +117,8 @@ def __init__(self, dim=embed_dim, dim_in=embed_dim, dim_out=embed_dim, - cross_attend = cross_attn_cond_dim > 0, - cond_token_dim = embed_dim if project_cross_attn_cond else cross_attn_cond_dim, + cross_attend=cross_attn_cond_dim > 0, + cond_token_dim=embed_dim if project_cross_attn_cond else cross_attn_cond_dim, causal=True, **kwargs ) @@ -124,7 +128,7 @@ def __init__(self, self.to_prepend_embed = nn.Sequential( nn.Linear(prepend_cond_dim, embed_dim, bias=False), nn.SiLU(), - nn.Linear(embed_dim, embed_dim, bias=False) + nn.Linear(embed_dim, embed_dim, bias=False), ) if cross_attn_cond_dim > 0 and project_cross_attn_cond: @@ -132,12 +136,21 @@ def __init__(self, self.to_cross_attn_embed = nn.Sequential( nn.Linear(cross_attn_cond_dim, embed_dim, bias=False), nn.SiLU(), - nn.Linear(embed_dim, embed_dim, bias=False) + nn.Linear(embed_dim, embed_dim, bias=False), ) else: self.to_cross_attn_embed = nn.Identity() - def forward(self, x, mask=None, prepend_cond=None, prepend_cond_mask=None, cross_attn_cond=None, global_cond=None, use_cache=False): + def forward( + self, + x, + mask=None, + prepend_cond=None, + prepend_cond_mask=None, + cross_attn_cond=None, + global_cond=None, + use_cache=False, + ): prepend_length = 0 if prepend_cond is not None: @@ -156,4 +169,6 @@ def forward(self, x, mask=None, prepend_cond=None, prepend_cond_mask=None, cross # Project the cross-attention conditioning to the embedding dimension cross_attn_cond = self.to_cross_attn_embed(cross_attn_cond) - return self.model(x, mask=mask, context=cross_attn_cond, prepend_embeds=prepend_cond, prepend_mask=prepend_cond_mask)[:, prepend_length:, :] \ No newline at end of file + return self.model( + x, mask=mask, context=cross_attn_cond, prepend_embeds=prepend_cond, prepend_mask=prepend_cond_mask + )[:, prepend_length:, :] diff --git a/stable_audio_tools/models/local_attention.py b/stable_audio_tools/models/local_attention.py index 893ce11f..d7ccfa51 100644 --- a/stable_audio_tools/models/local_attention.py +++ b/stable_audio_tools/models/local_attention.py @@ -1,15 +1,16 @@ import torch - from einops import rearrange from torch import nn from .blocks import AdaRMSNorm -from .transformer import Attention, FeedForward, RotaryEmbedding, LayerNorm +from .transformer import Attention, FeedForward, LayerNorm, RotaryEmbedding + def checkpoint(function, *args, **kwargs): kwargs.setdefault("use_reentrant", False) return torch.utils.checkpoint.checkpoint(function, *args, **kwargs) + # Adapted from https://github.com/lucidrains/local-attention/blob/master/local_attention/transformer.py class ContinuousLocalTransformer(nn.Module): def __init__( @@ -17,19 +18,19 @@ def __init__( *, dim, depth, - dim_in = None, - dim_out = None, - causal = False, - local_attn_window_size = 64, - heads = 8, - ff_mult = 2, - cond_dim = 0, - cross_attn_cond_dim = 0, + dim_in=None, + dim_out=None, + causal=False, + local_attn_window_size=64, + heads=8, + ff_mult=2, + cond_dim=0, + cross_attn_cond_dim=0, **kwargs ): super().__init__() - - dim_head = dim//heads + + dim_head = dim // heads self.layers = nn.ModuleList([]) @@ -44,30 +45,35 @@ def __init__( self.cross_attn_cond_dim = cross_attn_cond_dim self.rotary_pos_emb = RotaryEmbedding(max(dim_head // 2, 32)) - + for _ in range(depth): - self.layers.append(nn.ModuleList([ - AdaRMSNorm(dim, cond_dim, eps=1e-8) if cond_dim > 0 else LayerNorm(dim), - Attention( - dim=dim, - dim_heads=dim_head, - causal=causal, - zero_init_output=True, - natten_kernel_size=local_attn_window_size, - ), - Attention( - dim=dim, - dim_heads=dim_head, - dim_context = cross_attn_cond_dim, - zero_init_output=True - ) if self.cross_attn_cond_dim > 0 else nn.Identity(), - AdaRMSNorm(dim, cond_dim, eps=1e-8) if cond_dim > 0 else LayerNorm(dim), - FeedForward(dim = dim, mult = ff_mult, no_bias=True) - ])) - - def forward(self, x, mask = None, cond = None, cross_attn_cond = None, cross_attn_cond_mask = None, prepend_cond = None): - + self.layers.append( + nn.ModuleList( + [ + AdaRMSNorm(dim, cond_dim, eps=1e-8) if cond_dim > 0 else LayerNorm(dim), + Attention( + dim=dim, + dim_heads=dim_head, + causal=causal, + zero_init_output=True, + natten_kernel_size=local_attn_window_size, + ), + ( + Attention( + dim=dim, dim_heads=dim_head, dim_context=cross_attn_cond_dim, zero_init_output=True + ) + if self.cross_attn_cond_dim > 0 + else nn.Identity() + ), + AdaRMSNorm(dim, cond_dim, eps=1e-8) if cond_dim > 0 else LayerNorm(dim), + FeedForward(dim=dim, mult=ff_mult, no_bias=True), + ] + ) + ) + + def forward(self, x, mask=None, cond=None, cross_attn_cond=None, cross_attn_cond_mask=None, prepend_cond=None): + x = checkpoint(self.project_in, x) if prepend_cond is not None: @@ -83,7 +89,7 @@ def forward(self, x, mask = None, cond = None, cross_attn_cond = None, cross_att else: x = checkpoint(attn_norm, x) - x = checkpoint(attn, x, mask = mask, rotary_pos_emb=pos_emb) + residual + x = checkpoint(attn, x, mask=mask, rotary_pos_emb=pos_emb) + residual if cross_attn_cond is not None: x = checkpoint(xattn, x, context=cross_attn_cond, context_mask=cross_attn_cond_mask) + x @@ -99,34 +105,23 @@ def forward(self, x, mask = None, cond = None, cross_attn_cond = None, cross_att return checkpoint(self.project_out, x) + class TransformerDownsampleBlock1D(nn.Module): def __init__( - self, - in_channels, - embed_dim = 768, - depth = 3, - heads = 12, - downsample_ratio = 2, - local_attn_window_size = 64, - **kwargs + self, in_channels, embed_dim=768, depth=3, heads=12, downsample_ratio=2, local_attn_window_size=64, **kwargs ): super().__init__() self.downsample_ratio = downsample_ratio self.transformer = ContinuousLocalTransformer( - dim=embed_dim, - depth=depth, - heads=heads, - local_attn_window_size=local_attn_window_size, - **kwargs + dim=embed_dim, depth=depth, heads=heads, local_attn_window_size=local_attn_window_size, **kwargs ) self.project_in = nn.Linear(in_channels, embed_dim, bias=False) if in_channels != embed_dim else nn.Identity() self.project_down = nn.Linear(embed_dim * self.downsample_ratio, embed_dim, bias=False) - - + def forward(self, x): x = checkpoint(self.project_in, x) @@ -142,33 +137,23 @@ def forward(self, x): return x + class TransformerUpsampleBlock1D(nn.Module): def __init__( - self, - in_channels, - embed_dim, - depth = 3, - heads = 12, - upsample_ratio = 2, - local_attn_window_size = 64, - **kwargs + self, in_channels, embed_dim, depth=3, heads=12, upsample_ratio=2, local_attn_window_size=64, **kwargs ): super().__init__() self.upsample_ratio = upsample_ratio self.transformer = ContinuousLocalTransformer( - dim=embed_dim, - depth=depth, - heads=heads, - local_attn_window_size = local_attn_window_size, - **kwargs + dim=embed_dim, depth=depth, heads=heads, local_attn_window_size=local_attn_window_size, **kwargs ) self.project_in = nn.Linear(in_channels, embed_dim, bias=False) if in_channels != embed_dim else nn.Identity() self.project_up = nn.Linear(embed_dim, embed_dim * self.upsample_ratio, bias=False) - + def forward(self, x): # Project to embed dim @@ -181,42 +166,42 @@ def forward(self, x): x = rearrange(x, "b n (c r) -> b (n r) c", r=self.upsample_ratio) # Compute - x = self.transformer(x) + x = self.transformer(x) return x - + class TransformerEncoder1D(nn.Module): def __init__( self, in_channels, out_channels, - embed_dims = [96, 192, 384, 768], - heads = [12, 12, 12, 12], - depths = [3, 3, 3, 3], - ratios = [2, 2, 2, 2], - local_attn_window_size = 64, + embed_dims=[96, 192, 384, 768], + heads=[12, 12, 12, 12], + depths=[3, 3, 3, 3], + ratios=[2, 2, 2, 2], + local_attn_window_size=64, **kwargs ): super().__init__() - + layers = [] - + for layer in range(len(depths)): prev_dim = embed_dims[layer - 1] if layer > 0 else embed_dims[0] layers.append( TransformerDownsampleBlock1D( - in_channels = prev_dim, - embed_dim = embed_dims[layer], - heads = heads[layer], - depth = depths[layer], - downsample_ratio = ratios[layer], - local_attn_window_size = local_attn_window_size, + in_channels=prev_dim, + embed_dim=embed_dims[layer], + heads=heads[layer], + depth=depths[layer], + downsample_ratio=ratios[layer], + local_attn_window_size=local_attn_window_size, **kwargs ) ) - + self.layers = nn.Sequential(*layers) self.project_in = nn.Linear(in_channels, embed_dims[0], bias=False) @@ -237,33 +222,33 @@ def __init__( self, in_channels, out_channels, - embed_dims = [768, 384, 192, 96], - heads = [12, 12, 12, 12], - depths = [3, 3, 3, 3], - ratios = [2, 2, 2, 2], - local_attn_window_size = 64, + embed_dims=[768, 384, 192, 96], + heads=[12, 12, 12, 12], + depths=[3, 3, 3, 3], + ratios=[2, 2, 2, 2], + local_attn_window_size=64, **kwargs ): super().__init__() layers = [] - + for layer in range(len(depths)): prev_dim = embed_dims[layer - 1] if layer > 0 else embed_dims[0] layers.append( TransformerUpsampleBlock1D( - in_channels = prev_dim, - embed_dim = embed_dims[layer], - heads = heads[layer], - depth = depths[layer], - upsample_ratio = ratios[layer], - local_attn_window_size = local_attn_window_size, + in_channels=prev_dim, + embed_dim=embed_dims[layer], + heads=heads[layer], + depth=depths[layer], + upsample_ratio=ratios[layer], + local_attn_window_size=local_attn_window_size, **kwargs ) ) - + self.layers = nn.Sequential(*layers) self.project_in = nn.Linear(in_channels, embed_dims[0], bias=False) @@ -275,4 +260,4 @@ def forward(self, x): x = self.layers(x) x = checkpoint(self.project_out, x) x = rearrange(x, "b n c -> b c n") - return x \ No newline at end of file + return x diff --git a/stable_audio_tools/models/pqmf.py b/stable_audio_tools/models/pqmf.py index 007fdb51..4cc34ef9 100644 --- a/stable_audio_tools/models/pqmf.py +++ b/stable_audio_tools/models/pqmf.py @@ -1,4 +1,5 @@ import math + import numpy as np import torch import torch.nn as nn @@ -6,11 +7,12 @@ from scipy.optimize import fmin from scipy.signal import firwin, kaiser, kaiser_beta, kaiserord + class PQMF(nn.Module): """ Pseudo Quadrature Mirror Filter (PQMF) for multiband signal decomposition and reconstruction. - Uses polyphase representation which is computationally more efficient for real-time. - + Uses polyphase representation which is computationally more efficient for real-time. + Parameters: - attenuation (int): Desired attenuation of the rejected frequency bands, usually between 80 and 120 dB. - num_bands (int): Number of desired frequency bands. It must be a power of 2. @@ -18,16 +20,16 @@ class PQMF(nn.Module): def __init__(self, attenuation, num_bands): super(PQMF, self).__init__() - - # Ensure num_bands is a power of 2 - is_power_of_2 = (math.log2(num_bands) == int(math.log2(num_bands))) + + # Ensure num_bands is a power of 2 + is_power_of_2 = math.log2(num_bands) == int(math.log2(num_bands)) assert is_power_of_2, "'num_bands' must be a power of 2." - + # Create the prototype filter prototype_filter = design_prototype_filter(attenuation, num_bands) filter_bank = generate_modulated_filter_bank(prototype_filter, num_bands) padded_filter_bank = pad_to_nearest_power_of_two(filter_bank) - + # Register filters and settings self.register_buffer("filter_bank", padded_filter_bank) self.register_buffer("prototype", prototype_filter) @@ -35,7 +37,7 @@ def __init__(self, attenuation, num_bands): def forward(self, signal): """Decompose the signal into multiple frequency bands.""" - # If signal is not a pytorch tensor of Batch x Channels x Length, convert it + # If signal is not a pytorch tensor of Batch x Channels x Length, convert it signal = prepare_signal_dimensions(signal) # The signal length must be a multiple of num_bands. Pad it with zeros. signal = pad_signal(signal, self.num_bands) @@ -51,13 +53,13 @@ def inverse(self, bands): def prepare_signal_dimensions(signal): """ - Rearrange signal into Batch x Channels x Length. - + Rearrange signal into Batch x Channels x Length. + Parameters ---------- signal : torch.Tensor or numpy.ndarray The input signal. - + Returns ------- torch.Tensor @@ -66,11 +68,11 @@ def prepare_signal_dimensions(signal): # Convert numpy to torch tensor if isinstance(signal, np.ndarray): signal = torch.from_numpy(signal) - + # Ensure tensor if not isinstance(signal, torch.Tensor): raise ValueError("Input should be either a numpy array or a PyTorch tensor.") - + # Modify dimension of signal to Batch x Channels x Length if signal.dim() == 1: # This is just a mono signal. Unsqueeze to 1 x 1 x Length @@ -80,10 +82,11 @@ def prepare_signal_dimensions(signal): # Rearrange so that larger dimension (Length) is last if signal.shape[0] > signal.shape[1]: signal = signal.T - # Unsqueeze to 1 x Channels x Length + # Unsqueeze to 1 x Channels x Length signal = signal.unsqueeze(0) return signal - + + def pad_signal(signal, num_bands): """ Pads the signal to make its length divisible by the given number of bands. @@ -108,50 +111,49 @@ def pad_signal(signal, num_bands): signal = nn.functional.pad(signal, (0, padding_size)) return signal + def generate_modulated_filter_bank(prototype_filter, num_bands): """ - Generate a QMF bank of cosine modulated filters based on a given prototype filter. - + Generate a QMF bank of cosine modulated filters based on a given prototype filter. + Parameters ---------- prototype_filter : torch.Tensor The prototype filter used as the basis for modulation. num_bands : int The number of desired subbands or filters. - + Returns ------- torch.Tensor A bank of cosine modulated filters. """ - + # Initialize indices for modulation. subband_indices = torch.arange(num_bands).reshape(-1, 1) - + # Calculate the length of the prototype filter. filter_length = prototype_filter.shape[-1] - + # Generate symmetric time indices centered around zero. time_indices = torch.arange(-(filter_length // 2), (filter_length // 2) + 1) - + # Calculate phase offsets to ensure orthogonality between subbands. - phase_offsets = (-1)**subband_indices * np.pi / 4 - + phase_offsets = (-1) ** subband_indices * np.pi / 4 + # Compute the cosine modulation function. - modulation = torch.cos( - (2 * subband_indices + 1) * np.pi / (2 * num_bands) * time_indices + phase_offsets - ) - + modulation = torch.cos((2 * subband_indices + 1) * np.pi / (2 * num_bands) * time_indices + phase_offsets) + # Apply modulation to the prototype filter. modulated_filters = 2 * prototype_filter * modulation - + return modulated_filters def design_kaiser_lowpass(angular_cutoff, attenuation, filter_length=None): """ Design a lowpass filter using the Kaiser window. - + Parameters ---------- angular_cutoff : float @@ -160,28 +162,28 @@ def design_kaiser_lowpass(angular_cutoff, attenuation, filter_length=None): The desired stopband attenuation in decibels (dB). filter_length : int, optional Desired length of the filter. If not provided, it's computed based on the given specs. - + Returns ------- ndarray The designed lowpass filter coefficients. """ - + estimated_length, beta = kaiserord(attenuation, angular_cutoff / np.pi) - + # Ensure the estimated length is odd. estimated_length = 2 * (estimated_length // 2) + 1 - + if filter_length is None: filter_length = estimated_length - - return firwin(filter_length, angular_cutoff, window=('kaiser', beta), scale=False, nyq=np.pi) + + return firwin(filter_length, angular_cutoff, window=("kaiser", beta), scale=False, nyq=np.pi) def evaluate_filter_objective(angular_cutoff, attenuation, num_bands, filter_length): """ Evaluate the filter's objective value based on the criteria from https://ieeexplore.ieee.org/document/681427 - + Parameters ---------- angular_cutoff : float @@ -192,23 +194,23 @@ def evaluate_filter_objective(angular_cutoff, attenuation, num_bands, filter_len Number of bands for the multiband filter system. filter_length : int, optional Desired length of the filter. - + Returns ------- float The computed objective (loss) value for the given filter specs. """ - + filter_coeffs = design_kaiser_lowpass(angular_cutoff, attenuation, filter_length) convolved_filter = np.convolve(filter_coeffs, filter_coeffs[::-1], "full") - - return np.max(np.abs(convolved_filter[convolved_filter.shape[-1] // 2::2 * num_bands][1:])) + + return np.max(np.abs(convolved_filter[convolved_filter.shape[-1] // 2 :: 2 * num_bands][1:])) def design_prototype_filter(attenuation, num_bands, filter_length=None): """ Design the optimal prototype filter for a multiband system given the desired specs. - + Parameters ---------- attenuation : float @@ -217,96 +219,103 @@ def design_prototype_filter(attenuation, num_bands, filter_length=None): Number of bands for the multiband filter system. filter_length : int, optional Desired length of the filter. If not provided, it's computed based on the given specs. - + Returns ------- ndarray The optimal prototype filter coefficients. """ - - optimal_angular_cutoff = fmin(lambda angular_cutoff: evaluate_filter_objective(angular_cutoff, attenuation, num_bands, filter_length), - 1 / num_bands, disp=0)[0] - + + optimal_angular_cutoff = fmin( + lambda angular_cutoff: evaluate_filter_objective(angular_cutoff, attenuation, num_bands, filter_length), + 1 / num_bands, + disp=0, + )[0] + prototype_filter = design_kaiser_lowpass(optimal_angular_cutoff, attenuation, filter_length) return torch.tensor(prototype_filter, dtype=torch.float32) + def pad_to_nearest_power_of_two(x): """ - Pads the input tensor 'x' on both sides such that its last dimension + Pads the input tensor 'x' on both sides such that its last dimension becomes the nearest larger power of two. - + Parameters: ----------- x : torch.Tensor The input tensor to be padded. - + Returns: -------- torch.Tensor The padded tensor. """ current_length = x.shape[-1] - target_length = 2**math.ceil(math.log2(current_length)) - + target_length = 2 ** math.ceil(math.log2(current_length)) + total_padding = target_length - current_length left_padding = total_padding // 2 right_padding = total_padding - left_padding - + return nn.functional.pad(x, (left_padding, right_padding)) + def apply_alias_cancellation(x): """ - Applies alias cancellation by inverting the sign of every - second element of every second row, starting from the second + Applies alias cancellation by inverting the sign of every + second element of every second row, starting from the second row's first element in a tensor. - - This operation helps ensure that the aliasing introduced in - each band during the decomposition will be counteracted during + + This operation helps ensure that the aliasing introduced in + each band during the decomposition will be counteracted during the reconstruction. - + Parameters: ----------- x : torch.Tensor The input tensor. - + Returns: -------- torch.Tensor Tensor with specific elements' sign inverted for alias cancellation. """ - + # Create a mask of the same shape as 'x', initialized with all ones mask = torch.ones_like(x) - + # Update specific elements in the mask to -1 to perform inversion mask[..., 1::2, ::2] = -1 - + # Apply the mask to the input tensor 'x' return x * mask + def ensure_odd_length(tensor): """ Pads the last dimension of a tensor to ensure its size is odd. - + Parameters: ----------- tensor : torch.Tensor Input tensor whose last dimension might need padding. - + Returns: -------- torch.Tensor - The original tensor if its last dimension was already odd, + The original tensor if its last dimension was already odd, or the padded tensor with an odd-sized last dimension. """ - + last_dim_size = tensor.shape[-1] - + if last_dim_size % 2 == 0: tensor = nn.functional.pad(tensor, (0, 1)) - + return tensor + def polyphase_analysis(signal, filter_bank): """ Applies the polyphase method to efficiently analyze the signal using a filter bank. @@ -315,31 +324,31 @@ def polyphase_analysis(signal, filter_bank): ----------- signal : torch.Tensor Input signal tensor with shape (Batch x Channels x Length). - + filter_bank : torch.Tensor Filter bank tensor with shape (Bands x Length). - + Returns: -------- torch.Tensor Signal split into sub-bands. (Batch x Channels x Bands x Length) """ - + num_bands = filter_bank.shape[0] num_channels = signal.shape[1] - - # Rearrange signal for polyphase processing. - # Also combine Batch x Channel into one dimension for now. - #signal = rearrange(signal, "b c (t n) -> b (c n) t", n=num_bands) + + # Rearrange signal for polyphase processing. + # Also combine Batch x Channel into one dimension for now. + # signal = rearrange(signal, "b c (t n) -> b (c n) t", n=num_bands) signal = rearrange(signal, "b c (t n) -> (b c) n t", n=num_bands) - + # Rearrange the filter bank for matching signal shape filter_bank = rearrange(filter_bank, "c (t n) -> c n t", n=num_bands) - + # Apply convolution with appropriate padding to maintain spatial dimensions padding = filter_bank.shape[-1] // 2 filtered_signal = nn.functional.conv1d(signal, filter_bank, padding=padding) - + # Truncate the last dimension post-convolution to adjust the output shape filtered_signal = filtered_signal[..., :-1] # Rearrange the first dimension back into Batch x Channels @@ -347,18 +356,19 @@ def polyphase_analysis(signal, filter_bank): return filtered_signal + def polyphase_synthesis(signal, filter_bank): """ - Polyphase Inverse: Apply polyphase filter bank synthesis to reconstruct a signal. - + Polyphase Inverse: Apply polyphase filter bank synthesis to reconstruct a signal. + Parameters ---------- signal : torch.Tensor Decomposed signal to be reconstructed (shape: Batch x Channels x Bands x Length). - + filter_bank : torch.Tensor Analysis filter bank (shape: Bands x Length). - + should_rearrange : bool, optional Flag to determine if the filters should be rearranged for polyphase synthesis. Default is True. @@ -367,11 +377,11 @@ def polyphase_synthesis(signal, filter_bank): torch.Tensor Reconstructed signal (shape: Batch x Channels X Length) """ - + num_bands = filter_bank.shape[0] num_channels = signal.shape[1] - # Rearrange the filter bank + # Rearrange the filter bank filter_bank = filter_bank.flip(-1) filter_bank = rearrange(filter_bank, "c (t n) -> n c t", n=num_bands) @@ -381,13 +391,13 @@ def polyphase_synthesis(signal, filter_bank): # Apply convolution with appropriate padding padding_amount = filter_bank.shape[-1] // 2 + 1 reconstructed_signal = nn.functional.conv1d(signal, filter_bank, padding=int(padding_amount)) - + # Scale the result reconstructed_signal = reconstructed_signal[..., :-1] * num_bands # Reorganize the output and truncate reconstructed_signal = reconstructed_signal.flip(1) reconstructed_signal = rearrange(reconstructed_signal, "(b c) n t -> b c (t n)", c=num_channels, n=num_bands) - reconstructed_signal = reconstructed_signal[..., 2 * filter_bank.shape[1]:] - - return reconstructed_signal \ No newline at end of file + reconstructed_signal = reconstructed_signal[..., 2 * filter_bank.shape[1] :] + + return reconstructed_signal diff --git a/stable_audio_tools/models/pretrained.py b/stable_audio_tools/models/pretrained.py index e83af343..7ec1f97a 100644 --- a/stable_audio_tools/models/pretrained.py +++ b/stable_audio_tools/models/pretrained.py @@ -1,13 +1,14 @@ import json +from huggingface_hub import hf_hub_download + from .factory import create_model_from_config from .utils import load_ckpt_state_dict -from huggingface_hub import hf_hub_download def get_pretrained_model(name: str): - - model_config_path = hf_hub_download(name, filename="model_config.json", repo_type='model') + + model_config_path = hf_hub_download(name, filename="model_config.json", repo_type="model") with open(model_config_path) as f: model_config = json.load(f) @@ -16,10 +17,10 @@ def get_pretrained_model(name: str): # Try to download the model.safetensors file first, if it doesn't exist, download the model.ckpt file try: - model_ckpt_path = hf_hub_download(name, filename="model.safetensors", repo_type='model') + model_ckpt_path = hf_hub_download(name, filename="model.safetensors", repo_type="model") except Exception as e: - model_ckpt_path = hf_hub_download(name, filename="model.ckpt", repo_type='model') + model_ckpt_path = hf_hub_download(name, filename="model.ckpt", repo_type="model") model.load_state_dict(load_ckpt_state_dict(model_ckpt_path)) - return model, model_config \ No newline at end of file + return model, model_config diff --git a/stable_audio_tools/models/pretransforms.py b/stable_audio_tools/models/pretransforms.py index c9942db5..917374aa 100644 --- a/stable_audio_tools/models/pretransforms.py +++ b/stable_audio_tools/models/pretransforms.py @@ -2,6 +2,7 @@ from einops import rearrange from torch import nn + class Pretransform(nn.Module): def __init__(self, enable_grad, io_channels, is_discrete): super().__init__() @@ -18,37 +19,46 @@ def encode(self, x): def decode(self, z): raise NotImplementedError - + def tokenize(self, x): raise NotImplementedError - + def decode_tokens(self, tokens): raise NotImplementedError + class AutoencoderPretransform(Pretransform): def __init__(self, model, scale=1.0, model_half=False, iterate_batch=False, chunked=False): - super().__init__(enable_grad=False, io_channels=model.io_channels, is_discrete=model.bottleneck is not None and model.bottleneck.is_discrete) + super().__init__( + enable_grad=False, + io_channels=model.io_channels, + is_discrete=model.bottleneck is not None and model.bottleneck.is_discrete, + ) self.model = model self.model.requires_grad_(False).eval() - self.scale=scale + self.scale = scale self.downsampling_ratio = model.downsampling_ratio self.io_channels = model.io_channels self.sample_rate = model.sample_rate - + self.model_half = model_half self.iterate_batch = iterate_batch self.encoded_channels = model.latent_dim self.chunked = chunked - self.num_quantizers = model.bottleneck.num_quantizers if model.bottleneck is not None and model.bottleneck.is_discrete else None - self.codebook_size = model.bottleneck.codebook_size if model.bottleneck is not None and model.bottleneck.is_discrete else None + self.num_quantizers = ( + model.bottleneck.num_quantizers if model.bottleneck is not None and model.bottleneck.is_discrete else None + ) + self.codebook_size = ( + model.bottleneck.codebook_size if model.bottleneck is not None and model.bottleneck.is_discrete else None + ) if self.model_half: self.model.half() - + def encode(self, x, **kwargs): - + if self.model_half: x = x.half() self.model.to(torch.float16) @@ -73,48 +83,50 @@ def decode(self, z, **kwargs): decoded = decoded.float() return decoded - + def tokenize(self, x, **kwargs): assert self.model.is_discrete, "Cannot tokenize with a continuous model" - _, info = self.model.encode(x, return_info = True, **kwargs) + _, info = self.model.encode(x, return_info=True, **kwargs) return info[self.model.bottleneck.tokens_id] - + def decode_tokens(self, tokens, **kwargs): assert self.model.is_discrete, "Cannot decode tokens with a continuous model" return self.model.decode_tokens(tokens, **kwargs) - + def load_state_dict(self, state_dict, strict=True): self.model.load_state_dict(state_dict, strict=strict) + class WaveletPretransform(Pretransform): def __init__(self, channels, levels, wavelet): super().__init__(enable_grad=False, io_channels=channels, is_discrete=False) - from .wavelets import WaveletEncode1d, WaveletDecode1d + from .wavelets import WaveletDecode1d, WaveletEncode1d self.encoder = WaveletEncode1d(channels, levels, wavelet) self.decoder = WaveletDecode1d(channels, levels, wavelet) - self.downsampling_ratio = 2 ** levels + self.downsampling_ratio = 2**levels self.io_channels = channels self.encoded_channels = channels * self.downsampling_ratio - + def encode(self, x): return self.encoder(x) - + def decode(self, z): return self.decoder(z) - + + class PQMFPretransform(Pretransform): def __init__(self, attenuation=100, num_bands=16): # TODO: Fix PQMF to take in in-channels super().__init__(enable_grad=False, io_channels=1, is_discrete=False) from .pqmf import PQMF - self.pqmf = PQMF(attenuation, num_bands) + self.pqmf = PQMF(attenuation, num_bands) def encode(self, x): # x is (Batch x Channels x Time) @@ -125,19 +137,22 @@ def encode(self, x): return rearrange(x, "b c n t -> b (c n) t") def decode(self, x): - # x is (Batch x (Channels Bands) x Time), convert back to (Batch x Channels x Bands x Time) + # x is (Batch x (Channels Bands) x Time), convert back to (Batch x Channels x Bands x Time) x = rearrange(x, "b (c n) t -> b c n t", n=self.pqmf.num_bands) - # returns (Batch x Channels x Time) + # returns (Batch x Channels x Time) return self.pqmf.inverse(x) - + + class PretrainedDACPretransform(Pretransform): - def __init__(self, model_type="44khz", model_bitrate="8kbps", scale=1.0, quantize_on_decode: bool = True, chunked=True): + def __init__( + self, model_type="44khz", model_bitrate="8kbps", scale=1.0, quantize_on_decode: bool = True, chunked=True + ): super().__init__(enable_grad=False, io_channels=1, is_discrete=True) - + import dac - + model_path = dac.utils.download(model_type=model_type, model_bitrate=model_bitrate) - + self.model = dac.DAC.load(model_path) self.quantize_on_decode = quantize_on_decode @@ -168,14 +183,14 @@ def encode(self, x): else: z, _, _, _, _ = self.model.quantizer(latents, n_quantizers=self.model.n_codebooks) output = z - + if self.scale != 1.0: output = output / self.scale - + return output def decode(self, z): - + if self.scale != 1.0: z = z * self.scale @@ -186,20 +201,21 @@ def decode(self, z): def tokenize(self, x): return self.model.encode(x)[1] - + def decode_tokens(self, tokens): latents = self.model.quantizer.from_codes(tokens) return self.model.decode(latents) - + + class AudiocraftCompressionPretransform(Pretransform): def __init__(self, model_type="facebook/encodec_32khz", scale=1.0, quantize_on_decode: bool = True): super().__init__(enable_grad=False, io_channels=1, is_discrete=True) - + try: from audiocraft.models import CompressionModel except ImportError: raise ImportError("Audiocraft is not installed. Please install audiocraft to use Audiocraft models.") - + self.model = CompressionModel.get_pretrained(model_type) self.quantize_on_decode = quantize_on_decode @@ -212,7 +228,7 @@ def __init__(self, model_type="facebook/encodec_32khz", scale=1.0, quantize_on_d self.scale = scale - #self.encoded_channels = self.model.latent_dim + # self.encoded_channels = self.model.latent_dim self.num_quantizers = self.model.num_codebooks @@ -231,14 +247,14 @@ def encode(self, x): # else: # z, _, _, _, _ = self.model.quantizer(latents, n_quantizers=self.model.n_codebooks) # output = z - + # if self.scale != 1.0: # output = output / self.scale - + # return output def decode(self, z): - + assert False, "Audiocraft compression models do not support continuous decoding" # if self.scale != 1.0: @@ -252,7 +268,7 @@ def decode(self, z): def tokenize(self, x): with torch.cuda.amp.autocast(enabled=False): return self.model.encode(x.to(torch.float16))[0] - + def decode_tokens(self, tokens): with torch.cuda.amp.autocast(enabled=False): return self.model.decode(tokens) diff --git a/stable_audio_tools/models/transformer.py b/stable_audio_tools/models/transformer.py index 65965b49..d1480973 100644 --- a/stable_audio_tools/models/transformer.py +++ b/stable_audio_tools/models/transformer.py @@ -1,19 +1,19 @@ -from functools import reduce, partial -from packaging import version +from functools import partial, reduce +from typing import Callable, Literal -from einops import rearrange, repeat -from einops.layers.torch import Rearrange import torch import torch.nn.functional as F -from torch import nn, einsum +from einops import rearrange, repeat +from einops.layers.torch import Rearrange +from packaging import version +from torch import einsum, nn from torch.cuda.amp import autocast -from typing import Callable, Literal try: from flash_attn import flash_attn_func, flash_attn_kvpacked_func except ImportError as e: print(e) - print('flash_attn not installed, disabling Flash Attention') + print("flash_attn not installed, disabling Flash Attention") flash_attn_kvpacked_func = None flash_attn_func = None @@ -22,6 +22,7 @@ except ImportError: natten = None + def checkpoint(function, *args, **kwargs): kwargs.setdefault("use_reentrant", False) return torch.utils.checkpoint.checkpoint(function, *args, **kwargs) @@ -30,8 +31,10 @@ def checkpoint(function, *args, **kwargs): # Copied and modified from https://github.com/lucidrains/x-transformers/blob/main/x_transformers/attend.py under MIT License # License can be found in LICENSES/LICENSE_XTRANSFORMERS.txt + def create_causal_mask(i, j, device): - return torch.ones((i, j), device = device, dtype = torch.bool).triu(j - i + 1) + return torch.ones((i, j), device=device, dtype=torch.bool).triu(j - i + 1) + def or_reduce(masks): head, *body = masks @@ -39,62 +42,62 @@ def or_reduce(masks): head = head | rest return head + # positional embeddings + class AbsolutePositionalEmbedding(nn.Module): def __init__(self, dim, max_seq_len): super().__init__() - self.scale = dim ** -0.5 + self.scale = dim**-0.5 self.max_seq_len = max_seq_len self.emb = nn.Embedding(max_seq_len, dim) - def forward(self, x, pos = None, seq_start_pos = None): + def forward(self, x, pos=None, seq_start_pos=None): seq_len, device = x.shape[1], x.device - assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}' + assert ( + seq_len <= self.max_seq_len + ), f"you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}" if pos is None: - pos = torch.arange(seq_len, device = device) + pos = torch.arange(seq_len, device=device) if seq_start_pos is not None: - pos = (pos - seq_start_pos[..., None]).clamp(min = 0) + pos = (pos - seq_start_pos[..., None]).clamp(min=0) pos_emb = self.emb(pos) pos_emb = pos_emb * self.scale return pos_emb + class ScaledSinusoidalEmbedding(nn.Module): - def __init__(self, dim, theta = 10000): + def __init__(self, dim, theta=10000): super().__init__() - assert (dim % 2) == 0, 'dimension must be divisible by 2' - self.scale = nn.Parameter(torch.ones(1) * dim ** -0.5) + assert (dim % 2) == 0, "dimension must be divisible by 2" + self.scale = nn.Parameter(torch.ones(1) * dim**-0.5) half_dim = dim // 2 freq_seq = torch.arange(half_dim).float() / half_dim - inv_freq = theta ** -freq_seq - self.register_buffer('inv_freq', inv_freq, persistent = False) + inv_freq = theta**-freq_seq + self.register_buffer("inv_freq", inv_freq, persistent=False) - def forward(self, x, pos = None, seq_start_pos = None): + def forward(self, x, pos=None, seq_start_pos=None): seq_len, device = x.shape[1], x.device if pos is None: - pos = torch.arange(seq_len, device = device) + pos = torch.arange(seq_len, device=device) if seq_start_pos is not None: pos = pos - seq_start_pos[..., None] - emb = einsum('i, j -> i j', pos, self.inv_freq) - emb = torch.cat((emb.sin(), emb.cos()), dim = -1) + emb = einsum("i, j -> i j", pos, self.inv_freq) + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) return emb * self.scale - + + class RotaryEmbedding(nn.Module): def __init__( - self, - dim, - use_xpos = False, - scale_base = 512, - interpolation_factor = 1., - base = 10000, - base_rescale_factor = 1. + self, dim, use_xpos=False, scale_base=512, interpolation_factor=1.0, base=10000, base_rescale_factor=1.0 ): super().__init__() # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning @@ -102,28 +105,28 @@ def __init__( # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ base *= base_rescale_factor ** (dim / (dim - 2)) - inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim)) - self.register_buffer('inv_freq', inv_freq) + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq) - assert interpolation_factor >= 1. + assert interpolation_factor >= 1.0 self.interpolation_factor = interpolation_factor if not use_xpos: - self.register_buffer('scale', None) + self.register_buffer("scale", None) return scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim) self.scale_base = scale_base - self.register_buffer('scale', scale) + self.register_buffer("scale", scale) def forward_from_seq_len(self, seq_len): device = self.inv_freq.device - t = torch.arange(seq_len, device = device) + t = torch.arange(seq_len, device=device) return self.forward(t) - @autocast(enabled = False) + @autocast(enabled=False) def forward(self, t): device = self.inv_freq.device @@ -131,25 +134,27 @@ def forward(self, t): t = t / self.interpolation_factor - freqs = torch.einsum('i , j -> i j', t, self.inv_freq) - freqs = torch.cat((freqs, freqs), dim = -1) + freqs = torch.einsum("i , j -> i j", t, self.inv_freq) + freqs = torch.cat((freqs, freqs), dim=-1) if self.scale is None: - return freqs, 1. + return freqs, 1.0 - power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base - scale = self.scale ** rearrange(power, 'n -> n 1') - scale = torch.cat((scale, scale), dim = -1) + power = (torch.arange(seq_len, device=device) - (seq_len // 2)) / self.scale_base + scale = self.scale ** rearrange(power, "n -> n 1") + scale = torch.cat((scale, scale), dim=-1) return freqs, scale + def rotate_half(x): - x = rearrange(x, '... (j d) -> ... j d', j = 2) - x1, x2 = x.unbind(dim = -2) - return torch.cat((-x2, x1), dim = -1) + x = rearrange(x, "... (j d) -> ... j d", j=2) + x1, x2 = x.unbind(dim=-2) + return torch.cat((-x2, x1), dim=-1) -@autocast(enabled = False) -def apply_rotary_pos_emb(t, freqs, scale = 1): + +@autocast(enabled=False) +def apply_rotary_pos_emb(t, freqs, scale=1): out_dtype = t.dtype # cast to float32 if necessary for numerical stability @@ -159,7 +164,7 @@ def apply_rotary_pos_emb(t, freqs, scale = 1): freqs = freqs[-seq_len:, :] if t.ndim == 4 and freqs.ndim == 3: - freqs = rearrange(freqs, 'b n d -> b 1 n d') + freqs = rearrange(freqs, "b n d -> b 1 n d") # partial rotary embeddings, Wang et al. GPT-J t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:] @@ -167,7 +172,8 @@ def apply_rotary_pos_emb(t, freqs, scale = 1): t, t_unrotated = t.to(out_dtype), t_unrotated.to(out_dtype) - return torch.cat((t, t_unrotated), dim = -1) + return torch.cat((t, t_unrotated), dim=-1) + # norms class LayerNorm(nn.Module): @@ -187,48 +193,54 @@ def __init__(self, dim, bias=False, fix_scale=False): else: self.register_buffer("beta", torch.zeros(dim)) - def forward(self, x): return F.layer_norm(x, x.shape[-1:], weight=self.gamma, bias=self.beta) + # feedforward + class GLU(nn.Module): def __init__( self, dim_in, dim_out, activation: Callable, - use_conv = False, - conv_kernel_size = 3, + use_conv=False, + conv_kernel_size=3, ): super().__init__() self.act = activation - self.proj = nn.Linear(dim_in, dim_out * 2) if not use_conv else nn.Conv1d(dim_in, dim_out * 2, conv_kernel_size, padding = (conv_kernel_size // 2)) + self.proj = ( + nn.Linear(dim_in, dim_out * 2) + if not use_conv + else nn.Conv1d(dim_in, dim_out * 2, conv_kernel_size, padding=(conv_kernel_size // 2)) + ) self.use_conv = use_conv def forward(self, x): if self.use_conv: - x = rearrange(x, 'b n d -> b d n') + x = rearrange(x, "b n d -> b d n") x = self.proj(x) - x = rearrange(x, 'b d n -> b n d') + x = rearrange(x, "b d n -> b n d") else: x = self.proj(x) - x, gate = x.chunk(2, dim = -1) + x, gate = x.chunk(2, dim=-1) return x * self.act(gate) + class FeedForward(nn.Module): def __init__( self, dim, - dim_out = None, - mult = 4, - no_bias = False, - glu = True, - use_conv = False, - conv_kernel_size = 3, - zero_init_output = True, + dim_out=None, + mult=4, + no_bias=False, + glu=True, + use_conv=False, + conv_kernel_size=3, + zero_init_output=True, ): super().__init__() inner_dim = int(dim * mult) @@ -243,13 +255,21 @@ def __init__( linear_in = GLU(dim, inner_dim, activation) else: linear_in = nn.Sequential( - Rearrange('b n d -> b d n') if use_conv else nn.Identity(), - nn.Linear(dim, inner_dim, bias = not no_bias) if not use_conv else nn.Conv1d(dim, inner_dim, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias), - Rearrange('b n d -> b d n') if use_conv else nn.Identity(), - activation + Rearrange("b n d -> b d n") if use_conv else nn.Identity(), + ( + nn.Linear(dim, inner_dim, bias=not no_bias) + if not use_conv + else nn.Conv1d(dim, inner_dim, conv_kernel_size, padding=(conv_kernel_size // 2), bias=not no_bias) + ), + Rearrange("b n d -> b d n") if use_conv else nn.Identity(), + activation, ) - linear_out = nn.Linear(inner_dim, dim_out, bias = not no_bias) if not use_conv else nn.Conv1d(inner_dim, dim_out, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias) + linear_out = ( + nn.Linear(inner_dim, dim_out, bias=not no_bias) + if not use_conv + else nn.Conv1d(inner_dim, dim_out, conv_kernel_size, padding=(conv_kernel_size // 2), bias=not no_bias) + ) # init last linear layer to 0 if zero_init_output: @@ -257,27 +277,27 @@ def __init__( if not no_bias: nn.init.zeros_(linear_out.bias) - self.ff = nn.Sequential( linear_in, - Rearrange('b d n -> b n d') if use_conv else nn.Identity(), + Rearrange("b d n -> b n d") if use_conv else nn.Identity(), linear_out, - Rearrange('b n d -> b d n') if use_conv else nn.Identity(), + Rearrange("b n d -> b d n") if use_conv else nn.Identity(), ) def forward(self, x): return self.ff(x) + class Attention(nn.Module): def __init__( self, dim, - dim_heads = 64, - dim_context = None, - causal = False, + dim_heads=64, + dim_context=None, + causal=False, zero_init_output=True, - qk_norm: Literal['l2', 'ln', 'none'] = 'none', - natten_kernel_size = None + qk_norm: Literal["l2", "ln", "none"] = "none", + natten_kernel_size=None, ): super().__init__() self.dim = dim @@ -285,7 +305,7 @@ def __init__( self.causal = causal dim_kv = dim_context if dim_context is not None else dim - + self.num_heads = dim // dim_heads self.kv_heads = dim_kv // dim_heads @@ -311,24 +331,13 @@ def __init__( if natten_kernel_size is not None: return - self.use_pt_flash = torch.cuda.is_available() and version.parse(torch.__version__) >= version.parse('2.0.0') + self.use_pt_flash = torch.cuda.is_available() and version.parse(torch.__version__) >= version.parse("2.0.0") self.use_fa_flash = torch.cuda.is_available() and flash_attn_func is not None - self.sdp_kwargs = dict( - enable_flash = True, - enable_math = True, - enable_mem_efficient = True - ) + self.sdp_kwargs = dict(enable_flash=True, enable_math=True, enable_mem_efficient=True) - def flash_attn( - self, - q, - k, - v, - mask = None, - causal = None - ): + def flash_attn(self, q, k, v, mask=None, causal=None): batch, heads, q_len, _, k_len, device = *q.shape, k.shape[-2], q.device kv_heads = k.shape[1] # Recommended for multi-query single-key-value attention by Tri Dao @@ -337,19 +346,19 @@ def flash_attn( if heads != kv_heads: # Repeat interleave kv_heads to match q_heads heads_per_kv_head = heads // kv_heads - k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v)) + k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim=1), (k, v)) if k.ndim == 3: - k = rearrange(k, 'b ... -> b 1 ...').expand_as(q) + k = rearrange(k, "b ... -> b 1 ...").expand_as(q) if v.ndim == 3: - v = rearrange(v, 'b ... -> b 1 ...').expand_as(q) + v = rearrange(v, "b ... -> b 1 ...").expand_as(q) causal = self.causal if causal is None else causal if q_len == 1 and causal: causal = False - + if mask is not None: assert mask.ndim == 4 mask = mask.expand(batch, heads, q_len, k_len) @@ -357,7 +366,7 @@ def flash_attn( # handle kv cache - this should be bypassable in updated flash attention 2 if k_len > q_len and causal: - causal_mask = self.create_causal_mask(q_len, k_len, device = device) + causal_mask = self.create_causal_mask(q_len, k_len, device=device) if mask is None: mask = ~causal_mask else: @@ -369,56 +378,44 @@ def flash_attn( row_is_entirely_masked = None if mask is not None and causal: - causal_mask = self.create_causal_mask(q_len, k_len, device = device) + causal_mask = self.create_causal_mask(q_len, k_len, device=device) mask = mask & ~causal_mask # protect against an entire row being masked out - row_is_entirely_masked = ~mask.any(dim = -1) + row_is_entirely_masked = ~mask.any(dim=-1) mask[..., 0] = mask[..., 0] | row_is_entirely_masked causal = False - + with torch.backends.cuda.sdp_kernel(**self.sdp_kwargs): - out = F.scaled_dot_product_attention( - q, k, v, - attn_mask = mask, - is_causal = causal - ) + out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, is_causal=causal) # for a row that is entirely masked out, should zero out the output of that row token if row_is_entirely_masked is not None: - out = out.masked_fill(row_is_entirely_masked[..., None], 0.) + out = out.masked_fill(row_is_entirely_masked[..., None], 0.0) return out - def forward( - self, - x, - context = None, - mask = None, - context_mask = None, - rotary_pos_emb = None, - causal = None - ): + def forward(self, x, context=None, mask=None, context_mask=None, rotary_pos_emb=None, causal=None): h, kv_h, has_context = self.num_heads, self.kv_heads, context is not None kv_input = context if has_context else x - if hasattr(self, 'to_q'): + if hasattr(self, "to_q"): # Use separate linear projections for q and k/v q = self.to_q(x) - q = rearrange(q, 'b n (h d) -> b h n d', h = h) + q = rearrange(q, "b n (h d) -> b h n d", h=h) k, v = self.to_kv(kv_input).chunk(2, dim=-1) - k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = kv_h), (k, v)) + k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=kv_h), (k, v)) else: # Use fused linear projection q, k, v = self.to_qkv(x).chunk(3, dim=-1) - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) - + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) + # Normalize q and k for cosine sim attention if self.qk_norm == "l2": q = F.normalize(q, dim=-1) @@ -442,18 +439,18 @@ def forward( q = q.to(q_dtype) k = k.to(k_dtype) - - input_mask = context_mask + + input_mask = context_mask if input_mask is None and not has_context: input_mask = mask # determine masking masks = [] - final_attn_mask = None # The mask that will be applied to the attention matrix, taking all masks into account + final_attn_mask = None # The mask that will be applied to the attention matrix, taking all masks into account if input_mask is not None: - input_mask = rearrange(input_mask, 'b j -> b 1 1 j') + input_mask = rearrange(input_mask, "b j -> b 1 1 j") masks.append(~input_mask) # Other masks will be added here later @@ -470,34 +467,34 @@ def forward( if self.natten_kernel_size is not None: if natten is None: - raise ImportError('natten not installed, please install natten to use neighborhood attention') - + raise ImportError("natten not installed, please install natten to use neighborhood attention") + dtype_in = q.dtype q, k, v = map(lambda t: t.to(torch.float32), (q, k, v)) - attn = natten.functional.natten1dqk(q, k, kernel_size = self.natten_kernel_size, dilation=1) + attn = natten.functional.natten1dqk(q, k, kernel_size=self.natten_kernel_size, dilation=1) if final_attn_mask is not None: attn = attn.masked_fill(final_attn_mask, -torch.finfo(attn.dtype).max) attn = F.softmax(attn, dim=-1, dtype=torch.float32) - out = natten.functional.natten1dav(attn, v, kernel_size = self.natten_kernel_size, dilation=1).to(dtype_in) + out = natten.functional.natten1dav(attn, v, kernel_size=self.natten_kernel_size, dilation=1).to(dtype_in) # Prioritize Flash Attention 2 elif self.use_fa_flash: - assert final_attn_mask is None, 'masking not yet supported for Flash Attention 2' + assert final_attn_mask is None, "masking not yet supported for Flash Attention 2" # Flash Attention 2 requires FP16 inputs fa_dtype_in = q.dtype - q, k, v = map(lambda t: rearrange(t, 'b h n d -> b n h d').to(torch.float16), (q, k, v)) - - out = flash_attn_func(q, k, v, causal = causal) - - out = rearrange(out.to(fa_dtype_in), 'b n h d -> b h n d') + q, k, v = map(lambda t: rearrange(t, "b h n d -> b n h d").to(torch.float16), (q, k, v)) + + out = flash_attn_func(q, k, v, causal=causal) + + out = rearrange(out.to(fa_dtype_in), "b n h d -> b h n d") # Fall back to PyTorch implementation elif self.use_pt_flash: - out = self.flash_attn(q, k, v, causal = causal, mask = final_attn_mask) + out = self.flash_attn(q, k, v, causal=causal, mask=final_attn_mask) else: # Fall back to custom implementation @@ -505,14 +502,14 @@ def forward( if h != kv_h: # Repeat interleave kv_heads to match q_heads heads_per_kv_head = h // kv_h - k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v)) + k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim=1), (k, v)) + + scale = 1.0 / (q.shape[-1] ** 0.5) - scale = 1. / (q.shape[-1] ** 0.5) + kv_einsum_eq = "b j d" if k.ndim == 3 else "b h j d" - kv_einsum_eq = 'b j d' if k.ndim == 3 else 'b h j d' + dots = einsum(f"b h i d, {kv_einsum_eq} -> b h i j", q, k) * scale - dots = einsum(f'b h i d, {kv_einsum_eq} -> b h i j', q, k) * scale - i, j, dtype = *dots.shape[-2:], dots.dtype mask_value = -torch.finfo(dots.dtype).max @@ -521,19 +518,19 @@ def forward( dots = dots.masked_fill(~final_attn_mask, mask_value) if causal: - causal_mask = self.create_causal_mask(i, j, device = device) + causal_mask = self.create_causal_mask(i, j, device=device) dots = dots.masked_fill(causal_mask, mask_value) attn = F.softmax(dots, dim=-1, dtype=torch.float32) attn = attn.type(dtype) - out = einsum(f'b h i j, {kv_einsum_eq} -> b h i d', attn, v) + out = einsum(f"b h i j, {kv_einsum_eq} -> b h i d", attn, v) # merge heads - out = rearrange(out, ' b h n d -> b n (h d)') + out = rearrange(out, " b h n d -> b n (h d)") # Communicate between heads - + # with autocast(enabled = False): # out_dtype = out.dtype # out = out.to(torch.float32) @@ -541,65 +538,69 @@ def forward( out = self.to_out(out) if mask is not None: - mask = rearrange(mask, 'b n -> b n 1') - out = out.masked_fill(~mask, 0.) + mask = rearrange(mask, "b n -> b n 1") + out = out.masked_fill(~mask, 0.0) return out + class ConformerModule(nn.Module): def __init__( self, dim, - norm_kwargs = {}, - ): + norm_kwargs={}, + ): super().__init__() self.dim = dim - + self.in_norm = LayerNorm(dim, **norm_kwargs) self.pointwise_conv = nn.Conv1d(dim, dim, kernel_size=1, bias=False) self.glu = GLU(dim, dim, nn.SiLU()) self.depthwise_conv = nn.Conv1d(dim, dim, kernel_size=17, groups=dim, padding=8, bias=False) - self.mid_norm = LayerNorm(dim, **norm_kwargs) # This is a batch norm in the original but I don't like batch norm + self.mid_norm = LayerNorm( + dim, **norm_kwargs + ) # This is a batch norm in the original but I don't like batch norm self.swish = nn.SiLU() self.pointwise_conv_2 = nn.Conv1d(dim, dim, kernel_size=1, bias=False) def forward(self, x): x = self.in_norm(x) - x = rearrange(x, 'b n d -> b d n') + x = rearrange(x, "b n d -> b d n") x = self.pointwise_conv(x) - x = rearrange(x, 'b d n -> b n d') + x = rearrange(x, "b d n -> b n d") x = self.glu(x) - x = rearrange(x, 'b n d -> b d n') + x = rearrange(x, "b n d -> b d n") x = self.depthwise_conv(x) - x = rearrange(x, 'b d n -> b n d') + x = rearrange(x, "b d n -> b n d") x = self.mid_norm(x) x = self.swish(x) - x = rearrange(x, 'b n d -> b d n') + x = rearrange(x, "b n d -> b d n") x = self.pointwise_conv_2(x) - x = rearrange(x, 'b d n -> b n d') + x = rearrange(x, "b d n -> b n d") return x + class TransformerBlock(nn.Module): def __init__( - self, - dim, - dim_heads = 64, - cross_attend = False, - dim_context = None, - global_cond_dim = None, - causal = False, - zero_init_branch_outputs = True, - conformer = False, - layer_ix = -1, - remove_norms = False, - attn_kwargs = {}, - ff_kwargs = {}, - norm_kwargs = {} + self, + dim, + dim_heads=64, + cross_attend=False, + dim_context=None, + global_cond_dim=None, + causal=False, + zero_init_branch_outputs=True, + conformer=False, + layer_ix=-1, + remove_norms=False, + attn_kwargs={}, + ff_kwargs={}, + norm_kwargs={}, ): - + super().__init__() self.dim = dim self.dim_heads = dim_heads @@ -610,24 +611,20 @@ def __init__( self.pre_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else nn.Identity() self.self_attn = Attention( - dim, - dim_heads = dim_heads, - causal = causal, - zero_init_output=zero_init_branch_outputs, - **attn_kwargs + dim, dim_heads=dim_heads, causal=causal, zero_init_output=zero_init_branch_outputs, **attn_kwargs ) if cross_attend: self.cross_attend_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else nn.Identity() self.cross_attn = Attention( dim, - dim_heads = dim_heads, + dim_heads=dim_heads, dim_context=dim_context, - causal = causal, + causal=causal, zero_init_output=zero_init_branch_outputs, - **attn_kwargs + **attn_kwargs, ) - + self.ff_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else nn.Identity() self.ff = FeedForward(dim, zero_init_output=zero_init_branch_outputs, **ff_kwargs) @@ -638,37 +635,28 @@ def __init__( self.global_cond_dim = global_cond_dim if global_cond_dim is not None: - self.to_scale_shift_gate = nn.Sequential( - nn.SiLU(), - nn.Linear(global_cond_dim, dim * 6, bias=False) - ) + self.to_scale_shift_gate = nn.Sequential(nn.SiLU(), nn.Linear(global_cond_dim, dim * 6, bias=False)) nn.init.zeros_(self.to_scale_shift_gate[1].weight) - #nn.init.zeros_(self.to_scale_shift_gate_self[1].bias) + # nn.init.zeros_(self.to_scale_shift_gate_self[1].bias) - def forward( - self, - x, - context = None, - global_cond=None, - mask = None, - context_mask = None, - rotary_pos_emb = None - ): + def forward(self, x, context=None, global_cond=None, mask=None, context_mask=None, rotary_pos_emb=None): if self.global_cond_dim is not None and self.global_cond_dim > 0 and global_cond is not None: - - scale_self, shift_self, gate_self, scale_ff, shift_ff, gate_ff = self.to_scale_shift_gate(global_cond).unsqueeze(1).chunk(6, dim = -1) + + scale_self, shift_self, gate_self, scale_ff, shift_ff, gate_ff = ( + self.to_scale_shift_gate(global_cond).unsqueeze(1).chunk(6, dim=-1) + ) # self-attention with adaLN residual = x x = self.pre_norm(x) x = x * (1 + scale_self) + shift_self - x = self.self_attn(x, mask = mask, rotary_pos_emb = rotary_pos_emb) + x = self.self_attn(x, mask=mask, rotary_pos_emb=rotary_pos_emb) x = x * torch.sigmoid(1 - gate_self) x = x + residual if context is not None: - x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask) + x = x + self.cross_attn(self.cross_attend_norm(x), context=context, context_mask=context_mask) if self.conformer is not None: x = x + self.conformer(x) @@ -682,10 +670,10 @@ def forward( x = x + residual else: - x = x + self.self_attn(self.pre_norm(x), mask = mask, rotary_pos_emb = rotary_pos_emb) + x = x + self.self_attn(self.pre_norm(x), mask=mask, rotary_pos_emb=rotary_pos_emb) if context is not None: - x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask) + x = x + self.cross_attn(self.cross_attend_norm(x), context=context, context_mask=context_mask) if self.conformer is not None: x = x + self.conformer(x) @@ -693,16 +681,17 @@ def forward( x = x + self.ff(self.ff_norm(x)) return x - + + class ContinuousTransformer(nn.Module): def __init__( self, dim, depth, *, - dim_in = None, - dim_out = None, - dim_heads = 64, + dim_in=None, + dim_out=None, + dim_heads=64, cross_attend=False, cond_token_dim=None, global_cond_dim=None, @@ -713,8 +702,8 @@ def __init__( use_sinusoidal_emb=False, use_abs_pos_emb=False, abs_pos_emb_max_length=10000, - **kwargs - ): + **kwargs, + ): super().__init__() @@ -743,27 +732,20 @@ def __init__( self.layers.append( TransformerBlock( dim, - dim_heads = dim_heads, - cross_attend = cross_attend, - dim_context = cond_token_dim, - global_cond_dim = global_cond_dim, - causal = causal, - zero_init_branch_outputs = zero_init_branch_outputs, + dim_heads=dim_heads, + cross_attend=cross_attend, + dim_context=cond_token_dim, + global_cond_dim=global_cond_dim, + causal=causal, + zero_init_branch_outputs=zero_init_branch_outputs, conformer=conformer, layer_ix=i, - **kwargs + **kwargs, ) ) - + def forward( - self, - x, - mask = None, - prepend_embeds = None, - prepend_mask = None, - global_cond = None, - return_info = False, - **kwargs + self, x, mask=None, prepend_embeds=None, prepend_mask=None, global_cond=None, return_info=False, **kwargs ): batch, seq, device = *x.shape[:2], x.device @@ -776,17 +758,21 @@ def forward( if prepend_embeds is not None: prepend_length, prepend_dim = prepend_embeds.shape[1:] - assert prepend_dim == x.shape[-1], 'prepend dimension must match sequence dimension' + assert prepend_dim == x.shape[-1], "prepend dimension must match sequence dimension" - x = torch.cat((prepend_embeds, x), dim = -2) + x = torch.cat((prepend_embeds, x), dim=-2) if prepend_mask is not None or mask is not None: - mask = mask if mask is not None else torch.ones((batch, seq), device = device, dtype = torch.bool) - prepend_mask = prepend_mask if prepend_mask is not None else torch.ones((batch, prepend_length), device = device, dtype = torch.bool) + mask = mask if mask is not None else torch.ones((batch, seq), device=device, dtype=torch.bool) + prepend_mask = ( + prepend_mask + if prepend_mask is not None + else torch.ones((batch, prepend_length), device=device, dtype=torch.bool) + ) - mask = torch.cat((prepend_mask, mask), dim = -1) + mask = torch.cat((prepend_mask, mask), dim=-1) - # Attention layers + # Attention layers if self.rotary_pos_emb is not None: rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1]) @@ -798,8 +784,8 @@ def forward( # Iterate over the transformer layers for layer in self.layers: - #x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs) - x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs) + # x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs) + x = checkpoint(layer, x, rotary_pos_emb=rotary_pos_emb, global_cond=global_cond, **kwargs) if return_info: info["hidden_states"].append(x) @@ -808,5 +794,5 @@ def forward( if return_info: return x, info - + return x diff --git a/stable_audio_tools/models/utils.py b/stable_audio_tools/models/utils.py index c7f48b7c..434513c6 100644 --- a/stable_audio_tools/models/utils.py +++ b/stable_audio_tools/models/utils.py @@ -1,16 +1,17 @@ import torch from safetensors.torch import load_file - from torch.nn.utils import remove_weight_norm + def load_ckpt_state_dict(ckpt_path): if ckpt_path.endswith(".safetensors"): state_dict = load_file(ckpt_path) else: state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"] - + return state_dict + def remove_weight_norm_from_model(model): for module in model.modules(): if hasattr(module, "weight"): @@ -19,9 +20,11 @@ def remove_weight_norm_from_model(model): return model + # Sampling functions copied from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/utils/utils.py under MIT license # License can be found in LICENSES/LICENSE_META.txt + def multinomial(input: torch.Tensor, num_samples: int, replacement=False, *, generator=None): """torch.multinomial with arbitrary number of dimensions, and number of candidates on the last dimension. @@ -82,8 +85,10 @@ def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor: next_token = torch.gather(probs_idx, -1, next_token) return next_token + def next_power_of_two(n): return 2 ** (n - 1).bit_length() + def next_multiple_of_64(n): - return ((n + 63) // 64) * 64 \ No newline at end of file + return ((n + 63) // 64) * 64 diff --git a/stable_audio_tools/models/wavelets.py b/stable_audio_tools/models/wavelets.py index a359e391..76bee1b2 100644 --- a/stable_audio_tools/models/wavelets.py +++ b/stable_audio_tools/models/wavelets.py @@ -1,11 +1,12 @@ """The 1D discrete wavelet transform for PyTorch.""" -from einops import rearrange +from typing import Literal + import pywt import torch +from einops import rearrange from torch import nn from torch.nn import functional as F -from typing import Literal def get_filter_bank(wavelet): @@ -14,11 +15,14 @@ def get_filter_bank(wavelet): filt = filt[:, 1:] return filt + class WaveletEncode1d(nn.Module): - def __init__(self, - channels, - levels, - wavelet: Literal["bior2.2", "bior2.4", "bior2.6", "bior2.8", "bior4.4", "bior6.8"] = "bior4.4"): + def __init__( + self, + channels, + levels, + wavelet: Literal["bior2.2", "bior2.4", "bior2.6", "bior2.8", "bior4.4", "bior6.8"] = "bior4.4", + ): super().__init__() self.wavelet = wavelet self.channels = channels @@ -39,18 +43,18 @@ def forward(self, x): pad = self.kernel.shape[-1] // 2 low = F.pad(low, (pad, pad), "reflect") low = F.conv1d(low, self.kernel, stride=2) - rest = rearrange( - rest, "n (c c2) (l l2) -> n (c l2 c2) l", l2=2, c2=self.channels - ) + rest = rearrange(rest, "n (c c2) (l l2) -> n (c l2 c2) l", l2=2, c2=self.channels) x = torch.cat([low, rest], dim=1) return x class WaveletDecode1d(nn.Module): - def __init__(self, - channels, - levels, - wavelet: Literal["bior2.2", "bior2.4", "bior2.6", "bior2.8", "bior4.4", "bior6.8"] = "bior4.4"): + def __init__( + self, + channels, + levels, + wavelet: Literal["bior2.2", "bior2.4", "bior2.6", "bior2.8", "bior4.4", "bior6.8"] = "bior4.4", + ): super().__init__() self.wavelet = wavelet self.channels = channels @@ -71,12 +75,8 @@ def forward(self, x): low = rearrange(low, "n (l2 c) l -> n c (l l2)", l2=2) low = F.pad(low, (pad, pad), "reflect") low = rearrange(low, "n c (l l2) -> n (l2 c) l", l2=2) - low = F.conv_transpose1d( - low, self.kernel, stride=2, padding=self.kernel.shape[-1] // 2 - ) + low = F.conv_transpose1d(low, self.kernel, stride=2, padding=self.kernel.shape[-1] // 2) low = low[..., pad - 1 : -pad] - rest = rearrange( - rest, "n (c l2 c2) l -> n (c c2) (l l2)", l2=2, c2=self.channels - ) + rest = rearrange(rest, "n (c l2 c2) l -> n (c c2) (l l2)", l2=2, c2=self.channels) x = torch.cat([low, rest], dim=1) - return x \ No newline at end of file + return x diff --git a/stable_audio_tools/training/__init__.py b/stable_audio_tools/training/__init__.py index f77486b0..23884758 100644 --- a/stable_audio_tools/training/__init__.py +++ b/stable_audio_tools/training/__init__.py @@ -1 +1,4 @@ -from .factory import create_training_wrapper_from_config, create_demo_callback_from_config +from .factory import ( + create_demo_callback_from_config, + create_training_wrapper_from_config, +) diff --git a/stable_audio_tools/training/autoencoders.py b/stable_audio_tools/training/autoencoders.py index c592010e..92e29d23 100644 --- a/stable_audio_tools/training/autoencoders.py +++ b/stable_audio_tools/training/autoencoders.py @@ -1,36 +1,47 @@ +import pytorch_lightning as pl import torch import torchaudio import wandb +from aeiou.viz import audio_spectrogram_image, pca_point_cloud, tokens_spectrogram_image from einops import rearrange -from safetensors.torch import save_file, save_model from ema_pytorch import EMA -from .losses.auraloss import SumAndDifferenceSTFTLoss, MultiResolutionSTFTLoss -import pytorch_lightning as pl +from pytorch_lightning.utilities.rank_zero import rank_zero_only +from safetensors.torch import save_file, save_model + from ..models.autoencoders import AudioAutoencoder -from ..models.discriminators import EncodecDiscriminator, OobleckDiscriminator, DACGANLoss -from ..models.bottleneck import VAEBottleneck, RVQBottleneck, DACRVQBottleneck, DACRVQVAEBottleneck, RVQVAEBottleneck, WassersteinBottleneck -from .losses import MultiLoss, AuralossLoss, ValueLoss, L1Loss +from ..models.bottleneck import ( + DACRVQBottleneck, + DACRVQVAEBottleneck, + RVQBottleneck, + RVQVAEBottleneck, + VAEBottleneck, + WassersteinBottleneck, +) +from ..models.discriminators import ( + DACGANLoss, + EncodecDiscriminator, + OobleckDiscriminator, +) +from .losses import AuralossLoss, L1Loss, MultiLoss, ValueLoss +from .losses.auraloss import MultiResolutionSTFTLoss, SumAndDifferenceSTFTLoss from .utils import create_optimizer_from_config, create_scheduler_from_config -from pytorch_lightning.utilities.rank_zero import rank_zero_only -from aeiou.viz import pca_point_cloud, audio_spectrogram_image, tokens_spectrogram_image - class AutoencoderTrainingWrapper(pl.LightningModule): def __init__( - self, - autoencoder: AudioAutoencoder, - lr: float = 1e-4, - warmup_steps: int = 0, - encoder_freeze_on_warmup: bool = False, - sample_rate=48000, - loss_config: dict = None, - optimizer_configs: dict = None, - use_ema: bool = True, - ema_copy = None, - force_input_mono = False, - latent_mask_ratio = 0.0, - teacher_model: AudioAutoencoder = None + self, + autoencoder: AudioAutoencoder, + lr: float = 1e-4, + warmup_steps: int = 0, + encoder_freeze_on_warmup: bool = False, + sample_rate=48000, + loss_config: dict = None, + optimizer_configs: dict = None, + use_ema: bool = True, + ema_copy=None, + force_input_mono=False, + latent_mask_ratio=0.0, + teacher_model: AudioAutoencoder = None, ): super().__init__() @@ -48,28 +59,11 @@ def __init__( self.teacher_model = teacher_model if optimizer_configs is None: - optimizer_configs ={ - "autoencoder": { - "optimizer": { - "type": "AdamW", - "config": { - "lr": lr, - "betas": (.8, .99) - } - } - }, - "discriminator": { - "optimizer": { - "type": "AdamW", - "config": { - "lr": lr, - "betas": (.8, .99) - } - } - } - - } - + optimizer_configs = { + "autoencoder": {"optimizer": {"type": "AdamW", "config": {"lr": lr, "betas": (0.8, 0.99)}}}, + "discriminator": {"optimizer": {"type": "AdamW", "config": {"lr": lr, "betas": (0.8, 0.99)}}}, + } + self.optimizer_configs = optimizer_configs if loss_config is None: @@ -80,20 +74,15 @@ def __init__( for s in scales: hop_sizes.append(int(s * (1 - overlap))) win_lengths.append(s) - + loss_config = { "discriminator": { "type": "encodec", - "config": { - "n_ffts": scales, - "hop_lengths": hop_sizes, - "win_lengths": win_lengths, - "filters": 32 - }, + "config": {"n_ffts": scales, "hop_lengths": hop_sizes, "win_lengths": win_lengths, "filters": 32}, "weights": { "adversarial": 0.1, "feature_matching": 5.0, - } + }, }, "spectral": { "type": "mrstft", @@ -101,26 +90,26 @@ def __init__( "fft_sizes": scales, "hop_sizes": hop_sizes, "win_lengths": win_lengths, - "perceptual_weighting": True + "perceptual_weighting": True, }, "weights": { "mrstft": 1.0, - } + }, }, "time": { "type": "l1", "config": {}, "weights": { "l1": 0.0, - } - } + }, + }, } - + self.loss_config = loss_config - + # Spectral reconstruction loss - stft_loss_args = loss_config['spectral']['config'] + stft_loss_args = loss_config["spectral"]["config"] if self.autoencoder.out_channels == 2: self.sdstft = SumAndDifferenceSTFTLoss(sample_rate=sample_rate, **stft_loss_args) @@ -130,53 +119,112 @@ def __init__( # Discriminator - if loss_config['discriminator']['type'] == 'oobleck': - self.discriminator = OobleckDiscriminator(**loss_config['discriminator']['config']) - elif loss_config['discriminator']['type'] == 'encodec': - self.discriminator = EncodecDiscriminator(in_channels=self.autoencoder.out_channels, **loss_config['discriminator']['config']) - elif loss_config['discriminator']['type'] == 'dac': - self.discriminator = DACGANLoss(channels=self.autoencoder.out_channels, sample_rate=sample_rate, **loss_config['discriminator']['config']) + if loss_config["discriminator"]["type"] == "oobleck": + self.discriminator = OobleckDiscriminator(**loss_config["discriminator"]["config"]) + elif loss_config["discriminator"]["type"] == "encodec": + self.discriminator = EncodecDiscriminator( + in_channels=self.autoencoder.out_channels, **loss_config["discriminator"]["config"] + ) + elif loss_config["discriminator"]["type"] == "dac": + self.discriminator = DACGANLoss( + channels=self.autoencoder.out_channels, + sample_rate=sample_rate, + **loss_config["discriminator"]["config"], + ) self.gen_loss_modules = [] # Adversarial and feature matching losses self.gen_loss_modules += [ - ValueLoss(key='loss_adv', weight=self.loss_config['discriminator']['weights']['adversarial'], name='loss_adv'), - ValueLoss(key='feature_matching_distance', weight=self.loss_config['discriminator']['weights']['feature_matching'], name='feature_matching'), + ValueLoss( + key="loss_adv", weight=self.loss_config["discriminator"]["weights"]["adversarial"], name="loss_adv" + ), + ValueLoss( + key="feature_matching_distance", + weight=self.loss_config["discriminator"]["weights"]["feature_matching"], + name="feature_matching", + ), ] if self.teacher_model is not None: # Distillation losses - stft_loss_weight = self.loss_config['spectral']['weights']['mrstft'] * 0.25 + stft_loss_weight = self.loss_config["spectral"]["weights"]["mrstft"] * 0.25 self.gen_loss_modules += [ - AuralossLoss(self.sdstft, 'reals', 'decoded', name='mrstft_loss', weight=stft_loss_weight), # Reconstruction loss - AuralossLoss(self.sdstft, 'decoded', 'teacher_decoded', name='mrstft_loss_distill', weight=stft_loss_weight), # Distilled model's decoder is compatible with teacher's decoder - AuralossLoss(self.sdstft, 'reals', 'own_latents_teacher_decoded', name='mrstft_loss_own_latents_teacher', weight=stft_loss_weight), # Distilled model's encoder is compatible with teacher's decoder - AuralossLoss(self.sdstft, 'reals', 'teacher_latents_own_decoded', name='mrstft_loss_teacher_latents_own', weight=stft_loss_weight) # Teacher's encoder is compatible with distilled model's decoder + AuralossLoss( + self.sdstft, "reals", "decoded", name="mrstft_loss", weight=stft_loss_weight + ), # Reconstruction loss + AuralossLoss( + self.sdstft, "decoded", "teacher_decoded", name="mrstft_loss_distill", weight=stft_loss_weight + ), # Distilled model's decoder is compatible with teacher's decoder + AuralossLoss( + self.sdstft, + "reals", + "own_latents_teacher_decoded", + name="mrstft_loss_own_latents_teacher", + weight=stft_loss_weight, + ), # Distilled model's encoder is compatible with teacher's decoder + AuralossLoss( + self.sdstft, + "reals", + "teacher_latents_own_decoded", + name="mrstft_loss_teacher_latents_own", + weight=stft_loss_weight, + ), # Teacher's encoder is compatible with distilled model's decoder ] else: # Reconstruction loss self.gen_loss_modules += [ - AuralossLoss(self.sdstft, 'reals', 'decoded', name='mrstft_loss', weight=self.loss_config['spectral']['weights']['mrstft']), + AuralossLoss( + self.sdstft, + "reals", + "decoded", + name="mrstft_loss", + weight=self.loss_config["spectral"]["weights"]["mrstft"], + ), ] if self.autoencoder.out_channels == 2: # Add left and right channel reconstruction losses in addition to the sum and difference self.gen_loss_modules += [ - AuralossLoss(self.lrstft, 'reals_left', 'decoded_left', name='stft_loss_left', weight=self.loss_config['spectral']['weights']['mrstft']/2), - AuralossLoss(self.lrstft, 'reals_right', 'decoded_right', name='stft_loss_right', weight=self.loss_config['spectral']['weights']['mrstft']/2), + AuralossLoss( + self.lrstft, + "reals_left", + "decoded_left", + name="stft_loss_left", + weight=self.loss_config["spectral"]["weights"]["mrstft"] / 2, + ), + AuralossLoss( + self.lrstft, + "reals_right", + "decoded_right", + name="stft_loss_right", + weight=self.loss_config["spectral"]["weights"]["mrstft"] / 2, + ), ] self.gen_loss_modules += [ - AuralossLoss(self.sdstft, 'reals', 'decoded', name='mrstft_loss', weight=self.loss_config['spectral']['weights']['mrstft']), + AuralossLoss( + self.sdstft, + "reals", + "decoded", + name="mrstft_loss", + weight=self.loss_config["spectral"]["weights"]["mrstft"], + ), ] - if self.loss_config['time']['weights']['l1'] > 0.0: - self.gen_loss_modules.append(L1Loss(key_a='reals', key_b='decoded', weight=self.loss_config['time']['weights']['l1'], name='l1_time_loss')) + if self.loss_config["time"]["weights"]["l1"] > 0.0: + self.gen_loss_modules.append( + L1Loss( + key_a="reals", + key_b="decoded", + weight=self.loss_config["time"]["weights"]["l1"], + name="l1_time_loss", + ) + ) if self.autoencoder.bottleneck is not None: self.gen_loss_modules += create_loss_modules_from_bottleneck(self.autoencoder.bottleneck, self.loss_config) @@ -184,40 +232,42 @@ def __init__( self.losses_gen = MultiLoss(self.gen_loss_modules) self.disc_loss_modules = [ - ValueLoss(key='loss_dis', weight=1.0, name='discriminator_loss'), + ValueLoss(key="loss_dis", weight=1.0, name="discriminator_loss"), ] self.losses_disc = MultiLoss(self.disc_loss_modules) # Set up EMA for model weights self.autoencoder_ema = None - + self.use_ema = use_ema if self.use_ema: self.autoencoder_ema = EMA( - self.autoencoder, - ema_model=ema_copy, - beta=0.9999, - power=3/4, - update_every=1, - update_after_step=1 + self.autoencoder, ema_model=ema_copy, beta=0.9999, power=3 / 4, update_every=1, update_after_step=1 ) self.latent_mask_ratio = latent_mask_ratio def configure_optimizers(self): - opt_gen = create_optimizer_from_config(self.optimizer_configs['autoencoder']['optimizer'], self.autoencoder.parameters()) - opt_disc = create_optimizer_from_config(self.optimizer_configs['discriminator']['optimizer'], self.discriminator.parameters()) - - if "scheduler" in self.optimizer_configs['autoencoder'] and "scheduler" in self.optimizer_configs['discriminator']: - sched_gen = create_scheduler_from_config(self.optimizer_configs['autoencoder']['scheduler'], opt_gen) - sched_disc = create_scheduler_from_config(self.optimizer_configs['discriminator']['scheduler'], opt_disc) + opt_gen = create_optimizer_from_config( + self.optimizer_configs["autoencoder"]["optimizer"], self.autoencoder.parameters() + ) + opt_disc = create_optimizer_from_config( + self.optimizer_configs["discriminator"]["optimizer"], self.discriminator.parameters() + ) + + if ( + "scheduler" in self.optimizer_configs["autoencoder"] + and "scheduler" in self.optimizer_configs["discriminator"] + ): + sched_gen = create_scheduler_from_config(self.optimizer_configs["autoencoder"]["scheduler"], opt_gen) + sched_disc = create_scheduler_from_config(self.optimizer_configs["discriminator"]["scheduler"], opt_disc) return [opt_gen, opt_disc], [sched_gen, sched_disc] return [opt_gen, opt_disc] - + def training_step(self, batch, batch_idx): reals, _ = batch @@ -255,7 +305,7 @@ def training_step(self, batch, batch_idx): if self.teacher_model is not None: with torch.no_grad(): teacher_latents = self.teacher_model.encode(encoder_input, return_info=False) - loss_info['teacher_latents'] = teacher_latents + loss_info["teacher_latents"] = teacher_latents # Optionally mask out some latents for noise resistance if self.latent_mask_ratio > 0.0: @@ -276,20 +326,23 @@ def training_step(self, batch, batch_idx): if self.teacher_model is not None: with torch.no_grad(): teacher_decoded = self.teacher_model.decode(teacher_latents) - own_latents_teacher_decoded = self.teacher_model.decode(latents) #Distilled model's latents decoded by teacher - teacher_latents_own_decoded = self.autoencoder.decode(teacher_latents) #Teacher's latents decoded by distilled model + own_latents_teacher_decoded = self.teacher_model.decode( + latents + ) # Distilled model's latents decoded by teacher + teacher_latents_own_decoded = self.autoencoder.decode( + teacher_latents + ) # Teacher's latents decoded by distilled model - loss_info['teacher_decoded'] = teacher_decoded - loss_info['own_latents_teacher_decoded'] = own_latents_teacher_decoded - loss_info['teacher_latents_own_decoded'] = teacher_latents_own_decoded + loss_info["teacher_decoded"] = teacher_decoded + loss_info["own_latents_teacher_decoded"] = own_latents_teacher_decoded + loss_info["teacher_latents_own_decoded"] = teacher_latents_own_decoded - if self.warmed_up: loss_dis, loss_adv, feature_matching_distance = self.discriminator.loss(reals, decoded) else: - loss_dis = torch.tensor(0.).to(reals) - loss_adv = torch.tensor(0.).to(reals) - feature_matching_distance = torch.tensor(0.).to(reals) + loss_dis = torch.tensor(0.0).to(reals) + loss_adv = torch.tensor(0.0).to(reals) + feature_matching_distance = torch.tensor(0.0).to(reals) loss_info["loss_dis"] = loss_dis loss_info["loss_adv"] = loss_adv @@ -309,9 +362,7 @@ def training_step(self, batch, batch_idx): if self.global_step % 2 and self.warmed_up: loss, losses = self.losses_disc(loss_info) - log_dict = { - 'train/disc_lr': opt_disc.param_groups[0]['lr'] - } + log_dict = {"train/disc_lr": opt_disc.param_groups[0]["lr"]} opt_disc.zero_grad() self.manual_backward(loss) @@ -321,7 +372,7 @@ def training_step(self, batch, batch_idx): # sched step every step sched_disc.step() - # Train the generator + # Train the generator else: loss, losses = self.losses_gen(loss_info) @@ -338,39 +389,33 @@ def training_step(self, batch, batch_idx): sched_gen.step() log_dict = { - 'train/loss': loss.detach(), - 'train/latent_std': latents.std().detach(), - 'train/data_std': data_std.detach(), - 'train/gen_lr': opt_gen.param_groups[0]['lr'] + "train/loss": loss.detach(), + "train/latent_std": latents.std().detach(), + "train/data_std": data_std.detach(), + "train/gen_lr": opt_gen.param_groups[0]["lr"], } for loss_name, loss_value in losses.items(): - log_dict[f'train/{loss_name}'] = loss_value.detach() + log_dict[f"train/{loss_name}"] = loss_value.detach() self.log_dict(log_dict, prog_bar=True, on_step=True) return loss - + def export_model(self, path, use_safetensors=False): if self.autoencoder_ema is not None: model = self.autoencoder_ema.ema_model else: model = self.autoencoder - + if use_safetensors: save_model(model, path) else: torch.save({"state_dict": model.state_dict()}, path) - + class AutoencoderDemoCallback(pl.Callback): - def __init__( - self, - demo_dl, - demo_every=2000, - sample_size=65536, - sample_rate=48000 - ): + def __init__(self, demo_dl, demo_every=2000, sample_size=65536, sample_rate=48000): super().__init__() self.demo_every = demo_every self.demo_samples = sample_size @@ -380,10 +425,10 @@ def __init__( @rank_zero_only @torch.no_grad() - def on_train_batch_end(self, trainer, module, outputs, batch, batch_idx): + def on_train_batch_end(self, trainer, module, outputs, batch, batch_idx): if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step: return - + self.last_demo_step = trainer.global_step module.eval() @@ -396,7 +441,7 @@ def on_train_batch_end(self, trainer, module, outputs, batch, batch_idx): demo_reals = demo_reals[0] encoder_input = demo_reals - + encoder_input = encoder_input.to(module.device) if module.force_input_mono: @@ -415,63 +460,66 @@ def on_train_batch_end(self, trainer, module, outputs, batch, batch_idx): fakes = module.autoencoder.decode(latents) - #Interleave reals and fakes - reals_fakes = rearrange([demo_reals, fakes], 'i b d n -> (b i) d n') + # Interleave reals and fakes + reals_fakes = rearrange([demo_reals, fakes], "i b d n -> (b i) d n") # Put the demos together - reals_fakes = rearrange(reals_fakes, 'b d n -> d (b n)') + reals_fakes = rearrange(reals_fakes, "b d n -> d (b n)") log_dict = {} - - filename = f'recon_{trainer.global_step:08}.wav' + + filename = f"recon_{trainer.global_step:08}.wav" reals_fakes = reals_fakes.to(torch.float32).clamp(-1, 1).mul(32767).to(torch.int16).cpu() torchaudio.save(filename, reals_fakes, self.sample_rate) - log_dict[f'recon'] = wandb.Audio(filename, - sample_rate=self.sample_rate, - caption=f'Reconstructed') - - log_dict[f'embeddings_3dpca'] = pca_point_cloud(latents) - log_dict[f'embeddings_spec'] = wandb.Image(tokens_spectrogram_image(latents)) + log_dict[f"recon"] = wandb.Audio(filename, sample_rate=self.sample_rate, caption=f"Reconstructed") - log_dict[f'recon_melspec_left'] = wandb.Image(audio_spectrogram_image(reals_fakes)) + log_dict[f"embeddings_3dpca"] = pca_point_cloud(latents) + log_dict[f"embeddings_spec"] = wandb.Image(tokens_spectrogram_image(latents)) + + log_dict[f"recon_melspec_left"] = wandb.Image(audio_spectrogram_image(reals_fakes)) trainer.logger.experiment.log(log_dict) except Exception as e: - print(f'{type(e).__name__}: {e}') + print(f"{type(e).__name__}: {e}") raise e finally: module.train() + def create_loss_modules_from_bottleneck(bottleneck, loss_config): losses = [] - - if isinstance(bottleneck, VAEBottleneck) or isinstance(bottleneck, DACRVQVAEBottleneck) or isinstance(bottleneck, RVQVAEBottleneck): + + if ( + isinstance(bottleneck, VAEBottleneck) + or isinstance(bottleneck, DACRVQVAEBottleneck) + or isinstance(bottleneck, RVQVAEBottleneck) + ): try: - kl_weight = loss_config['bottleneck']['weights']['kl'] + kl_weight = loss_config["bottleneck"]["weights"]["kl"] except: kl_weight = 1e-6 - kl_loss = ValueLoss(key='kl', weight=kl_weight, name='kl_loss') + kl_loss = ValueLoss(key="kl", weight=kl_weight, name="kl_loss") losses.append(kl_loss) if isinstance(bottleneck, RVQBottleneck) or isinstance(bottleneck, RVQVAEBottleneck): - quantizer_loss = ValueLoss(key='quantizer_loss', weight=1.0, name='quantizer_loss') + quantizer_loss = ValueLoss(key="quantizer_loss", weight=1.0, name="quantizer_loss") losses.append(quantizer_loss) if isinstance(bottleneck, DACRVQBottleneck) or isinstance(bottleneck, DACRVQVAEBottleneck): - codebook_loss = ValueLoss(key='vq/codebook_loss', weight=1.0, name='codebook_loss') - commitment_loss = ValueLoss(key='vq/commitment_loss', weight=0.25, name='commitment_loss') + codebook_loss = ValueLoss(key="vq/codebook_loss", weight=1.0, name="codebook_loss") + commitment_loss = ValueLoss(key="vq/commitment_loss", weight=0.25, name="commitment_loss") losses.append(codebook_loss) losses.append(commitment_loss) if isinstance(bottleneck, WassersteinBottleneck): try: - mmd_weight = loss_config['bottleneck']['weights']['mmd'] + mmd_weight = loss_config["bottleneck"]["weights"]["mmd"] except: mmd_weight = 100 - mmd_loss = ValueLoss(key='mmd', weight=mmd_weight, name='mmd_loss') + mmd_loss = ValueLoss(key="mmd", weight=mmd_weight, name="mmd_loss") losses.append(mmd_loss) - - return losses \ No newline at end of file + + return losses diff --git a/stable_audio_tools/training/diffusion.py b/stable_audio_tools/training/diffusion.py index 91058ade..b01f196f 100644 --- a/stable_audio_tools/training/diffusion.py +++ b/stable_audio_tools/training/diffusion.py @@ -1,29 +1,30 @@ -import pytorch_lightning as pl -import sys, gc +import gc import random -import torch -import torchaudio +import sys import typing as tp -import wandb +from time import time -from aeiou.viz import pca_point_cloud, audio_spectrogram_image, tokens_spectrogram_image import auraloss -from ema_pytorch import EMA +import pytorch_lightning as pl +import torch +import torchaudio +import wandb +from aeiou.viz import audio_spectrogram_image, pca_point_cloud, tokens_spectrogram_image from einops import rearrange +from ema_pytorch import EMA +from pytorch_lightning.utilities.rank_zero import rank_zero_only from safetensors.torch import save_file from torch import optim from torch.nn import functional as F -from pytorch_lightning.utilities.rank_zero import rank_zero_only from ..inference.sampling import get_alphas_sigmas, sample, sample_discrete_euler -from ..models.diffusion import DiffusionModelWrapper, ConditionedDiffusionModelWrapper from ..models.autoencoders import DiffusionAutoencoder +from ..models.diffusion import ConditionedDiffusionModelWrapper, DiffusionModelWrapper from ..models.diffusion_prior import PriorType from .autoencoders import create_loss_modules_from_bottleneck from .losses import AuralossLoss, MSELoss, MultiLoss from .utils import create_optimizer_from_config, create_scheduler_from_config -from time import time class Profiler: @@ -42,39 +43,24 @@ def __repr__(self): rep += 80 * "=" + "\n\n\n" return rep + class DiffusionUncondTrainingWrapper(pl.LightningModule): - ''' + """ Wrapper for training an unconditional audio diffusion model (like Dance Diffusion). - ''' - def __init__( - self, - model: DiffusionModelWrapper, - lr: float = 1e-4, - pre_encoded: bool = False - ): + """ + + def __init__(self, model: DiffusionModelWrapper, lr: float = 1e-4, pre_encoded: bool = False): super().__init__() self.diffusion = model - - self.diffusion_ema = EMA( - self.diffusion.model, - beta=0.9999, - power=3/4, - update_every=1, - update_after_step=1 - ) + + self.diffusion_ema = EMA(self.diffusion.model, beta=0.9999, power=3 / 4, update_every=1, update_after_step=1) self.lr = lr self.rng = torch.quasirandom.SobolEngine(1, scramble=True) - loss_modules = [ - MSELoss("v", - "targets", - weight=1.0, - name="mse_loss" - ) - ] + loss_modules = [MSELoss("v", "targets", weight=1.0, name="mse_loss")] self.losses = MultiLoss(loss_modules) @@ -88,7 +74,7 @@ def training_step(self, batch, batch_idx): if reals.ndim == 4 and reals.shape[0] == 1: reals = reals[0] - + diffusion_input = reals loss_info = {} @@ -100,7 +86,7 @@ def training_step(self, batch, batch_idx): if not self.pre_encoded: with torch.set_grad_enabled(self.diffusion.pretransform.enable_grad): diffusion_input = self.diffusion.pretransform.encode(diffusion_input) - else: + else: # Apply scale to pre-encoded latents if needed, as the pretransform encode function will not be run if hasattr(self.diffusion.pretransform, "scale") and self.diffusion.pretransform.scale != 1.0: diffusion_input = diffusion_input / self.diffusion.pretransform.scale @@ -123,16 +109,13 @@ def training_step(self, batch, batch_idx): with torch.cuda.amp.autocast(): v = self.diffusion(noised_inputs, t) - loss_info.update({ - "v": v, - "targets": targets - }) + loss_info.update({"v": v, "targets": targets}) loss, losses = self.losses(loss_info) log_dict = { - 'train/loss': loss.detach(), - 'train/std_data': diffusion_input.std(), + "train/loss": loss.detach(), + "train/std_data": diffusion_input.std(), } for loss_name, loss_value in losses.items(): @@ -140,26 +123,22 @@ def training_step(self, batch, batch_idx): self.log_dict(log_dict, prog_bar=True, on_step=True) return loss - + def on_before_zero_grad(self, *args, **kwargs): self.diffusion_ema.update() def export_model(self, path, use_safetensors=False): self.diffusion.model = self.diffusion_ema.ema_model - + if use_safetensors: save_file(self.diffusion.state_dict(), path) else: torch.save({"state_dict": self.diffusion.state_dict()}, path) + class DiffusionUncondDemoCallback(pl.Callback): - def __init__(self, - demo_every=2000, - num_demos=8, - demo_steps=250, - sample_rate=48000 - ): + def __init__(self, demo_every=2000, num_demos=8, demo_steps=250, sample_rate=48000): super().__init__() self.demo_every = demo_every @@ -167,14 +146,14 @@ def __init__(self, self.demo_steps = demo_steps self.sample_rate = sample_rate self.last_demo_step = -1 - + @rank_zero_only @torch.no_grad() - def on_train_batch_end(self, trainer, module, outputs, batch, batch_idx): + def on_train_batch_end(self, trainer, module, outputs, batch, batch_idx): if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step: return - + self.last_demo_step = trainer.global_step demo_samples = module.diffusion.sample_size @@ -192,46 +171,46 @@ def on_train_batch_end(self, trainer, module, outputs, batch, batch_idx): fakes = module.diffusion.pretransform.decode(fakes) # Put the demos together - fakes = rearrange(fakes, 'b d n -> d (b n)') + fakes = rearrange(fakes, "b d n -> d (b n)") log_dict = {} - - filename = f'demo_{trainer.global_step:08}.wav' + + filename = f"demo_{trainer.global_step:08}.wav" fakes = fakes.to(torch.float32).div(torch.max(torch.abs(fakes))).mul(32767).to(torch.int16).cpu() torchaudio.save(filename, fakes, self.sample_rate) - log_dict[f'demo'] = wandb.Audio(filename, - sample_rate=self.sample_rate, - caption=f'Reconstructed') - - log_dict[f'demo_melspec_left'] = wandb.Image(audio_spectrogram_image(fakes)) + log_dict[f"demo"] = wandb.Audio(filename, sample_rate=self.sample_rate, caption=f"Reconstructed") + + log_dict[f"demo_melspec_left"] = wandb.Image(audio_spectrogram_image(fakes)) trainer.logger.experiment.log(log_dict) del fakes - + except Exception as e: - print(f'{type(e).__name__}: {e}') + print(f"{type(e).__name__}: {e}") finally: gc.collect() torch.cuda.empty_cache() + class DiffusionCondTrainingWrapper(pl.LightningModule): - ''' + """ Wrapper for training a conditional audio diffusion model. - ''' + """ + def __init__( - self, - model: ConditionedDiffusionModelWrapper, - lr: float = None, - mask_padding: bool = False, - mask_padding_dropout: float = 0.0, - use_ema: bool = True, - log_loss_info: bool = False, - optimizer_configs: dict = None, - pre_encoded: bool = False, - cfg_dropout_prob = 0.1, - timestep_sampler: tp.Literal["uniform", "logit_normal"] = "uniform", + self, + model: ConditionedDiffusionModelWrapper, + lr: float = None, + mask_padding: bool = False, + mask_padding_dropout: float = 0.0, + use_ema: bool = True, + log_loss_info: bool = False, + optimizer_configs: dict = None, + pre_encoded: bool = False, + cfg_dropout_prob=0.1, + timestep_sampler: tp.Literal["uniform", "logit_normal"] = "uniform", ): super().__init__() @@ -241,10 +220,10 @@ def __init__( self.diffusion_ema = EMA( self.diffusion.model, beta=0.9999, - power=3/4, + power=3 / 4, update_every=1, update_after_step=1, - include_online_model=False + include_online_model=False, ) else: self.diffusion_ema = None @@ -261,11 +240,8 @@ def __init__( self.diffusion_objective = model.diffusion_objective self.loss_modules = [ - MSELoss("output", - "targets", - weight=1.0, - mask_key="padding_mask" if self.mask_padding else None, - name="mse_loss" + MSELoss( + "output", "targets", weight=1.0, mask_key="padding_mask" if self.mask_padding else None, name="mse_loss" ) ] @@ -273,37 +249,29 @@ def __init__( self.log_loss_info = log_loss_info - assert lr is not None or optimizer_configs is not None, "Must specify either lr or optimizer_configs in training config" + assert ( + lr is not None or optimizer_configs is not None + ), "Must specify either lr or optimizer_configs in training config" if optimizer_configs is None: - optimizer_configs = { - "diffusion": { - "optimizer": { - "type": "Adam", - "config": { - "lr": lr - } - } - } - } + optimizer_configs = {"diffusion": {"optimizer": {"type": "Adam", "config": {"lr": lr}}}} else: if lr is not None: - print(f"WARNING: learning_rate and optimizer_configs both specified in config. Ignoring learning_rate and using optimizer_configs.") + print( + f"WARNING: learning_rate and optimizer_configs both specified in config. Ignoring learning_rate and using optimizer_configs." + ) self.optimizer_configs = optimizer_configs self.pre_encoded = pre_encoded def configure_optimizers(self): - diffusion_opt_config = self.optimizer_configs['diffusion'] - opt_diff = create_optimizer_from_config(diffusion_opt_config['optimizer'], self.diffusion.parameters()) + diffusion_opt_config = self.optimizer_configs["diffusion"] + opt_diff = create_optimizer_from_config(diffusion_opt_config["optimizer"], self.diffusion.parameters()) if "scheduler" in diffusion_opt_config: - sched_diff = create_scheduler_from_config(diffusion_opt_config['scheduler'], opt_diff) - sched_diff_config = { - "scheduler": sched_diff, - "interval": "step" - } + sched_diff = create_scheduler_from_config(diffusion_opt_config["scheduler"], opt_diff) + sched_diff_config = {"scheduler": sched_diff, "interval": "step"} return [opt_diff], [sched_diff_config] return [opt_diff] @@ -327,13 +295,15 @@ def training_step(self, batch, batch_idx): with torch.cuda.amp.autocast(): conditioning = self.diffusion.conditioner(metadata, self.device) - + # If mask_padding is on, randomly drop the padding masks to allow for learning silence padding use_padding_mask = self.mask_padding and random.random() > self.mask_padding_dropout # Create batch tensor of attention masks from the "mask" field of the metadata array if use_padding_mask: - padding_masks = torch.stack([md["padding_mask"][0] for md in metadata], dim=0).to(self.device) # Shape (batch_size, sequence_length) + padding_masks = torch.stack([md["padding_mask"][0] for md in metadata], dim=0).to( + self.device + ) # Shape (batch_size, sequence_length) p.tick("conditioning") @@ -343,14 +313,20 @@ def training_step(self, batch, batch_idx): if not self.pre_encoded: with torch.cuda.amp.autocast() and torch.set_grad_enabled(self.diffusion.pretransform.enable_grad): self.diffusion.pretransform.train(self.diffusion.pretransform.enable_grad) - + diffusion_input = self.diffusion.pretransform.encode(diffusion_input) p.tick("pretransform") # If mask_padding is on, interpolate the padding masks to the size of the pretransformed input if use_padding_mask: - padding_masks = F.interpolate(padding_masks.unsqueeze(1).float(), size=diffusion_input.shape[2], mode="nearest").squeeze(1).bool() - else: + padding_masks = ( + F.interpolate( + padding_masks.unsqueeze(1).float(), size=diffusion_input.shape[2], mode="nearest" + ) + .squeeze(1) + .bool() + ) + else: # Apply scale to pre-encoded latents if needed, as the pretransform encode function will not be run if hasattr(self.diffusion.pretransform, "scale") and self.diffusion.pretransform.scale != 1.0: diffusion_input = diffusion_input / self.diffusion.pretransform.scale @@ -360,12 +336,12 @@ def training_step(self, batch, batch_idx): t = self.rng.draw(reals.shape[0])[:, 0].to(self.device) elif self.timestep_sampler == "logit_normal": t = torch.sigmoid(torch.randn(reals.shape[0], device=self.device)) - + # Calculate the noise schedule parameters for those timesteps if self.diffusion_objective == "v": alphas, sigmas = get_alphas_sigmas(t) elif self.diffusion_objective == "rectified_flow": - alphas, sigmas = 1-t, t + alphas, sigmas = 1 - t, t # Combine the ground truth data and the noise alphas = alphas[:, None, None] @@ -387,14 +363,18 @@ def training_step(self, batch, batch_idx): with torch.cuda.amp.autocast(): p.tick("amp") - output = self.diffusion(noised_inputs, t, cond=conditioning, cfg_dropout_prob = self.cfg_dropout_prob, **extra_args) + output = self.diffusion( + noised_inputs, t, cond=conditioning, cfg_dropout_prob=self.cfg_dropout_prob, **extra_args + ) p.tick("diffusion") - loss_info.update({ - "output": output, - "targets": targets, - "padding_mask": padding_masks if use_padding_mask else None, - }) + loss_info.update( + { + "output": output, + "targets": targets, + "padding_mask": padding_masks if use_padding_mask else None, + } + ) loss, losses = self.losses(loss_info) @@ -412,20 +392,26 @@ def training_step(self, batch, batch_idx): loss_all = rearrange(self.all_gather(loss_all), "w b c n -> (w b) c n") # Bucket loss values based on corresponding sigma values, bucketing sigma values by bucket_size - loss_all = torch.stack([loss_all[(sigmas >= i) & (sigmas < i + bucket_size)].mean() for i in torch.arange(0, 1, bucket_size).to(self.device)]) + loss_all = torch.stack( + [ + loss_all[(sigmas >= i) & (sigmas < i + bucket_size)].mean() + for i in torch.arange(0, 1, bucket_size).to(self.device) + ] + ) # Log bucketed losses with corresponding sigma bucket values, if it's not NaN debug_log_dict = { - f"model/loss_all_{i/num_loss_buckets:.1f}": loss_all[i].detach() for i in range(num_loss_buckets) if not torch.isnan(loss_all[i]) + f"model/loss_all_{i/num_loss_buckets:.1f}": loss_all[i].detach() + for i in range(num_loss_buckets) + if not torch.isnan(loss_all[i]) } self.log_dict(debug_log_dict) - log_dict = { - 'train/loss': loss.detach(), - 'train/std_data': diffusion_input.std(), - 'train/lr': self.trainer.optimizers[0].param_groups[0]['lr'] + "train/loss": loss.detach(), + "train/std_data": diffusion_input.std(), + "train/lr": self.trainer.optimizers[0].param_groups[0]["lr"], } for loss_name, loss_value in losses.items(): @@ -433,9 +419,9 @@ def training_step(self, batch, batch_idx): self.log_dict(log_dict, prog_bar=True, on_step=True) p.tick("log") - #print(f"Profiler: {p}") + # print(f"Profiler: {p}") return loss - + def on_before_zero_grad(self, *args, **kwargs): if self.diffusion_ema is not None: self.diffusion_ema.update() @@ -443,23 +429,25 @@ def on_before_zero_grad(self, *args, **kwargs): def export_model(self, path, use_safetensors=False): if self.diffusion_ema is not None: self.diffusion.model = self.diffusion_ema.ema_model - + if use_safetensors: save_file(self.diffusion.state_dict(), path) else: torch.save({"state_dict": self.diffusion.state_dict()}, path) + class DiffusionCondDemoCallback(pl.Callback): - def __init__(self, - demo_every=2000, - num_demos=8, - sample_size=65536, - demo_steps=250, - sample_rate=48000, - demo_conditioning: tp.Optional[tp.Dict[str, tp.Any]] = {}, - demo_cfg_scales: tp.Optional[tp.List[int]] = [3, 5, 7], - demo_cond_from_batch: bool = False, - display_audio_cond: bool = False + def __init__( + self, + demo_every=2000, + num_demos=8, + sample_size=65536, + demo_steps=250, + sample_rate=48000, + demo_conditioning: tp.Optional[tp.Dict[str, tp.Any]] = {}, + demo_cfg_scales: tp.Optional[tp.List[int]] = [3, 5, 7], + demo_cond_from_batch: bool = False, + display_audio_cond: bool = False, ): super().__init__() @@ -480,7 +468,7 @@ def __init__(self, @rank_zero_only @torch.no_grad() - def on_train_batch_end(self, trainer, module: DiffusionCondTrainingWrapper, outputs, batch, batch_idx): + def on_train_batch_end(self, trainer, module: DiffusionCondTrainingWrapper, outputs, batch, batch_idx): if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step: return @@ -496,7 +484,7 @@ def on_train_batch_end(self, trainer, module: DiffusionCondTrainingWrapper, outp if self.demo_cond_from_batch: # Get metadata from the batch - demo_cond = batch[1][:self.num_demos] + demo_cond = batch[1][: self.num_demos] if module.diffusion.pretransform is not None: demo_samples = demo_samples // module.diffusion.pretransform.downsampling_ratio @@ -514,47 +502,53 @@ def on_train_batch_end(self, trainer, module: DiffusionCondTrainingWrapper, outp if self.display_audio_cond: audio_inputs = torch.cat([cond["audio"] for cond in demo_cond], dim=0) - audio_inputs = rearrange(audio_inputs, 'b d n -> d (b n)') + audio_inputs = rearrange(audio_inputs, "b d n -> d (b n)") - filename = f'demo_audio_cond_{trainer.global_step:08}.wav' + filename = f"demo_audio_cond_{trainer.global_step:08}.wav" audio_inputs = audio_inputs.to(torch.float32).mul(32767).to(torch.int16).cpu() torchaudio.save(filename, audio_inputs, self.sample_rate) - log_dict[f'demo_audio_cond'] = wandb.Audio(filename, sample_rate=self.sample_rate, caption="Audio conditioning") + log_dict[f"demo_audio_cond"] = wandb.Audio( + filename, sample_rate=self.sample_rate, caption="Audio conditioning" + ) log_dict[f"demo_audio_cond_melspec_left"] = wandb.Image(audio_spectrogram_image(audio_inputs)) trainer.logger.experiment.log(log_dict) for cfg_scale in self.demo_cfg_scales: print(f"Generating demo for cfg scale {cfg_scale}") - + with torch.cuda.amp.autocast(): model = module.diffusion_ema.model if module.diffusion_ema is not None else module.diffusion.model if module.diffusion_objective == "v": - fakes = sample(model, noise, self.demo_steps, 0, **cond_inputs, cfg_scale=cfg_scale, batch_cfg=True) + fakes = sample( + model, noise, self.demo_steps, 0, **cond_inputs, cfg_scale=cfg_scale, batch_cfg=True + ) elif module.diffusion_objective == "rectified_flow": - fakes = sample_discrete_euler(model, noise, self.demo_steps, **cond_inputs, cfg_scale=cfg_scale, batch_cfg=True) - + fakes = sample_discrete_euler( + model, noise, self.demo_steps, **cond_inputs, cfg_scale=cfg_scale, batch_cfg=True + ) + if module.diffusion.pretransform is not None: fakes = module.diffusion.pretransform.decode(fakes) # Put the demos together - fakes = rearrange(fakes, 'b d n -> d (b n)') + fakes = rearrange(fakes, "b d n -> d (b n)") log_dict = {} - - filename = f'demo_cfg_{cfg_scale}_{trainer.global_step:08}.wav' + + filename = f"demo_cfg_{cfg_scale}_{trainer.global_step:08}.wav" fakes = fakes.div(torch.max(torch.abs(fakes))).mul(32767).to(torch.int16).cpu() torchaudio.save(filename, fakes, self.sample_rate) - log_dict[f'demo_cfg_{cfg_scale}'] = wandb.Audio(filename, - sample_rate=self.sample_rate, - caption=f'Reconstructed') - - log_dict[f'demo_melspec_left_cfg_{cfg_scale}'] = wandb.Image(audio_spectrogram_image(fakes)) + log_dict[f"demo_cfg_{cfg_scale}"] = wandb.Audio( + filename, sample_rate=self.sample_rate, caption=f"Reconstructed" + ) + + log_dict[f"demo_melspec_left_cfg_{cfg_scale}"] = wandb.Image(audio_spectrogram_image(fakes)) trainer.logger.experiment.log(log_dict) - + del fakes except Exception as e: @@ -564,36 +558,38 @@ def on_train_batch_end(self, trainer, module: DiffusionCondTrainingWrapper, outp torch.cuda.empty_cache() module.train() + class DiffusionCondInpaintTrainingWrapper(pl.LightningModule): - ''' + """ Wrapper for training a conditional audio diffusion model. - ''' + """ + def __init__( - self, - model: ConditionedDiffusionModelWrapper, - lr: float = 1e-4, - max_mask_segments = 10, - log_loss_info: bool = False, - optimizer_configs: dict = None, - use_ema: bool = True, - pre_encoded: bool = False, - cfg_dropout_prob = 0.1, - timestep_sampler: tp.Literal["uniform", "logit_normal"] = "uniform", + self, + model: ConditionedDiffusionModelWrapper, + lr: float = 1e-4, + max_mask_segments=10, + log_loss_info: bool = False, + optimizer_configs: dict = None, + use_ema: bool = True, + pre_encoded: bool = False, + cfg_dropout_prob=0.1, + timestep_sampler: tp.Literal["uniform", "logit_normal"] = "uniform", ): super().__init__() self.diffusion = model - + self.use_ema = use_ema if self.use_ema: self.diffusion_ema = EMA( self.diffusion.model, beta=0.9999, - power=3/4, + power=3 / 4, update_every=1, update_after_step=1, - include_online_model=False + include_online_model=False, ) else: self.diffusion_ema = None @@ -604,54 +600,40 @@ def __init__( self.max_mask_segments = max_mask_segments self.rng = torch.quasirandom.SobolEngine(1, scramble=True) - + self.timestep_sampler = timestep_sampler self.diffusion_objective = model.diffusion_objective - self.loss_modules = [ - MSELoss("output", - "targets", - weight=1.0, - name="mse_loss" - ) - ] + self.loss_modules = [MSELoss("output", "targets", weight=1.0, name="mse_loss")] self.losses = MultiLoss(self.loss_modules) self.log_loss_info = log_loss_info - assert lr is not None or optimizer_configs is not None, "Must specify either lr or optimizer_configs in training config" + assert ( + lr is not None or optimizer_configs is not None + ), "Must specify either lr or optimizer_configs in training config" if optimizer_configs is None: - optimizer_configs = { - "diffusion": { - "optimizer": { - "type": "Adam", - "config": { - "lr": lr - } - } - } - } + optimizer_configs = {"diffusion": {"optimizer": {"type": "Adam", "config": {"lr": lr}}}} else: if lr is not None: - print(f"WARNING: learning_rate and optimizer_configs both specified in config. Ignoring learning_rate and using optimizer_configs.") + print( + f"WARNING: learning_rate and optimizer_configs both specified in config. Ignoring learning_rate and using optimizer_configs." + ) self.optimizer_configs = optimizer_configs self.pre_encoded = pre_encoded def configure_optimizers(self): - diffusion_opt_config = self.optimizer_configs['diffusion'] - opt_diff = create_optimizer_from_config(diffusion_opt_config['optimizer'], self.diffusion.parameters()) + diffusion_opt_config = self.optimizer_configs["diffusion"] + opt_diff = create_optimizer_from_config(diffusion_opt_config["optimizer"], self.diffusion.parameters()) if "scheduler" in diffusion_opt_config: - sched_diff = create_scheduler_from_config(diffusion_opt_config['scheduler'], opt_diff) - sched_diff_config = { - "scheduler": sched_diff, - "interval": "step" - } + sched_diff = create_scheduler_from_config(diffusion_opt_config["scheduler"], opt_diff) + sched_diff_config = {"scheduler": sched_diff, "interval": "step"} return [opt_diff], [sched_diff_config] return [opt_diff] @@ -670,11 +652,11 @@ def random_mask(self, sequence, max_mask_length): max_segment_length = max_mask_length // num_segments segment_lengths = random.sample(range(1, max_segment_length + 1), num_segments) - + mask = torch.ones((1, 1, sequence_length)) for length in segment_lengths: mask_start = random.randint(0, sequence_length - length) - mask[:, :, mask_start:mask_start + length] = 0 + mask[:, :, mask_start : mask_start + length] = 0 elif mask_type == 1: # Full mask mask = torch.zeros((1, 1, sequence_length)) @@ -728,7 +710,7 @@ def training_step(self, batch, batch_idx): # If mask_padding is on, interpolate the padding masks to the size of the pretransformed input # if use_padding_mask: # padding_masks = F.interpolate(padding_masks.unsqueeze(1).float(), size=diffusion_input.shape[2], mode="nearest").squeeze(1).bool() - else: + else: # Apply scale to pre-encoded latents if needed, as the pretransform encode function will not be run if hasattr(self.diffusion.pretransform, "scale") and self.diffusion.pretransform.scale != 1.0: diffusion_input = diffusion_input / self.diffusion.pretransform.scale @@ -739,20 +721,20 @@ def training_step(self, batch, batch_idx): # Create a mask of random length for a random slice of the input masked_input, mask = self.random_mask(diffusion_input, max_mask_length) - conditioning['inpaint_mask'] = [mask] - conditioning['inpaint_masked_input'] = [masked_input] + conditioning["inpaint_mask"] = [mask] + conditioning["inpaint_masked_input"] = [masked_input] if self.timestep_sampler == "uniform": # Draw uniformly distributed continuous timesteps t = self.rng.draw(reals.shape[0])[:, 0].to(self.device) elif self.timestep_sampler == "logit_normal": t = torch.sigmoid(torch.randn(reals.shape[0], device=self.device)) - + # Calculate the noise schedule parameters for those timesteps if self.diffusion_objective == "v": alphas, sigmas = get_alphas_sigmas(t) elif self.diffusion_objective == "rectified_flow": - alphas, sigmas = 1-t, t + alphas, sigmas = 1 - t, t # Combine the ground truth data and the noise alphas = alphas[:, None, None] @@ -771,13 +753,17 @@ def training_step(self, batch, batch_idx): with torch.cuda.amp.autocast(): p.tick("amp") - output = self.diffusion(noised_inputs, t, cond=conditioning, cfg_dropout_prob = self.cfg_dropout_prob, **extra_args) + output = self.diffusion( + noised_inputs, t, cond=conditioning, cfg_dropout_prob=self.cfg_dropout_prob, **extra_args + ) p.tick("diffusion") - loss_info.update({ - "output": output, - "targets": targets, - }) + loss_info.update( + { + "output": output, + "targets": targets, + } + ) loss, losses = self.losses(loss_info) @@ -793,19 +779,26 @@ def training_step(self, batch, batch_idx): loss_all = rearrange(self.all_gather(loss_all), "w b c n -> (w b) c n") # Bucket loss values based on corresponding sigma values, bucketing sigma values by bucket_size - loss_all = torch.stack([loss_all[(sigmas >= i) & (sigmas < i + bucket_size)].mean() for i in torch.arange(0, 1, bucket_size).to(self.device)]) + loss_all = torch.stack( + [ + loss_all[(sigmas >= i) & (sigmas < i + bucket_size)].mean() + for i in torch.arange(0, 1, bucket_size).to(self.device) + ] + ) # Log bucketed losses with corresponding sigma bucket values, if it's not NaN debug_log_dict = { - f"model/loss_all_{i/num_loss_buckets:.1f}": loss_all[i].detach() for i in range(num_loss_buckets) if not torch.isnan(loss_all[i]) + f"model/loss_all_{i/num_loss_buckets:.1f}": loss_all[i].detach() + for i in range(num_loss_buckets) + if not torch.isnan(loss_all[i]) } self.log_dict(debug_log_dict) log_dict = { - 'train/loss': loss.detach(), - 'train/std_data': diffusion_input.std(), - 'train/lr': self.trainer.optimizers[0].param_groups[0]['lr'] + "train/loss": loss.detach(), + "train/std_data": diffusion_input.std(), + "train/lr": self.trainer.optimizers[0].param_groups[0]["lr"], } for loss_name, loss_value in losses.items(): @@ -813,9 +806,9 @@ def training_step(self, batch, batch_idx): self.log_dict(log_dict, prog_bar=True, on_step=True) p.tick("log") - #print(f"Profiler: {p}") + # print(f"Profiler: {p}") return loss - + def on_before_zero_grad(self, *args, **kwargs): if self.diffusion_ema is not None: self.diffusion_ema.update() @@ -823,21 +816,22 @@ def on_before_zero_grad(self, *args, **kwargs): def export_model(self, path, use_safetensors=False): if self.diffusion_ema is not None: self.diffusion.model = self.diffusion_ema.ema_model - + if use_safetensors: save_file(self.diffusion.state_dict(), path) else: torch.save({"state_dict": self.diffusion.state_dict()}, path) + class DiffusionCondInpaintDemoCallback(pl.Callback): def __init__( - self, - demo_dl, + self, + demo_dl, demo_every=2000, demo_steps=250, sample_size=65536, sample_rate=48000, - demo_cfg_scales: tp.Optional[tp.List[int]] = [3, 5, 7] + demo_cfg_scales: tp.Optional[tp.List[int]] = [3, 5, 7], ): super().__init__() self.demo_every = demo_every @@ -850,10 +844,10 @@ def __init__( @rank_zero_only @torch.no_grad() - def on_train_batch_end(self, trainer, module: DiffusionCondTrainingWrapper, outputs, batch, batch_idx): + def on_train_batch_end(self, trainer, module: DiffusionCondTrainingWrapper, outputs, batch, batch_idx): if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step: return - + self.last_demo_step = trainer.global_step try: @@ -869,7 +863,9 @@ def on_train_batch_end(self, trainer, module: DiffusionCondTrainingWrapper, outp if not module.pre_encoded: # Log the real audio - log_dict[f'demo_reals_melspec_left'] = wandb.Image(audio_spectrogram_image(rearrange(demo_reals, "b d n -> d (b n)").mul(32767).to(torch.int16).cpu())) + log_dict[f"demo_reals_melspec_left"] = wandb.Image( + audio_spectrogram_image(rearrange(demo_reals, "b d n -> d (b n)").mul(32767).to(torch.int16).cpu()) + ) # log_dict[f'demo_reals'] = wandb.Audio(rearrange(demo_reals, "b d n -> d (b n)").mul(32767).to(torch.int16).cpu(), sample_rate=self.sample_rate, caption="demo reals") if module.diffusion.pretransform is not None: @@ -884,13 +880,17 @@ def on_train_batch_end(self, trainer, module: DiffusionCondTrainingWrapper, outp masked_input, mask = module.random_mask(demo_reals, demo_reals.shape[2]) - conditioning['inpaint_mask'] = [mask] - conditioning['inpaint_masked_input'] = [masked_input] + conditioning["inpaint_mask"] = [mask] + conditioning["inpaint_masked_input"] = [masked_input] if module.diffusion.pretransform is not None: - log_dict[f'demo_masked_input'] = wandb.Image(tokens_spectrogram_image(masked_input.cpu())) + log_dict[f"demo_masked_input"] = wandb.Image(tokens_spectrogram_image(masked_input.cpu())) else: - log_dict[f'demo_masked_input'] = wandb.Image(audio_spectrogram_image(rearrange(masked_input, "b c t -> c (b t)").mul(32767).to(torch.int16).cpu())) + log_dict[f"demo_masked_input"] = wandb.Image( + audio_spectrogram_image( + rearrange(masked_input, "b c t -> c (b t)").mul(32767).to(torch.int16).cpu() + ) + ) cond_inputs = module.diffusion.get_conditioning_inputs(conditioning) @@ -905,68 +905,62 @@ def on_train_batch_end(self, trainer, module: DiffusionCondTrainingWrapper, outp if module.diffusion_objective == "v": fakes = sample(model, noise, self.demo_steps, 0, **cond_inputs, cfg_scale=cfg_scale, batch_cfg=True) elif module.diffusion_objective == "rectified_flow": - fakes = sample_discrete_euler(model, noise, self.demo_steps, **cond_inputs, cfg_scale=cfg_scale, batch_cfg=True) + fakes = sample_discrete_euler( + model, noise, self.demo_steps, **cond_inputs, cfg_scale=cfg_scale, batch_cfg=True + ) if module.diffusion.pretransform is not None: with torch.cuda.amp.autocast(): fakes = module.diffusion.pretransform.decode(fakes) # Put the demos together - fakes = rearrange(fakes, 'b d n -> d (b n)') + fakes = rearrange(fakes, "b d n -> d (b n)") log_dict = {} - - filename = f'demo_cfg_{cfg_scale}_{trainer.global_step:08}.wav' + + filename = f"demo_cfg_{cfg_scale}_{trainer.global_step:08}.wav" fakes = fakes.to(torch.float32).div(torch.max(torch.abs(fakes))).mul(32767).to(torch.int16).cpu() torchaudio.save(filename, fakes, self.sample_rate) - log_dict[f'demo_cfg_{cfg_scale}'] = wandb.Audio(filename, - sample_rate=self.sample_rate, - caption=f'Reconstructed') - - log_dict[f'demo_melspec_left_cfg_{cfg_scale}'] = wandb.Image(audio_spectrogram_image(fakes)) + log_dict[f"demo_cfg_{cfg_scale}"] = wandb.Audio( + filename, sample_rate=self.sample_rate, caption=f"Reconstructed" + ) + + log_dict[f"demo_melspec_left_cfg_{cfg_scale}"] = wandb.Image(audio_spectrogram_image(fakes)) trainer.logger.experiment.log(log_dict) except Exception as e: - print(f'{type(e).__name__}: {e}') + print(f"{type(e).__name__}: {e}") raise e + class DiffusionAutoencoderTrainingWrapper(pl.LightningModule): - ''' + """ Wrapper for training a diffusion autoencoder - ''' + """ + def __init__( - self, - model: DiffusionAutoencoder, - lr: float = 1e-4, - ema_copy = None, - use_reconstruction_loss: bool = False + self, model: DiffusionAutoencoder, lr: float = 1e-4, ema_copy=None, use_reconstruction_loss: bool = False ): super().__init__() self.diffae = model - + self.diffae_ema = EMA( self.diffae, ema_model=ema_copy, beta=0.9999, - power=3/4, + power=3 / 4, update_every=1, update_after_step=1, - include_online_model=False + include_online_model=False, ) self.lr = lr self.rng = torch.quasirandom.SobolEngine(1, scramble=True) - loss_modules = [ - MSELoss("v", - "targets", - weight=1.0, - name="mse_loss" - ) - ] + loss_modules = [MSELoss("v", "targets", weight=1.0, name="mse_loss")] if model.bottleneck is not None: # TODO: Use loss config for configurable bottleneck weights and reconstruction losses @@ -989,7 +983,7 @@ def __init__( "fft_sizes": scales, "hop_sizes": hop_sizes, "win_lengths": win_lengths, - "perceptual_weighting": True + "perceptual_weighting": True, } out_channels = model.out_channels @@ -1003,7 +997,9 @@ def __init__( self.sdstft = auraloss.freq.MultiResolutionSTFTLoss(sample_rate=sample_rate, **stft_loss_args) loss_modules.append( - AuralossLoss(self.sdstft, 'audio_reals', 'audio_pred', name='mrstft_loss', weight=0.1), # Reconstruction loss + AuralossLoss( + self.sdstft, "audio_reals", "audio_pred", name="mrstft_loss", weight=0.1 + ), # Reconstruction loss ) self.losses = MultiLoss(loss_modules) @@ -1020,14 +1016,14 @@ def training_step(self, batch, batch_idx): loss_info = {} loss_info["audio_reals"] = reals - + if self.diffae.pretransform is not None: with torch.no_grad(): reals = self.diffae.pretransform.encode(reals) loss_info["reals"] = reals - #Encode reals, skipping the pretransform since it was already applied + # Encode reals, skipping the pretransform since it was already applied latents, encoder_info = self.diffae.encode(reals, return_info=True, skip_pretransform=True) loss_info["latents"] = latents @@ -1035,10 +1031,10 @@ def training_step(self, batch, batch_idx): if self.diffae.decoder is not None: latents = self.diffae.decoder(latents) - + # Upsample latents to match diffusion length if latents.shape[2] != reals.shape[2]: - latents = F.interpolate(latents, size=reals.shape[2], mode='nearest') + latents = F.interpolate(latents, size=reals.shape[2], mode="nearest") loss_info["latents_upsampled"] = latents @@ -1057,11 +1053,8 @@ def training_step(self, batch, batch_idx): with torch.cuda.amp.autocast(): v = self.diffae.diffusion(noised_reals, t, input_concat_cond=latents) - - loss_info.update({ - "v": v, - "targets": targets - }) + + loss_info.update({"v": v, "targets": targets}) if self.use_reconstruction_loss: pred = noised_reals * alphas - v * sigmas @@ -1075,9 +1068,9 @@ def training_step(self, batch, batch_idx): loss, losses = self.losses(loss_info) log_dict = { - 'train/loss': loss.detach(), - 'train/std_data': reals.std(), - 'train/latent_std': latents.std(), + "train/loss": loss.detach(), + "train/std_data": reals.std(), + "train/latent_std": latents.std(), } for loss_name, loss_value in losses.items(): @@ -1085,28 +1078,22 @@ def training_step(self, batch, batch_idx): self.log_dict(log_dict, prog_bar=True, on_step=True) return loss - + def on_before_zero_grad(self, *args, **kwargs): self.diffae_ema.update() def export_model(self, path, use_safetensors=False): model = self.diffae_ema.ema_model - + if use_safetensors: save_file(model.state_dict(), path) else: torch.save({"state_dict": model.state_dict()}, path) + class DiffusionAutoencoderDemoCallback(pl.Callback): - def __init__( - self, - demo_dl, - demo_every=2000, - demo_steps=250, - sample_size=65536, - sample_rate=48000 - ): + def __init__(self, demo_dl, demo_every=2000, demo_steps=250, sample_size=65536, sample_rate=48000): super().__init__() self.demo_every = demo_every self.demo_steps = demo_steps @@ -1117,10 +1104,10 @@ def __init__( @rank_zero_only @torch.no_grad() - def on_train_batch_end(self, trainer, module: DiffusionAutoencoderTrainingWrapper, outputs, batch, batch_idx): + def on_train_batch_end(self, trainer, module: DiffusionAutoencoderTrainingWrapper, outputs, batch, batch_idx): if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step: return - + self.last_demo_step = trainer.global_step demo_reals, _ = next(self.demo_dl) @@ -1130,7 +1117,7 @@ def on_train_batch_end(self, trainer, module: DiffusionAutoencoderTrainingWrappe demo_reals = demo_reals[0] encoder_input = demo_reals - + encoder_input = encoder_input.to(module.device) demo_reals = demo_reals.to(module.device) @@ -1139,60 +1126,60 @@ def on_train_batch_end(self, trainer, module: DiffusionAutoencoderTrainingWrappe latents = module.diffae_ema.ema_model.encode(encoder_input).float() fakes = module.diffae_ema.ema_model.decode(latents, steps=self.demo_steps) - #Interleave reals and fakes - reals_fakes = rearrange([demo_reals, fakes], 'i b d n -> (b i) d n') + # Interleave reals and fakes + reals_fakes = rearrange([demo_reals, fakes], "i b d n -> (b i) d n") # Put the demos together - reals_fakes = rearrange(reals_fakes, 'b d n -> d (b n)') + reals_fakes = rearrange(reals_fakes, "b d n -> d (b n)") log_dict = {} - - filename = f'recon_{trainer.global_step:08}.wav' - reals_fakes = reals_fakes.to(torch.float32).div(torch.max(torch.abs(reals_fakes))).mul(32767).to(torch.int16).cpu() + + filename = f"recon_{trainer.global_step:08}.wav" + reals_fakes = ( + reals_fakes.to(torch.float32).div(torch.max(torch.abs(reals_fakes))).mul(32767).to(torch.int16).cpu() + ) torchaudio.save(filename, reals_fakes, self.sample_rate) - log_dict[f'recon'] = wandb.Audio(filename, - sample_rate=self.sample_rate, - caption=f'Reconstructed') + log_dict[f"recon"] = wandb.Audio(filename, sample_rate=self.sample_rate, caption=f"Reconstructed") - log_dict[f'embeddings_3dpca'] = pca_point_cloud(latents) - log_dict[f'embeddings_spec'] = wandb.Image(tokens_spectrogram_image(latents)) + log_dict[f"embeddings_3dpca"] = pca_point_cloud(latents) + log_dict[f"embeddings_spec"] = wandb.Image(tokens_spectrogram_image(latents)) - log_dict[f'recon_melspec_left'] = wandb.Image(audio_spectrogram_image(reals_fakes)) + log_dict[f"recon_melspec_left"] = wandb.Image(audio_spectrogram_image(reals_fakes)) if module.diffae_ema.ema_model.pretransform is not None: with torch.no_grad() and torch.cuda.amp.autocast(): initial_latents = module.diffae_ema.ema_model.pretransform.encode(encoder_input) first_stage_fakes = module.diffae_ema.ema_model.pretransform.decode(initial_latents) - first_stage_fakes = rearrange(first_stage_fakes, 'b d n -> d (b n)') + first_stage_fakes = rearrange(first_stage_fakes, "b d n -> d (b n)") first_stage_fakes = first_stage_fakes.to(torch.float32).mul(32767).to(torch.int16).cpu() - first_stage_filename = f'first_stage_{trainer.global_step:08}.wav' + first_stage_filename = f"first_stage_{trainer.global_step:08}.wav" torchaudio.save(first_stage_filename, first_stage_fakes, self.sample_rate) - log_dict[f'first_stage_latents'] = wandb.Image(tokens_spectrogram_image(initial_latents)) + log_dict[f"first_stage_latents"] = wandb.Image(tokens_spectrogram_image(initial_latents)) + + log_dict[f"first_stage"] = wandb.Audio( + first_stage_filename, sample_rate=self.sample_rate, caption=f"First Stage Reconstructed" + ) - log_dict[f'first_stage'] = wandb.Audio(first_stage_filename, - sample_rate=self.sample_rate, - caption=f'First Stage Reconstructed') - - log_dict[f'first_stage_melspec_left'] = wandb.Image(audio_spectrogram_image(first_stage_fakes)) - + log_dict[f"first_stage_melspec_left"] = wandb.Image(audio_spectrogram_image(first_stage_fakes)) trainer.logger.experiment.log(log_dict) + def create_source_mixture(reals, num_sources=2): # Create a fake mixture source by mixing elements from the training batch together with random offsets source = torch.zeros_like(reals) for i in range(reals.shape[0]): sources_added = 0 - + js = list(range(reals.shape[0])) random.shuffle(js) for j in js: if i == j or (i != j and sources_added < num_sources): # Randomly offset the mixed element between 0 and the length of the source seq_len = reals.shape[2] - offset = random.randint(0, seq_len-1) + offset = random.randint(0, seq_len - 1) source[i, :, offset:] += reals[j, :, :-offset] if i == j: # If this is the real one, shift the reals as well to ensure alignment @@ -1203,33 +1190,35 @@ def create_source_mixture(reals, num_sources=2): return source + class DiffusionPriorTrainingWrapper(pl.LightningModule): - ''' + """ Wrapper for training a diffusion prior for inverse problems Prior types: mono_stereo: The prior is conditioned on a mono version of the audio to generate a stereo version - ''' + """ + def __init__( - self, - model: ConditionedDiffusionModelWrapper, - lr: float = 1e-4, - ema_copy = None, - prior_type: PriorType = PriorType.MonoToStereo, - use_reconstruction_loss: bool = False, - log_loss_info: bool = False, + self, + model: ConditionedDiffusionModelWrapper, + lr: float = 1e-4, + ema_copy=None, + prior_type: PriorType = PriorType.MonoToStereo, + use_reconstruction_loss: bool = False, + log_loss_info: bool = False, ): super().__init__() self.diffusion = model - + self.diffusion_ema = EMA( self.diffusion, ema_model=ema_copy, beta=0.9999, - power=3/4, + power=3 / 4, update_every=1, update_after_step=1, - include_online_model=False + include_online_model=False, ) self.lr = lr @@ -1238,13 +1227,7 @@ def __init__( self.log_loss_info = log_loss_info - loss_modules = [ - MSELoss("v", - "targets", - weight=1.0, - name="mse_loss" - ) - ] + loss_modules = [MSELoss("v", "targets", weight=1.0, name="mse_loss")] self.use_reconstruction_loss = use_reconstruction_loss @@ -1263,7 +1246,7 @@ def __init__( "fft_sizes": scales, "hop_sizes": hop_sizes, "win_lengths": win_lengths, - "perceptual_weighting": True + "perceptual_weighting": True, } out_channels = model.io_channels @@ -1279,15 +1262,17 @@ def __init__( # Add left and right channel reconstruction losses in addition to the sum and difference self.loss_modules += [ - AuralossLoss(self.lrstft, 'audio_reals_left', 'pred_left', name='stft_loss_left', weight=0.05), - AuralossLoss(self.lrstft, 'audio_reals_right', 'pred_right', name='stft_loss_right', weight=0.05), + AuralossLoss(self.lrstft, "audio_reals_left", "pred_left", name="stft_loss_left", weight=0.05), + AuralossLoss(self.lrstft, "audio_reals_right", "pred_right", name="stft_loss_right", weight=0.05), ] else: self.sdstft = auraloss.freq.MultiResolutionSTFTLoss(sample_rate=sample_rate, **stft_loss_args) self.loss_modules.append( - AuralossLoss(self.sdstft, 'audio_reals', 'audio_pred', name='mrstft_loss', weight=0.1), # Reconstruction loss + AuralossLoss( + self.sdstft, "audio_reals", "audio_pred", name="mrstft_loss", weight=0.1 + ), # Reconstruction loss ) self.losses = MultiLoss(loss_modules) @@ -1312,7 +1297,7 @@ def training_step(self, batch, batch_idx): loss_info["audio_reals_mono"] = source else: raise ValueError(f"Unknown prior type {self.prior_type}") - + if self.diffusion.pretransform is not None: with torch.no_grad(): reals = self.diffusion.pretransform.encode(reals) @@ -1342,15 +1327,12 @@ def training_step(self, batch, batch_idx): targets = noise * alphas - reals * sigmas with torch.cuda.amp.autocast(): - - conditioning['source'] = [source] - v = self.diffusion(noised_reals, t, cond=conditioning, cfg_dropout_prob = 0.1) - - loss_info.update({ - "v": v, - "targets": targets - }) + conditioning["source"] = [source] + + v = self.diffusion(noised_reals, t, cond=conditioning, cfg_dropout_prob=0.1) + + loss_info.update({"v": v, "targets": targets}) if self.use_reconstruction_loss: pred = noised_reals * alphas - v * sigmas @@ -1381,48 +1363,46 @@ def training_step(self, batch, batch_idx): loss_all = rearrange(self.all_gather(loss_all), "w b c n -> (w b) c n") # Bucket loss values based on corresponding sigma values, bucketing sigma values by bucket_size - loss_all = torch.stack([loss_all[(sigmas >= i) & (sigmas < i + bucket_size)].mean() for i in torch.arange(0, 1, bucket_size).to(self.device)]) + loss_all = torch.stack( + [ + loss_all[(sigmas >= i) & (sigmas < i + bucket_size)].mean() + for i in torch.arange(0, 1, bucket_size).to(self.device) + ] + ) # Log bucketed losses with corresponding sigma bucket values, if it's not NaN debug_log_dict = { - f"model/loss_all_{i/num_loss_buckets:.1f}": loss_all[i].detach() for i in range(num_loss_buckets) if not torch.isnan(loss_all[i]) + f"model/loss_all_{i/num_loss_buckets:.1f}": loss_all[i].detach() + for i in range(num_loss_buckets) + if not torch.isnan(loss_all[i]) } self.log_dict(debug_log_dict) - log_dict = { - 'train/loss': loss.detach(), - 'train/std_data': reals.std() - } + log_dict = {"train/loss": loss.detach(), "train/std_data": reals.std()} for loss_name, loss_value in losses.items(): log_dict[f"train/{loss_name}"] = loss_value.detach() self.log_dict(log_dict, prog_bar=True, on_step=True) return loss - + def on_before_zero_grad(self, *args, **kwargs): self.diffusion_ema.update() def export_model(self, path, use_safetensors=False): - #model = self.diffusion_ema.ema_model + # model = self.diffusion_ema.ema_model model = self.diffusion - + if use_safetensors: save_file(model.state_dict(), path) else: torch.save({"state_dict": model.state_dict()}, path) + class DiffusionPriorDemoCallback(pl.Callback): - def __init__( - self, - demo_dl, - demo_every=2000, - demo_steps=250, - sample_size=65536, - sample_rate=48000 - ): + def __init__(self, demo_dl, demo_every=2000, demo_steps=250, sample_size=65536, sample_rate=48000): super().__init__() self.demo_every = demo_every self.demo_steps = demo_steps @@ -1433,10 +1413,10 @@ def __init__( @rank_zero_only @torch.no_grad() - def on_train_batch_end(self, trainer, module: DiffusionAutoencoderTrainingWrapper, outputs, batch, batch_idx): + def on_train_batch_end(self, trainer, module: DiffusionAutoencoderTrainingWrapper, outputs, batch, batch_idx): if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step: return - + self.last_demo_step = trainer.global_step demo_reals, metadata = next(self.demo_dl) @@ -1456,7 +1436,6 @@ def on_train_batch_end(self, trainer, module: DiffusionAutoencoderTrainingWrappe else: conditioning_tensors = {} - with torch.no_grad() and torch.cuda.amp.autocast(): if module.prior_type == PriorType.MonoToStereo and encoder_input.shape[1] > 1: source = encoder_input.mean(dim=1, keepdim=True).repeat(1, encoder_input.shape[1], 1).to(module.device) @@ -1467,41 +1446,45 @@ def on_train_batch_end(self, trainer, module: DiffusionAutoencoderTrainingWrappe else: source_input = source - conditioning_tensors['source'] = [source_input] + conditioning_tensors["source"] = [source_input] - fakes = sample(module.diffusion_ema.model, torch.randn_like(encoder_input), self.demo_steps, 0, cond=conditioning_tensors) + fakes = sample( + module.diffusion_ema.model, + torch.randn_like(encoder_input), + self.demo_steps, + 0, + cond=conditioning_tensors, + ) if module.diffusion.pretransform is not None: fakes = module.diffusion.pretransform.decode(fakes) - #Interleave reals and fakes - reals_fakes = rearrange([demo_reals, fakes], 'i b d n -> (b i) d n') + # Interleave reals and fakes + reals_fakes = rearrange([demo_reals, fakes], "i b d n -> (b i) d n") # Put the demos together - reals_fakes = rearrange(reals_fakes, 'b d n -> d (b n)') + reals_fakes = rearrange(reals_fakes, "b d n -> d (b n)") log_dict = {} - - filename = f'recon_{trainer.global_step:08}.wav' - reals_fakes = reals_fakes.to(torch.float32).div(torch.max(torch.abs(reals_fakes))).mul(32767).to(torch.int16).cpu() + + filename = f"recon_{trainer.global_step:08}.wav" + reals_fakes = ( + reals_fakes.to(torch.float32).div(torch.max(torch.abs(reals_fakes))).mul(32767).to(torch.int16).cpu() + ) torchaudio.save(filename, reals_fakes, self.sample_rate) - log_dict[f'recon'] = wandb.Audio(filename, - sample_rate=self.sample_rate, - caption=f'Reconstructed') + log_dict[f"recon"] = wandb.Audio(filename, sample_rate=self.sample_rate, caption=f"Reconstructed") - log_dict[f'recon_melspec_left'] = wandb.Image(audio_spectrogram_image(reals_fakes)) + log_dict[f"recon_melspec_left"] = wandb.Image(audio_spectrogram_image(reals_fakes)) - #Log the source - filename = f'source_{trainer.global_step:08}.wav' - source = rearrange(source, 'b d n -> d (b n)') + # Log the source + filename = f"source_{trainer.global_step:08}.wav" + source = rearrange(source, "b d n -> d (b n)") source = source.to(torch.float32).mul(32767).to(torch.int16).cpu() torchaudio.save(filename, source, self.sample_rate) - log_dict[f'source'] = wandb.Audio(filename, - sample_rate=self.sample_rate, - caption=f'Source') + log_dict[f"source"] = wandb.Audio(filename, sample_rate=self.sample_rate, caption=f"Source") - log_dict[f'source_melspec_left'] = wandb.Image(audio_spectrogram_image(source)) + log_dict[f"source_melspec_left"] = wandb.Image(audio_spectrogram_image(source)) - trainer.logger.experiment.log(log_dict) \ No newline at end of file + trainer.logger.experiment.log(log_dict) diff --git a/stable_audio_tools/training/factory.py b/stable_audio_tools/training/factory.py index c3216d14..d182a954 100644 --- a/stable_audio_tools/training/factory.py +++ b/stable_audio_tools/training/factory.py @@ -1,22 +1,26 @@ import torch from torch.nn import Parameter + from ..models.factory import create_model_from_config + def create_training_wrapper_from_config(model_config, model): - model_type = model_config.get('model_type', None) - assert model_type is not None, 'model_type must be specified in model config' + model_type = model_config.get("model_type", None) + assert model_type is not None, "model_type must be specified in model config" - training_config = model_config.get('training', None) - assert training_config is not None, 'training config must be specified in model config' + training_config = model_config.get("training", None) + assert training_config is not None, "training config must be specified in model config" - if model_type == 'autoencoder': + if model_type == "autoencoder": from .autoencoders import AutoencoderTrainingWrapper ema_copy = None if training_config.get("use_ema", False): ema_copy = create_model_from_config(model_config) - ema_copy = create_model_from_config(model_config) # I don't know why this needs to be called twice but it broke when I called it once + ema_copy = create_model_from_config( + model_config + ) # I don't know why this needs to be called twice but it broke when I called it once # Copy each weight to the ema copy for name, param in model.state_dict().items(): if isinstance(param, Parameter): @@ -40,9 +44,9 @@ def create_training_wrapper_from_config(model_config, model): raise ValueError("teacher_model_ckpt must be specified if teacher_model is specified") return AutoencoderTrainingWrapper( - model, + model, lr=training_config["learning_rate"], - warmup_steps=training_config.get("warmup_steps", 0), + warmup_steps=training_config.get("warmup_steps", 0), encoder_freeze_on_warmup=training_config.get("encoder_freeze_on_warmup", False), sample_rate=model_config["sample_rate"], loss_config=training_config.get("loss_configs", None), @@ -51,35 +55,37 @@ def create_training_wrapper_from_config(model_config, model): ema_copy=ema_copy if use_ema else None, force_input_mono=training_config.get("force_input_mono", False), latent_mask_ratio=latent_mask_ratio, - teacher_model=teacher_model + teacher_model=teacher_model, ) - elif model_type == 'diffusion_uncond': + elif model_type == "diffusion_uncond": from .diffusion import DiffusionUncondTrainingWrapper + return DiffusionUncondTrainingWrapper( - model, + model, lr=training_config["learning_rate"], pre_encoded=training_config.get("pre_encoded", False), ) - elif model_type == 'diffusion_cond': + elif model_type == "diffusion_cond": from .diffusion import DiffusionCondTrainingWrapper + return DiffusionCondTrainingWrapper( - model, + model, lr=training_config.get("learning_rate", None), mask_padding=training_config.get("mask_padding", False), mask_padding_dropout=training_config.get("mask_padding_dropout", 0.0), - use_ema = training_config.get("use_ema", True), + use_ema=training_config.get("use_ema", True), log_loss_info=training_config.get("log_loss_info", False), optimizer_configs=training_config.get("optimizer_configs", None), pre_encoded=training_config.get("pre_encoded", False), - cfg_dropout_prob = training_config.get("cfg_dropout_prob", 0.1), - timestep_sampler = training_config.get("timestep_sampler", "uniform") + cfg_dropout_prob=training_config.get("cfg_dropout_prob", 0.1), + timestep_sampler=training_config.get("timestep_sampler", "uniform"), ) - elif model_type == 'diffusion_prior': - from .diffusion import DiffusionPriorTrainingWrapper + elif model_type == "diffusion_prior": from ..models.diffusion_prior import PriorType + from .diffusion import DiffusionPriorTrainingWrapper ema_copy = create_model_from_config(model_config) - + # Copy each weight to the ema copy for name, param in model.state_dict().items(): if isinstance(param, Parameter): @@ -95,31 +101,32 @@ def create_training_wrapper_from_config(model_config, model): raise ValueError(f"Unknown prior type: {prior_type}") return DiffusionPriorTrainingWrapper( - model, + model, lr=training_config["learning_rate"], ema_copy=ema_copy, prior_type=prior_type_enum, log_loss_info=training_config.get("log_loss_info", False), use_reconstruction_loss=training_config.get("use_reconstruction_loss", False), ) - elif model_type == 'diffusion_cond_inpaint': + elif model_type == "diffusion_cond_inpaint": from .diffusion import DiffusionCondInpaintTrainingWrapper + return DiffusionCondInpaintTrainingWrapper( - model, + model, lr=training_config.get("learning_rate", None), - max_mask_segments = training_config.get("max_mask_segments", 10), + max_mask_segments=training_config.get("max_mask_segments", 10), log_loss_info=training_config.get("log_loss_info", False), optimizer_configs=training_config.get("optimizer_configs", None), use_ema=training_config.get("use_ema", True), pre_encoded=training_config.get("pre_encoded", False), - cfg_dropout_prob = training_config.get("cfg_dropout_prob", 0.1), - timestep_sampler = training_config.get("timestep_sampler", "uniform") + cfg_dropout_prob=training_config.get("cfg_dropout_prob", 0.1), + timestep_sampler=training_config.get("timestep_sampler", "uniform"), ) - elif model_type == 'diffusion_autoencoder': + elif model_type == "diffusion_autoencoder": from .diffusion import DiffusionAutoencoderTrainingWrapper ema_copy = create_model_from_config(model_config) - + # Copy each weight to the ema copy for name, param in model.state_dict().items(): if isinstance(param, Parameter): @@ -131,9 +138,9 @@ def create_training_wrapper_from_config(model_config, model): model, ema_copy=ema_copy, lr=training_config["learning_rate"], - use_reconstruction_loss=training_config.get("use_reconstruction_loss", False) + use_reconstruction_loss=training_config.get("use_reconstruction_loss", False), ) - elif model_type == 'lm': + elif model_type == "lm": from .lm import AudioLanguageModelTrainingWrapper ema_copy = create_model_from_config(model_config) @@ -154,58 +161,63 @@ def create_training_wrapper_from_config(model_config, model): ) else: - raise NotImplementedError(f'Unknown model type: {model_type}') + raise NotImplementedError(f"Unknown model type: {model_type}") + def create_demo_callback_from_config(model_config, **kwargs): - model_type = model_config.get('model_type', None) - assert model_type is not None, 'model_type must be specified in model config' + model_type = model_config.get("model_type", None) + assert model_type is not None, "model_type must be specified in model config" - training_config = model_config.get('training', None) - assert training_config is not None, 'training config must be specified in model config' + training_config = model_config.get("training", None) + assert training_config is not None, "training config must be specified in model config" demo_config = training_config.get("demo", {}) - if model_type == 'autoencoder': + if model_type == "autoencoder": from .autoencoders import AutoencoderDemoCallback + return AutoencoderDemoCallback( - demo_every=demo_config.get("demo_every", 2000), - sample_size=model_config["sample_size"], + demo_every=demo_config.get("demo_every", 2000), + sample_size=model_config["sample_size"], sample_rate=model_config["sample_rate"], - **kwargs + **kwargs, ) - elif model_type == 'diffusion_uncond': + elif model_type == "diffusion_uncond": from .diffusion import DiffusionUncondDemoCallback + return DiffusionUncondDemoCallback( - demo_every=demo_config.get("demo_every", 2000), - demo_steps=demo_config.get("demo_steps", 250), - sample_rate=model_config["sample_rate"] + demo_every=demo_config.get("demo_every", 2000), + demo_steps=demo_config.get("demo_steps", 250), + sample_rate=model_config["sample_rate"], ) elif model_type == "diffusion_autoencoder": from .diffusion import DiffusionAutoencoderDemoCallback + return DiffusionAutoencoderDemoCallback( - demo_every=demo_config.get("demo_every", 2000), + demo_every=demo_config.get("demo_every", 2000), demo_steps=demo_config.get("demo_steps", 250), sample_size=model_config["sample_size"], sample_rate=model_config["sample_rate"], - **kwargs + **kwargs, ) elif model_type == "diffusion_prior": from .diffusion import DiffusionPriorDemoCallback + return DiffusionPriorDemoCallback( - demo_every=demo_config.get("demo_every", 2000), + demo_every=demo_config.get("demo_every", 2000), demo_steps=demo_config.get("demo_steps", 250), sample_size=model_config["sample_size"], sample_rate=model_config["sample_rate"], - **kwargs + **kwargs, ) elif model_type == "diffusion_cond": from .diffusion import DiffusionCondDemoCallback return DiffusionCondDemoCallback( - demo_every=demo_config.get("demo_every", 2000), + demo_every=demo_config.get("demo_every", 2000), sample_size=model_config["sample_size"], sample_rate=model_config["sample_rate"], - demo_steps=demo_config.get("demo_steps", 250), + demo_steps=demo_config.get("demo_steps", 250), num_demos=demo_config["num_demos"], demo_cfg_scales=demo_config["demo_cfg_scales"], demo_conditioning=demo_config.get("demo_cond", {}), @@ -216,25 +228,25 @@ def create_demo_callback_from_config(model_config, **kwargs): from .diffusion import DiffusionCondInpaintDemoCallback return DiffusionCondInpaintDemoCallback( - demo_every=demo_config.get("demo_every", 2000), + demo_every=demo_config.get("demo_every", 2000), sample_size=model_config["sample_size"], sample_rate=model_config["sample_rate"], demo_steps=demo_config.get("demo_steps", 250), demo_cfg_scales=demo_config["demo_cfg_scales"], - **kwargs + **kwargs, ) - + elif model_type == "lm": from .lm import AudioLanguageModelDemoCallback return AudioLanguageModelDemoCallback( - demo_every=demo_config.get("demo_every", 2000), + demo_every=demo_config.get("demo_every", 2000), sample_size=model_config["sample_size"], sample_rate=model_config["sample_rate"], demo_cfg_scales=demo_config.get("demo_cfg_scales", [1]), demo_conditioning=demo_config.get("demo_cond", None), num_demos=demo_config.get("num_demos", 8), - **kwargs + **kwargs, ) else: - raise NotImplementedError(f'Unknown model type: {model_type}') \ No newline at end of file + raise NotImplementedError(f"Unknown model type: {model_type}") diff --git a/stable_audio_tools/training/lm.py b/stable_audio_tools/training/lm.py index e1fa9f71..e154d72b 100644 --- a/stable_audio_tools/training/lm.py +++ b/stable_audio_tools/training/lm.py @@ -1,32 +1,34 @@ -import pytorch_lightning as pl -import sys, gc +import gc import random +import sys +import typing as tp + +import pytorch_lightning as pl import torch import torchaudio -import typing as tp import wandb - -from aeiou.viz import pca_point_cloud, audio_spectrogram_image, tokens_spectrogram_image -from ema_pytorch import EMA +from aeiou.viz import audio_spectrogram_image, pca_point_cloud, tokens_spectrogram_image from einops import rearrange +from ema_pytorch import EMA +from pytorch_lightning.utilities.rank_zero import rank_zero_only from safetensors.torch import save_file from torch import optim from torch.nn import functional as F -from pytorch_lightning.utilities.rank_zero import rank_zero_only from ..models.lm import AudioLanguageModelWrapper from .utils import create_optimizer_from_config, create_scheduler_from_config + class AudioLanguageModelTrainingWrapper(pl.LightningModule): def __init__( - self, - model: AudioLanguageModelWrapper, - lr = 1e-4, - use_ema=False, - ema_copy=None, - optimizer_configs: dict = None, - pre_encoded=False - ): + self, + model: AudioLanguageModelWrapper, + lr=1e-4, + use_ema=False, + ema_copy=None, + optimizer_configs: dict = None, + pre_encoded=False, + ): super().__init__() self.model = model @@ -37,43 +39,35 @@ def __init__( if use_ema: self.model_ema = EMA(self.model, ema_model=ema_copy, beta=0.99, update_every=10) - assert lr is not None or optimizer_configs is not None, "Must specify either lr or optimizer_configs in training config" + assert ( + lr is not None or optimizer_configs is not None + ), "Must specify either lr or optimizer_configs in training config" if optimizer_configs is None: optimizer_configs = { - "lm": { - "optimizer": { - "type": "AdamW", - "config": { - "lr": lr, - "betas": (0.9, 0.95), - "weight_decay": 0.1 - } - } - } + "lm": {"optimizer": {"type": "AdamW", "config": {"lr": lr, "betas": (0.9, 0.95), "weight_decay": 0.1}}} } else: if lr is not None: - print(f"WARNING: learning_rate and optimizer_configs both specified in config. Ignoring learning_rate and using optimizer_configs.") + print( + f"WARNING: learning_rate and optimizer_configs both specified in config. Ignoring learning_rate and using optimizer_configs." + ) self.optimizer_configs = optimizer_configs self.pre_encoded = pre_encoded def configure_optimizers(self): - lm_opt_config = self.optimizer_configs['lm'] - opt_lm = create_optimizer_from_config(lm_opt_config['optimizer'], self.model.parameters()) + lm_opt_config = self.optimizer_configs["lm"] + opt_lm = create_optimizer_from_config(lm_opt_config["optimizer"], self.model.parameters()) if "scheduler" in lm_opt_config: - sched_lm = create_scheduler_from_config(lm_opt_config['scheduler'], opt_lm) - sched_lm_config = { - "scheduler": sched_lm, - "interval": "step" - } + sched_lm = create_scheduler_from_config(lm_opt_config["scheduler"], opt_lm) + sched_lm_config = {"scheduler": sched_lm, "interval": "step"} return [opt_lm], [sched_lm_config] return [opt_lm] - + # Copied and modified from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/solvers/musicgen.py under MIT license # License can be found in LICENSES/LICENSE_META.txt @@ -128,12 +122,12 @@ def training_step(self, batch, batch_idx): padding_masks.append(md["padding_mask"]) else: padding_masks.append(md["padding_mask"][0]) - - padding_masks = torch.stack(padding_masks, dim=0).to(self.device) # Shape (batch_size, sequence_length) + + padding_masks = torch.stack(padding_masks, dim=0).to(self.device) # Shape (batch_size, sequence_length) # Interpolate padding masks to the same length as the codes - padding_masks = F.interpolate(padding_masks.unsqueeze(1).float(), size=codes.shape[2], mode='nearest').bool() - + padding_masks = F.interpolate(padding_masks.unsqueeze(1).float(), size=codes.shape[2], mode="nearest").bool() + condition_tensors = None # If the model is conditioned, get the conditioning tensors @@ -142,8 +136,8 @@ def training_step(self, batch, batch_idx): lm_output = self.model.compute_logits(codes, condition_tensors=condition_tensors, cfg_dropout_prob=0.1) - logits = lm_output.logits # [b, k, t, c] - logits_mask = lm_output.mask # [b, k, t] + logits = lm_output.logits # [b, k, t, c] + logits_mask = lm_output.mask # [b, k, t] logits_mask = logits_mask & padding_masks @@ -152,15 +146,15 @@ def training_step(self, batch, batch_idx): loss = cross_entropy log_dict = { - 'train/loss': loss.detach(), - 'train/cross_entropy': cross_entropy.detach(), - 'train/perplexity': torch.exp(cross_entropy).detach(), - 'train/lr': self.trainer.optimizers[0].param_groups[0]['lr'] + "train/loss": loss.detach(), + "train/cross_entropy": cross_entropy.detach(), + "train/perplexity": torch.exp(cross_entropy).detach(), + "train/lr": self.trainer.optimizers[0].param_groups[0]["lr"], } for k, ce_q in enumerate(cross_entropy_per_codebook): - log_dict[f'cross_entropy_q{k + 1}'] = ce_q - log_dict[f'perplexity_q{k + 1}'] = torch.exp(ce_q) + log_dict[f"cross_entropy_q{k + 1}"] = ce_q + log_dict[f"perplexity_q{k + 1}"] = torch.exp(ce_q) self.log_dict(log_dict, prog_bar=True, on_step=True) return loss @@ -170,24 +164,25 @@ def on_before_zero_grad(self, *args, **kwargs): self.model_ema.update() def export_model(self, path, use_safetensors=False): - + model = self.model_ema.ema_model if self.model_ema is not None else self.model if use_safetensors: save_file(model.state_dict(), path) else: torch.save({"state_dict": model.state_dict()}, path) - + class AudioLanguageModelDemoCallback(pl.Callback): - def __init__(self, - demo_every=2000, - num_demos=8, - sample_size=65536, - sample_rate=48000, - demo_conditioning: tp.Optional[tp.Dict[str, tp.Any]] = None, - demo_cfg_scales: tp.Optional[tp.List[int]] = [3, 5, 7], - **kwargs + def __init__( + self, + demo_every=2000, + num_demos=8, + sample_size=65536, + sample_rate=48000, + demo_conditioning: tp.Optional[tp.Dict[str, tp.Any]] = None, + demo_cfg_scales: tp.Optional[tp.List[int]] = [3, 5, 7], + **kwargs, ): super().__init__() @@ -201,7 +196,7 @@ def __init__(self, @rank_zero_only @torch.no_grad() - def on_train_batch_end(self, trainer, module: AudioLanguageModelTrainingWrapper, outputs, batch, batch_idx): + def on_train_batch_end(self, trainer, module: AudioLanguageModelTrainingWrapper, outputs, batch, batch_idx): if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step: return @@ -213,49 +208,49 @@ def on_train_batch_end(self, trainer, module: AudioLanguageModelTrainingWrapper, demo_length_tokens = self.demo_samples // module.model.pretransform.downsampling_ratio - #demo_reals = batch[0][:self.num_demos] + # demo_reals = batch[0][:self.num_demos] # if demo_reals.ndim == 4 and demo_reals.shape[0] == 1: # demo_reals = demo_reals[0] - #demo_reals_tokens = module.model.pretransform.tokenize(demo_reals) + # demo_reals_tokens = module.model.pretransform.tokenize(demo_reals) ##Limit to first 50 tokens - #demo_reals_tokens = demo_reals_tokens[:, :, :50] + # demo_reals_tokens = demo_reals_tokens[:, :, :50] try: print("Getting conditioning") for cfg_scale in self.demo_cfg_scales: - model = module.model # module.model_ema.ema_model if module.model_ema is not None else module.model + model = module.model # module.model_ema.ema_model if module.model_ema is not None else module.model print(f"Generating demo for cfg scale {cfg_scale}") fakes = model.generate_audio( batch_size=self.num_demos, - max_gen_len=demo_length_tokens, - conditioning=self.demo_conditioning, - #init_data = demo_reals_tokens, + max_gen_len=demo_length_tokens, + conditioning=self.demo_conditioning, + # init_data = demo_reals_tokens, cfg_scale=cfg_scale, temp=1.0, - top_p=0.95 + top_p=0.95, ) # Put the demos together - fakes = rearrange(fakes, 'b d n -> d (b n)') + fakes = rearrange(fakes, "b d n -> d (b n)") log_dict = {} - - filename = f'demo_cfg_{cfg_scale}_{trainer.global_step:08}.wav' + + filename = f"demo_cfg_{cfg_scale}_{trainer.global_step:08}.wav" fakes = fakes / fakes.abs().max() fakes = fakes.type(torch.float32).clamp(-1, 1).mul(32767).type(torch.int16).cpu() torchaudio.save(filename, fakes, self.sample_rate) - log_dict[f'demo_cfg_{cfg_scale}'] = wandb.Audio(filename, - sample_rate=self.sample_rate, - caption=f'Reconstructed') - - log_dict[f'demo_melspec_left_cfg_{cfg_scale}'] = wandb.Image(audio_spectrogram_image(fakes)) + log_dict[f"demo_cfg_{cfg_scale}"] = wandb.Audio( + filename, sample_rate=self.sample_rate, caption=f"Reconstructed" + ) + + log_dict[f"demo_melspec_left_cfg_{cfg_scale}"] = wandb.Image(audio_spectrogram_image(fakes)) trainer.logger.experiment.log(log_dict) @@ -264,4 +259,4 @@ def on_train_batch_end(self, trainer, module: AudioLanguageModelTrainingWrapper, finally: gc.collect() torch.cuda.empty_cache() - module.train() \ No newline at end of file + module.train() diff --git a/stable_audio_tools/training/losses/__init__.py b/stable_audio_tools/training/losses/__init__.py index 37fdea0e..7d39f7af 100644 --- a/stable_audio_tools/training/losses/__init__.py +++ b/stable_audio_tools/training/losses/__init__.py @@ -1 +1 @@ -from .losses import * \ No newline at end of file +from .losses import * diff --git a/stable_audio_tools/training/losses/auraloss.py b/stable_audio_tools/training/losses/auraloss.py index 9ab5405d..e669a318 100644 --- a/stable_audio_tools/training/losses/auraloss.py +++ b/stable_audio_tools/training/losses/auraloss.py @@ -1,10 +1,12 @@ # Copied and modified from https://github.com/csteinmetz1/auraloss/blob/main/auraloss/freq.py under Apache License 2.0 # You can find the license at LICENSES/LICENSE_AURALOSS.txt -import torch +from typing import Any, List + import numpy as np -from typing import List, Any import scipy.signal +import torch + def apply_reduction(losses, reduction="none"): """Apply reduction to collection of losses.""" @@ -14,6 +16,7 @@ def apply_reduction(losses, reduction="none"): losses = losses.sum() return losses + def get_window(win_type: str, win_length: int): """Return a window function. @@ -34,6 +37,7 @@ def get_window(win_type: str, win_length: int): return win + class SumAndDifference(torch.nn.Module): """Sum and difference signal extraction module.""" @@ -122,9 +126,7 @@ def __init__(self, filter_type="hp", coef=0.85, fs=44100, ntaps=101, plot=False) [1, 4 * np.pi * f4, (2 * np.pi * f4) ** 2], [1, 4 * np.pi * f1, (2 * np.pi * f1) ** 2], ) - DENs = np.polymul( - np.polymul(DENs, [1, 2 * np.pi * f3]), [1, 2 * np.pi * f2] - ) + DENs = np.polymul(np.polymul(DENs, [1, 2 * np.pi * f3]), [1, 2 * np.pi * f2]) # convert analog filter to digital filter b, a = scipy.signal.bilinear(NUMs, DENs, fs=fs) @@ -136,14 +138,13 @@ def __init__(self, filter_type="hp", coef=0.85, fs=44100, ntaps=101, plot=False) taps = scipy.signal.firls(ntaps, w_iir, abs(h_iir), fs=fs) # now implement this digital FIR filter as a Conv1d layer - self.fir = torch.nn.Conv1d( - 1, 1, kernel_size=ntaps, bias=False, padding=ntaps // 2 - ) + self.fir = torch.nn.Conv1d(1, 1, kernel_size=ntaps, bias=False, padding=ntaps // 2) self.fir.weight.requires_grad = False self.fir.weight.data = torch.tensor(taps.astype("float32")).view(1, 1, -1) if plot: from .plotting import compare_filters + compare_filters(b, a, taps, fs=fs) def forward(self, input, target): @@ -154,14 +155,11 @@ def forward(self, input, target): Returns: Tensor: Filtered signal. """ - input = torch.nn.functional.conv1d( - input, self.fir.weight.data, padding=self.ntaps // 2 - ) - target = torch.nn.functional.conv1d( - target, self.fir.weight.data, padding=self.ntaps // 2 - ) + input = torch.nn.functional.conv1d(input, self.fir.weight.data, padding=self.ntaps // 2) + target = torch.nn.functional.conv1d(target, self.fir.weight.data, padding=self.ntaps // 2) return input, target + class SpectralConvergenceLoss(torch.nn.Module): """Spectral convergence loss module. @@ -174,6 +172,7 @@ def __init__(self): def forward(self, x_mag, y_mag): return (torch.norm(y_mag - x_mag, p="fro", dim=[-1, -2]) / torch.norm(y_mag, p="fro", dim=[-1, -2])).mean() + class STFTMagnitudeLoss(torch.nn.Module): """STFT magnitude loss module. @@ -281,7 +280,7 @@ def __init__( reduction: str = "mean", mag_distance: str = "L1", device: Any = None, - **kwargs + **kwargs, ): super().__init__() self.fft_size = fft_size @@ -306,18 +305,8 @@ def __init__( self.phs_used = bool(self.w_phs) self.spectralconv = SpectralConvergenceLoss() - self.logstft = STFTMagnitudeLoss( - log=True, - reduction=reduction, - distance=mag_distance, - **kwargs - ) - self.linstft = STFTMagnitudeLoss( - log=False, - reduction=reduction, - distance=mag_distance, - **kwargs - ) + self.logstft = STFTMagnitudeLoss(log=True, reduction=reduction, distance=mag_distance, **kwargs) + self.linstft = STFTMagnitudeLoss(log=False, reduction=reduction, distance=mag_distance, **kwargs) # setup mel filterbank if scale is not None: @@ -336,14 +325,10 @@ def __init__( elif self.scale == "chroma": assert sample_rate != None # Must set sample rate to use chroma scale assert n_bins <= fft_size # Must be more FFT bins than chroma bins - fb = librosa.filters.chroma( - sr=sample_rate, n_fft=fft_size, n_chroma=n_bins - ) + fb = librosa.filters.chroma(sr=sample_rate, n_fft=fft_size, n_chroma=n_bins) else: - raise ValueError( - f"Invalid scale: {self.scale}. Must be 'mel' or 'chroma'." - ) + raise ValueError(f"Invalid scale: {self.scale}. Must be 'mel' or 'chroma'.") self.register_buffer("fb", fb) @@ -352,9 +337,7 @@ def __init__( if self.perceptual_weighting: if sample_rate is None: - raise ValueError( - f"`sample_rate` must be supplied when `perceptual_weighting = True`." - ) + raise ValueError(f"`sample_rate` must be supplied when `perceptual_weighting = True`.") self.prefilter = FIRFilter(filter_type="aw", fs=sample_rate) def stft(self, x): @@ -374,9 +357,7 @@ def stft(self, x): self.window, return_complex=True, ) - x_mag = torch.sqrt( - torch.clamp((x_stft.real**2) + (x_stft.imag**2), min=self.eps) - ) + x_mag = torch.sqrt(torch.clamp((x_stft.real**2) + (x_stft.imag**2), min=self.eps)) # torch.angle is expensive, so it is only evaluated if the values are used in the loss if self.phs_used: @@ -440,6 +421,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor): elif self.output == "full": return loss, sc_mag_loss, log_mag_loss, lin_mag_loss, phs_loss + class MultiResolutionSTFTLoss(torch.nn.Module): """Multi resolution STFT loss module. @@ -604,4 +586,4 @@ def forward(self, input: torch.Tensor, target: torch.Tensor): if self.output == "loss": return loss elif self.output == "full": - return loss, sum_loss, diff_loss \ No newline at end of file + return loss, sum_loss, diff_loss diff --git a/stable_audio_tools/training/losses/losses.py b/stable_audio_tools/training/losses/losses.py index 55b15461..28402b80 100644 --- a/stable_audio_tools/training/losses/losses.py +++ b/stable_audio_tools/training/losses/losses.py @@ -1,7 +1,8 @@ import typing as tp -from torch.nn import functional as F from torch import nn +from torch.nn import functional as F + class LossModule(nn.Module): def __init__(self, name: str, weight: float = 1.0): @@ -12,27 +13,29 @@ def __init__(self, name: str, weight: float = 1.0): def forward(self, info, *args, **kwargs): raise NotImplementedError - + + class ValueLoss(LossModule): def __init__(self, key: str, name, weight: float = 1.0): super().__init__(name=name, weight=weight) self.key = key - + def forward(self, info): return self.weight * info[self.key] + class L1Loss(LossModule): - def __init__(self, key_a: str, key_b: str, weight: float = 1.0, mask_key: str = None, name: str = 'l1_loss'): + def __init__(self, key_a: str, key_b: str, weight: float = 1.0, mask_key: str = None, name: str = "l1_loss"): super().__init__(name=name, weight=weight) self.key_a = key_a self.key_b = key_b self.mask_key = mask_key - + def forward(self, info): - mse_loss = F.l1_loss(info[self.key_a], info[self.key_b], reduction='none') + mse_loss = F.l1_loss(info[self.key_a], info[self.key_b], reduction="none") if self.mask_key is not None and self.mask_key in info: mse_loss = mse_loss[info[self.mask_key]] @@ -40,18 +43,19 @@ def forward(self, info): mse_loss = mse_loss.mean() return self.weight * mse_loss - + + class MSELoss(LossModule): - def __init__(self, key_a: str, key_b: str, weight: float = 1.0, mask_key: str = None, name: str = 'mse_loss'): + def __init__(self, key_a: str, key_b: str, weight: float = 1.0, mask_key: str = None, name: str = "mse_loss"): super().__init__(name=name, weight=weight) self.key_a = key_a self.key_b = key_b self.mask_key = mask_key - + def forward(self, info): - mse_loss = F.mse_loss(info[self.key_a], info[self.key_b], reduction='none') + mse_loss = F.mse_loss(info[self.key_a], info[self.key_b], reduction="none") if self.mask_key is not None and self.mask_key in info and info[self.mask_key] is not None: mask = info[self.mask_key] @@ -67,7 +71,8 @@ def forward(self, info): mse_loss = mse_loss.mean() return self.weight * mse_loss - + + class AuralossLoss(LossModule): def __init__(self, auraloss_module, input_key: str, target_key: str, name: str, weight: float = 1): super().__init__(name, weight) @@ -81,7 +86,8 @@ def forward(self, info): loss = self.auraloss_module(info[self.input_key], info[self.target_key]) return self.weight * loss - + + class MultiLoss(nn.Module): def __init__(self, losses: tp.List[LossModule]): super().__init__() @@ -98,4 +104,4 @@ def forward(self, info): total_loss += module_loss losses[loss_module.name] = module_loss - return total_loss, losses \ No newline at end of file + return total_loss, losses diff --git a/stable_audio_tools/training/utils.py b/stable_audio_tools/training/utils.py index 38a3fccc..8f940f33 100644 --- a/stable_audio_tools/training/utils.py +++ b/stable_audio_tools/training/utils.py @@ -1,9 +1,11 @@ -import torch import os +import torch + + def get_rank(): """Get rank of current process.""" - + print(os.environ.keys()) if "SLURM_PROCID" in os.environ: @@ -11,9 +13,10 @@ def get_rank(): if not torch.distributed.is_available() or not torch.distributed.is_initialized(): return 0 - + return torch.distributed.get_rank() + class InverseLR(torch.optim.lr_scheduler._LRScheduler): """Implements an inverse decay learning rate schedule with an optional exponential warmup. When last_epoch=-1, sets initial lr as lr. @@ -31,12 +34,11 @@ class InverseLR(torch.optim.lr_scheduler._LRScheduler): each update. Default: ``False``. """ - def __init__(self, optimizer, inv_gamma=1., power=1., warmup=0., final_lr=0., - last_epoch=-1, verbose=False): + def __init__(self, optimizer, inv_gamma=1.0, power=1.0, warmup=0.0, final_lr=0.0, last_epoch=-1, verbose=False): self.inv_gamma = inv_gamma self.power = power - if not 0. <= warmup < 1: - raise ValueError('Invalid value for warmup') + if not 0.0 <= warmup < 1: + raise ValueError("Invalid value for warmup") self.warmup = warmup self.final_lr = final_lr super().__init__(optimizer, last_epoch, verbose) @@ -44,16 +46,16 @@ def __init__(self, optimizer, inv_gamma=1., power=1., warmup=0., final_lr=0., def get_lr(self): if not self._get_lr_called_within_step: import warnings - warnings.warn("To get the last learning rate computed by the scheduler, " - "please use `get_last_lr()`.") + + warnings.warn("To get the last learning rate computed by the scheduler, " "please use `get_last_lr()`.") return self._get_closed_form_lr() def _get_closed_form_lr(self): warmup = 1 - self.warmup ** (self.last_epoch + 1) lr_mult = (1 + self.last_epoch / self.inv_gamma) ** -self.power - return [warmup * max(self.final_lr, base_lr * lr_mult) - for base_lr in self.base_lrs] + return [warmup * max(self.final_lr, base_lr * lr_mult) for base_lr in self.base_lrs] + def copy_state_dict(model, state_dict): """Load state_dict to model, but only for keys that match exactly. @@ -69,9 +71,10 @@ def copy_state_dict(model, state_dict): # backwards compatibility for serialized parameters state_dict[key] = state_dict[key].data model_state_dict[key] = state_dict[key] - + model.load_state_dict(model_state_dict, strict=False) + def create_optimizer_from_config(optimizer_config, parameters): """Create optimizer from config. @@ -87,12 +90,14 @@ def create_optimizer_from_config(optimizer_config, parameters): if optimizer_type == "FusedAdam": from deepspeed.ops.adam import FusedAdam + optimizer = FusedAdam(parameters, **optimizer_config["config"]) else: optimizer_fn = getattr(torch.optim, optimizer_type) optimizer = optimizer_fn(parameters, **optimizer_config["config"]) return optimizer + def create_scheduler_from_config(scheduler_config, optimizer): """Create scheduler from config. @@ -108,4 +113,4 @@ def create_scheduler_from_config(scheduler_config, optimizer): else: scheduler_fn = getattr(torch.optim.lr_scheduler, scheduler_config["type"]) scheduler = scheduler_fn(optimizer, **scheduler_config["config"]) - return scheduler \ No newline at end of file + return scheduler diff --git a/train.py b/train.py index 8dbbfbbb..803b1d82 100644 --- a/train.py +++ b/train.py @@ -1,19 +1,28 @@ -from prefigure.prefigure import get_all_args, push_wandb_config import json import os -import torch -import pytorch_lightning as pl import random +import pytorch_lightning as pl +import torch +from prefigure.prefigure import get_all_args, push_wandb_config + from stable_audio_tools.data.dataset import create_dataloader_from_config from stable_audio_tools.models import create_model_from_config -from stable_audio_tools.models.utils import load_ckpt_state_dict, remove_weight_norm_from_model -from stable_audio_tools.training import create_training_wrapper_from_config, create_demo_callback_from_config +from stable_audio_tools.models.utils import ( + load_ckpt_state_dict, + remove_weight_norm_from_model, +) +from stable_audio_tools.training import ( + create_demo_callback_from_config, + create_training_wrapper_from_config, +) from stable_audio_tools.training.utils import copy_state_dict + class ExceptionCallback(pl.Callback): def on_exception(self, trainer, module, err): - print(f'{type(err).__name__}: {err}') + print(f"{type(err).__name__}: {err}") + class ModelConfigEmbedderCallback(pl.Callback): def __init__(self, model_config): @@ -22,6 +31,7 @@ def __init__(self, model_config): def on_save_checkpoint(self, trainer, pl_module, checkpoint): checkpoint["model_config"] = self.model_config + def main(): args = get_all_args() @@ -35,7 +45,7 @@ def main(): random.seed(seed) torch.manual_seed(seed) - #Get JSON config from args.model_config + # Get JSON config from args.model_config with open(args.model_config) as f: model_config = json.load(f) @@ -43,8 +53,8 @@ def main(): dataset_config = json.load(f) train_dl = create_dataloader_from_config( - dataset_config, - batch_size=args.batch_size, + dataset_config, + batch_size=args.batch_size, num_workers=args.num_workers, sample_rate=model_config["sample_rate"], sample_size=model_config["sample_size"], @@ -55,13 +65,13 @@ def main(): if args.pretrained_ckpt_path: copy_state_dict(model, load_ckpt_state_dict(args.pretrained_ckpt_path)) - + if args.remove_pretransform_weight_norm == "pre_load": remove_weight_norm_from_model(model.pretransform) if args.pretransform_ckpt_path: model.pretransform.load_state_dict(load_ckpt_state_dict(args.pretransform_ckpt_path)) - + # Remove weight_norm from the pretransform if specified if args.remove_pretransform_weight_norm == "post_load": remove_weight_norm_from_model(model.pretransform) @@ -72,57 +82,64 @@ def main(): wandb_logger.watch(training_wrapper) exc_callback = ExceptionCallback() - + if args.save_dir and isinstance(wandb_logger.experiment.id, str): - checkpoint_dir = os.path.join(args.save_dir, wandb_logger.experiment.project, wandb_logger.experiment.id, "checkpoints") + checkpoint_dir = os.path.join( + args.save_dir, wandb_logger.experiment.project, wandb_logger.experiment.id, "checkpoints" + ) else: checkpoint_dir = None - ckpt_callback = pl.callbacks.ModelCheckpoint(every_n_train_steps=args.checkpoint_every, dirpath=checkpoint_dir, save_top_k=-1) + ckpt_callback = pl.callbacks.ModelCheckpoint( + every_n_train_steps=args.checkpoint_every, dirpath=checkpoint_dir, save_top_k=-1 + ) save_model_config_callback = ModelConfigEmbedderCallback(model_config) demo_callback = create_demo_callback_from_config(model_config, demo_dl=train_dl) - #Combine args and config dicts + # Combine args and config dicts args_dict = vars(args) args_dict.update({"model_config": model_config}) args_dict.update({"dataset_config": dataset_config}) push_wandb_config(wandb_logger, args_dict) - #Set multi-GPU strategy if specified + # Set multi-GPU strategy if specified if args.strategy: if args.strategy == "deepspeed": from pytorch_lightning.strategies import DeepSpeedStrategy - strategy = DeepSpeedStrategy(stage=2, - contiguous_gradients=True, - overlap_comm=True, - reduce_scatter=True, - reduce_bucket_size=5e8, - allgather_bucket_size=5e8, - load_full_weights=True - ) + + strategy = DeepSpeedStrategy( + stage=2, + contiguous_gradients=True, + overlap_comm=True, + reduce_scatter=True, + reduce_bucket_size=5e8, + allgather_bucket_size=5e8, + load_full_weights=True, + ) else: strategy = args.strategy else: - strategy = 'ddp_find_unused_parameters_true' if args.num_gpus > 1 else "auto" + strategy = "ddp_find_unused_parameters_true" if args.num_gpus > 1 else "auto" trainer = pl.Trainer( devices=args.num_gpus, accelerator="gpu", - num_nodes = args.num_nodes, + num_nodes=args.num_nodes, strategy=strategy, precision=args.precision, - accumulate_grad_batches=args.accum_batches, + accumulate_grad_batches=args.accum_batches, callbacks=[ckpt_callback, demo_callback, exc_callback, save_model_config_callback], logger=wandb_logger, log_every_n_steps=1, max_epochs=10000000, default_root_dir=args.save_dir, gradient_clip_val=args.gradient_clip_val, - reload_dataloaders_every_n_epochs = 0 + reload_dataloaders_every_n_epochs=0, ) trainer.fit(training_wrapper, train_dl, ckpt_path=args.ckpt_path if args.ckpt_path else None) -if __name__ == '__main__': - main() \ No newline at end of file + +if __name__ == "__main__": + main() diff --git a/unwrap_model.py b/unwrap_model.py index 4afb7bb7..0d32df66 100644 --- a/unwrap_model.py +++ b/unwrap_model.py @@ -1,39 +1,44 @@ import argparse import json + import torch from torch.nn.parameter import Parameter + from stable_audio_tools.models import create_model_from_config -if __name__ == '__main__': +if __name__ == "__main__": args = argparse.ArgumentParser() - args.add_argument('--model-config', type=str, default=None) - args.add_argument('--ckpt-path', type=str, default=None) - args.add_argument('--name', type=str, default='exported_model') - args.add_argument('--use-safetensors', action='store_true') + args.add_argument("--model-config", type=str, default=None) + args.add_argument("--ckpt-path", type=str, default=None) + args.add_argument("--name", type=str, default="exported_model") + args.add_argument("--use-safetensors", action="store_true") args = args.parse_args() with open(args.model_config) as f: model_config = json.load(f) - + model = create_model_from_config(model_config) - - model_type = model_config.get('model_type', None) - assert model_type is not None, 'model_type must be specified in model config' + model_type = model_config.get("model_type", None) - training_config = model_config.get('training', None) + assert model_type is not None, "model_type must be specified in model config" - if model_type == 'autoencoder': + training_config = model_config.get("training", None) + + if model_type == "autoencoder": from stable_audio_tools.training.autoencoders import AutoencoderTrainingWrapper - + ema_copy = None if training_config.get("use_ema", False): from stable_audio_tools.models.factory import create_model_from_config + ema_copy = create_model_from_config(model_config) - ema_copy = create_model_from_config(model_config) # I don't know why this needs to be called twice but it broke when I called it once - + ema_copy = create_model_from_config( + model_config + ) # I don't know why this needs to be called twice but it broke when I called it once + # Copy each weight to the ema copy for name, param in model.state_dict().items(): if isinstance(param, Parameter): @@ -44,58 +49,72 @@ use_ema = training_config.get("use_ema", False) training_wrapper = AutoencoderTrainingWrapper.load_from_checkpoint( - args.ckpt_path, - autoencoder=model, + args.ckpt_path, + autoencoder=model, strict=False, loss_config=training_config["loss_configs"], use_ema=training_config["use_ema"], - ema_copy=ema_copy if use_ema else None + ema_copy=ema_copy if use_ema else None, ) - elif model_type == 'diffusion_uncond': + elif model_type == "diffusion_uncond": from stable_audio_tools.training.diffusion import DiffusionUncondTrainingWrapper - training_wrapper = DiffusionUncondTrainingWrapper.load_from_checkpoint(args.ckpt_path, model=model, strict=False) - elif model_type == 'diffusion_autoencoder': - from stable_audio_tools.training.diffusion import DiffusionAutoencoderTrainingWrapper + training_wrapper = DiffusionUncondTrainingWrapper.load_from_checkpoint( + args.ckpt_path, model=model, strict=False + ) + + elif model_type == "diffusion_autoencoder": + from stable_audio_tools.training.diffusion import ( + DiffusionAutoencoderTrainingWrapper, + ) ema_copy = create_model_from_config(model_config) - + for name, param in model.state_dict().items(): if isinstance(param, Parameter): # backwards compatibility for serialized parameters param = param.data ema_copy.state_dict()[name].copy_(param) - training_wrapper = DiffusionAutoencoderTrainingWrapper.load_from_checkpoint(args.ckpt_path, model=model, ema_copy=ema_copy, strict=False) - elif model_type == 'diffusion_cond': + training_wrapper = DiffusionAutoencoderTrainingWrapper.load_from_checkpoint( + args.ckpt_path, model=model, ema_copy=ema_copy, strict=False + ) + elif model_type == "diffusion_cond": from stable_audio_tools.training.diffusion import DiffusionCondTrainingWrapper - + use_ema = training_config.get("use_ema", True) - + training_wrapper = DiffusionCondTrainingWrapper.load_from_checkpoint( - args.ckpt_path, - model=model, - use_ema=use_ema, + args.ckpt_path, + model=model, + use_ema=use_ema, lr=training_config.get("learning_rate", None), optimizer_configs=training_config.get("optimizer_configs", None), - strict=False + strict=False, + ) + elif model_type == "diffusion_cond_inpaint": + from stable_audio_tools.training.diffusion import ( + DiffusionCondInpaintTrainingWrapper, + ) + + training_wrapper = DiffusionCondInpaintTrainingWrapper.load_from_checkpoint( + args.ckpt_path, model=model, strict=False ) - elif model_type == 'diffusion_cond_inpaint': - from stable_audio_tools.training.diffusion import DiffusionCondInpaintTrainingWrapper - training_wrapper = DiffusionCondInpaintTrainingWrapper.load_from_checkpoint(args.ckpt_path, model=model, strict=False) - elif model_type == 'diffusion_prior': + elif model_type == "diffusion_prior": from stable_audio_tools.training.diffusion import DiffusionPriorTrainingWrapper ema_copy = create_model_from_config(model_config) - + for name, param in model.state_dict().items(): if isinstance(param, Parameter): # backwards compatibility for serialized parameters param = param.data ema_copy.state_dict()[name].copy_(param) - training_wrapper = DiffusionPriorTrainingWrapper.load_from_checkpoint(args.ckpt_path, model=model, strict=False, ema_copy=ema_copy) - elif model_type == 'lm': + training_wrapper = DiffusionPriorTrainingWrapper.load_from_checkpoint( + args.ckpt_path, model=model, strict=False, ema_copy=ema_copy + ) + elif model_type == "lm": from stable_audio_tools.training.lm import AudioLanguageModelTrainingWrapper ema_copy = None @@ -111,16 +130,16 @@ ema_copy.state_dict()[name].copy_(param) training_wrapper = AudioLanguageModelTrainingWrapper.load_from_checkpoint( - args.ckpt_path, - model=model, - strict=False, + args.ckpt_path, + model=model, + strict=False, ema_copy=ema_copy, - optimizer_configs=training_config.get("optimizer_configs", None) + optimizer_configs=training_config.get("optimizer_configs", None), ) else: raise ValueError(f"Unknown model type {model_type}") - + print(f"Loaded model from {args.ckpt_path}") if args.use_safetensors: @@ -130,4 +149,4 @@ training_wrapper.export_model(ckpt_path, use_safetensors=args.use_safetensors) - print(f"Exported model to {ckpt_path}") \ No newline at end of file + print(f"Exported model to {ckpt_path}")