Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pre-commit config for black and isort #124

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions .github/workflows/pre-commit.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
name: pre-commit

on:
pull_request:
push:
branches: [main]

jobs:
pre-commit:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
cache: pip
- uses: pre-commit/[email protected]
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -161,4 +161,4 @@ cython_debug/

*.ckpt
*.wav
wandb/*
wandb/*
20 changes: 20 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# See https://pre-commit.com for more information
# See https://pre-commit.com/hooks.html for more hooks
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
- id: check-added-large-files

- repo: https://github.com/psf/black
rev: 24.4.2
hooks:
- id: black

- repo: https://github.com/PyCQA/isort
rev: 5.13.2
hooks:
- id: isort
2 changes: 1 addition & 1 deletion LICENSES/LICENSE_ADP.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
SOFTWARE.
2 changes: 1 addition & 1 deletion LICENSES/LICENSE_AURALOSS.txt
Original file line number Diff line number Diff line change
Expand Up @@ -198,4 +198,4 @@
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.
2 changes: 1 addition & 1 deletion LICENSES/LICENSE_DESCRIPT.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
SOFTWARE.
2 changes: 1 addition & 1 deletion LICENSES/LICENSE_META.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
SOFTWARE.
4 changes: 2 additions & 2 deletions LICENSES/LICENSE_NVIDIA.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@ copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
SOFTWARE.
2 changes: 1 addition & 1 deletion LICENSES/LICENSE_XTRANSFORMERS.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
SOFTWARE.
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ Development for the repo is done in Python 3.8.10

# Interface

A basic Gradio interface is provided to test out trained models.
A basic Gradio interface is provided to test out trained models.

For example, to create an interface for the [`stable-audio-open-1.0`](https://huggingface.co/stabilityai/stable-audio-open-1.0) model, once you've accepted the terms for the model on Hugging Face, you can run:
```bash
Expand All @@ -37,7 +37,7 @@ The `run_gradio.py` script accepts the following command line arguments:
- Path to the model config file for a local model
- `--ckpt-path`
- Path to unwrapped model checkpoint file for a local model
- `--pretransform-ckpt-path`
- `--pretransform-ckpt-path`
- Path to an unwrapped pretransform checkpoint, replaces the pretransform in the model, useful for testing out fine-tuned decoders
- Optional
- `--share`
Expand Down Expand Up @@ -69,7 +69,7 @@ $ python3 ./train.py --dataset-config /path/to/dataset/config --model-config /pa
The `--name` parameter will set the project name for your Weights and Biases run.

## Training wrappers and model unwrapping
`stable-audio-tools` uses PyTorch Lightning to facilitate multi-GPU and multi-node training.
`stable-audio-tools` uses PyTorch Lightning to facilitate multi-GPU and multi-node training.

When a model is being trained, it is wrapped in a "training wrapper", which is a `pl.LightningModule` that contains all of the relevant objects needed only for training. That includes things like discriminators for autoencoders, EMA copies of models, and all of the optimizer states.

Expand All @@ -88,7 +88,7 @@ Unwrapped model checkpoints are required for:
- Fine-tuning a pre-trained model with a modified configuration (i.e. partial initialization)

## Fine-tuning
Fine-tuning a model involves continuning a training run from a pre-trained checkpoint.
Fine-tuning a model involves continuning a training run from a pre-trained checkpoint.

To continue a training run from a wrapped model checkpoint, you can pass in the checkpoint path to `train.py` with the `--ckpt-path` flag.

Expand Down Expand Up @@ -154,4 +154,4 @@ The following properties are defined in the top level of the model configuration

# Todo
- [ ] Add troubleshooting section
- [ ] Add contribution guidelines
- [ ] Add contribution guidelines
12 changes: 6 additions & 6 deletions defaults.ini
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
name = stable_audio_tools

# the batch size
batch_size = 8
batch_size = 8

# number of GPUs to use for training
num_gpus = 1
num_gpus = 1

# number of nodes to use for training
num_nodes = 1
num_nodes = 1

# Multi-GPU strategy for PyTorch Lightning
strategy = ""
Expand All @@ -29,8 +29,8 @@ seed = 42
accum_batches = 1

# Number of steps between checkpoints
checkpoint_every = 10000
checkpoint_every = 10000

# trainer checkpoint file to restart training from
ckpt_path = ''

Expand All @@ -53,4 +53,4 @@ save_dir = ''
gradient_clip_val = 0.0

# remove the weight norm from the pretransform model
remove_pretransform_weight_norm = ''
remove_pretransform_weight_norm = ''
18 changes: 9 additions & 9 deletions docs/autoencoders.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Autoencoders
At a high level, autoencoders are models constructed of two parts: an *encoder*, and a *decoder*.
At a high level, autoencoders are models constructed of two parts: an *encoder*, and a *decoder*.

The *encoder* takes in an sequence (such as mono or stereo audio) and outputs a compressed representation of that sequence as a d-channel "latent sequence", usually heavily downsampled by a constant factor.

Expand Down Expand Up @@ -41,7 +41,7 @@ The `training` config in the autoencoder model config file should have the follo
- `learning_rate`
- The learning rate to use during training
- `use_ema`
- If true, a copy of the model weights is maintained during training and updated as an exponential moving average of the trained model's weights.
- If true, a copy of the model weights is maintained during training and updated as an exponential moving average of the trained model's weights.
- Optional. Default: `false`
- `warmup_steps`
- The number of training steps before turning on adversarial losses
Expand All @@ -62,11 +62,11 @@ There are few different types of losses that are used for autoencoder training,
Hyperparameters fo these losses as well as loss weighting factors can be configured in the `loss_configs` property in the `training` config.

### Spectral losses
Multi-resolution STFT losses are the main reconstruction loss used for our audio autoencoders. We use the [auraloss](https://github.com/csteinmetz1/auraloss/tree/main/auraloss) library for our spectral loss functions.
Multi-resolution STFT losses are the main reconstruction loss used for our audio autoencoders. We use the [auraloss](https://github.com/csteinmetz1/auraloss/tree/main/auraloss) library for our spectral loss functions.

For mono autoencoders (`io_channels` == 1), we use the [MultiResolutionSTFTLoss](https://github.com/csteinmetz1/auraloss/blob/1576b0cd6e927abc002b23cf3bfc455b660f663c/auraloss/freq.py#L329) module.
For mono autoencoders (`io_channels` == 1), we use the [MultiResolutionSTFTLoss](https://github.com/csteinmetz1/auraloss/blob/1576b0cd6e927abc002b23cf3bfc455b660f663c/auraloss/freq.py#L329) module.

For stereo autoencoders (`io_channels` == 2), we use the [SumAndDifferenceSTFTLoss](https://github.com/csteinmetz1/auraloss/blob/1576b0cd6e927abc002b23cf3bfc455b660f663c/auraloss/freq.py#L533) module.
For stereo autoencoders (`io_channels` == 2), we use the [SumAndDifferenceSTFTLoss](https://github.com/csteinmetz1/auraloss/blob/1576b0cd6e927abc002b23cf3bfc455b660f663c/auraloss/freq.py#L533) module.

#### Example config
```json
Expand Down Expand Up @@ -130,7 +130,7 @@ The only property to set for autoencoder training demos is the `demo_every` prop
```

# Encoder and decoder types
Encoders and decoders are defined separately in the model configuration, so encoders and decoders from different model architectures and libraries can be used interchangeably.
Encoders and decoders are defined separately in the model configuration, so encoders and decoders from different model architectures and libraries can be used interchangeably.

## Oobleck
Oobleck is Harmonai's in-house autoencoder architecture, implementing features from a variety of other autoencoder architectures.
Expand Down Expand Up @@ -229,7 +229,7 @@ In our terminology, the "bottleneck" of an autoencoder is a module placed betwee

Bottlenecks have a similar interface to the autoencoder with `encode()` and `decode()` functions defined. Some bottlenecks return extra information in addition to the output latent series, such as quantized token indices, or additional losses to be considered during training.

To define a bottleneck for the autoencoder, you can provide the `bottleneck` object in the autoencoder's model configuration, with the following
To define a bottleneck for the autoencoder, you can provide the `bottleneck` object in the autoencoder's model configuration, with the following

## VAE

Expand Down Expand Up @@ -311,7 +311,7 @@ Residual vector quantization (RVQ) is currently the leading method for learning

This RVQ bottleneck uses [lucidrains' implementation](https://github.com/lucidrains/vector-quantize-pytorch/tree/master) from the `vector-quantize-pytorch` repo, which provides a lot of different quantizer options. The bottleneck config is passed through to the `ResidualVQ` [constructor](https://github.com/lucidrains/vector-quantize-pytorch/blob/0c6cea24ce68510b607f2c9997e766d9d55c085b/vector_quantize_pytorch/residual_vq.py#L26).

**Note: This RVQ implementation uses manual replacement of codebook vectors to reduce codebook collapse. This does not work with multi-GPU training as the random replacement is not synchronized across devices.**
**Note: This RVQ implementation uses manual replacement of codebook vectors to reduce codebook collapse. This does not work with multi-GPU training as the random replacement is not synchronized across devices.**

### Example config
```json
Expand All @@ -327,7 +327,7 @@ This RVQ bottleneck uses [lucidrains' implementation](https://github.com/lucidra
```

## DAC RVQ
This is the residual vector quantization implementation from the `descript-audio-codec` repo. It differs from the above implementation in that it does not use manual replacements to improve codebook usage, but instead uses learnable linear layers to project the latents down to a lower-dimensional space before performing the individual quantization operations. This means it's compatible with distributed training.
This is the residual vector quantization implementation from the `descript-audio-codec` repo. It differs from the above implementation in that it does not use manual replacements to improve codebook usage, but instead uses learnable linear layers to project the latents down to a lower-dimensional space before performing the individual quantization operations. This means it's compatible with distributed training.

The bottleneck config is passed directly into the `ResidualVectorQuantize` [constructor](https://github.com/descriptinc/descript-audio-codec/blob/c7cfc5d2647e26471dc394f95846a0830e7bec34/dac/nn/quantize.py#L97).

Expand Down
8 changes: 4 additions & 4 deletions docs/conditioning.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,15 @@ Input concatenation applies a spatial conditioning signal to the model that corr
Signals used for input concatenation conditioning should be of the shape `[batch, channels, sequence]` and must be the same length as the model's input.

# Conditioners and conditioning configs
`stable-audio-tools` uses Conditioner modules to translate human-readable metadata such as text prompts or a number of seconds into tensors that the model can take as input.
`stable-audio-tools` uses Conditioner modules to translate human-readable metadata such as text prompts or a number of seconds into tensors that the model can take as input.

Each conditioner has a corresponding `id` that it expects to find in the conditioning dictionary provided during training or inference. Each conditioner takes in the relevant conditioning data and returns a tuple containing the corresponding tensor and a mask.

The ConditionedDiffusionModelWrapper manages the translation between the user-provided metadata dictionary (e.g. `{"prompt": "a beautiful song", "seconds_start": 22, "seconds_total": 193}`) and the dictionary of different conditioning types that the model uses (e.g. `{"cross_attn_cond": ...}`).

To apply conditioning to a model, you must provide a `conditioning` configuration in the model's config. At the moment, we only support conditioning diffusion models though the `diffusion_cond` model type.

The `conditioning` configuration should contain a `configs` array, which allows you to define multiple conditioning signals.
The `conditioning` configuration should contain a `configs` array, which allows you to define multiple conditioning signals.

Each item in `configs` array should define the `id` for the corresponding metadata, the type of conditioner to be used, and the config for that conditioner.

Expand Down Expand Up @@ -74,7 +74,7 @@ If you set `enable_grad` to `true`, the T5 model will be un-frozen and saved wit

T5 encodings are only compatible with cross attention conditioning.

#### Example config
#### Example config
```json
{
"id": "prompt",
Expand Down Expand Up @@ -155,4 +155,4 @@ Number embeddings are compatible with global conditioning and cross attention co
"max_val": 512
}
}
```
```
8 changes: 4 additions & 4 deletions docs/datasets.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ To use a local directory of audio samples, set the `dataset_type` property in yo

This will load all of the compatible audio files from the provided directory and all subdirectories.

### Example config
### Example config
```json
{
"dataset_type": "audio_dir",
Expand Down Expand Up @@ -43,9 +43,9 @@ To load audio files and related metadata from .tar files in the WebDataset forma
```

# Custom metadata
To customize the metadata provided to the conditioners during model training, you can provide a separate custom metadata module to the dataset config. This metadata module should be a Python file that must contain a function called `get_custom_metadata` that takes in two parameters, `info`, and `audio`, and returns a dictionary.
To customize the metadata provided to the conditioners during model training, you can provide a separate custom metadata module to the dataset config. This metadata module should be a Python file that must contain a function called `get_custom_metadata` that takes in two parameters, `info`, and `audio`, and returns a dictionary.

For local training, the `info` parameter will contain a few pieces of information about the loaded audio file, such as the path, and information about how the audio was cropped from the original training sample. For WebDataset datasets, it will also contain the metadata from the related JSON files.
For local training, the `info` parameter will contain a few pieces of information about the loaded audio file, such as the path, and information about how the audio was cropped from the original training sample. For WebDataset datasets, it will also contain the metadata from the related JSON files.

The `audio` parameter contains the audio sample that will be passed to the model at training time. This lets you analyze the audio for extra properties that you can then pass in as extra conditioning signals.

Expand All @@ -72,4 +72,4 @@ def get_custom_metadata(info, audio):

# Pass in the relative path of the audio file as the prompt
return {"prompt": info["relpath"]}
```
```
8 changes: 4 additions & 4 deletions docs/diffusion.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ The `training` config in the diffusion model config file should have the followi
- The learning rate to use during training
- Defaults to constant learning rate, can be overridden with `optimizer_configs`
- `use_ema`
- If true, a copy of the model weights is maintained during training and updated as an exponential moving average of the trained model's weights.
- If true, a copy of the model weights is maintained during training and updated as an exponential moving average of the trained model's weights.
- Optional. Default: `true`
- `log_loss_info`
- If true, additional diffusion loss info will be gathered across all GPUs and displayed during training
Expand Down Expand Up @@ -138,16 +138,16 @@ This is our custom implementation of a transformer model, based on the `x-transf
This model type uses the `ContinuousTransformerWrapper` class from the https://github.com/lucidrains/x-transformers repository as the diffusion transformer backbone.

`x-transformers` is a great baseline transformer implementation with lots of options for various experimental settings.
It's great for testing out experimental features without implementing them yourself, but the implementations might not be fully optimized, and breaking changes may be introduced without much warning.
It's great for testing out experimental features without implementing them yourself, but the implementations might not be fully optimized, and breaking changes may be introduced without much warning.

## Diffusion U-Net

U-Nets use a hierarchical architecture to gradually downsample the input data before more heavy processing is performed, then upsample the data again, using skip connections to pass data across the downsampling "valley" (the "U" in the name) to the upsampling layer at the same resolution.
U-Nets use a hierarchical architecture to gradually downsample the input data before more heavy processing is performed, then upsample the data again, using skip connections to pass data across the downsampling "valley" (the "U" in the name) to the upsampling layer at the same resolution.

### audio-diffusion-pytorch U-Net (ADP)

This model type uses a modified implementation of the `UNetCFG1D` class from version 0.0.94 of the `https://github.com/archinetai/audio-diffusion-pytorch` repo, with added Flash Attention support.

### Dance Diffusion U-Net

This is a reimplementation of the U-Net used in [Dance Diffusion](https://github.com/Harmonai-org/sample-generator). It has minimal conditioning support, only really supporting global conditioning. Mostly used for unconditional diffusion models.
This is a reimplementation of the U-Net used in [Dance Diffusion](https://github.com/Harmonai-org/sample-generator). It has minimal conditioning support, only really supporting global conditioning. Mostly used for unconditional diffusion models.
Loading