Skip to content

Commit

Permalink
Updates.
Browse files Browse the repository at this point in the history
  • Loading branch information
lucasnewman committed Oct 6, 2024
1 parent 19aa238 commit 3b30ee0
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 1 deletion.
36 changes: 36 additions & 0 deletions .github/workflows/python-publish.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# This workflow will upload a Python Package using Twine when a release is created
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries

# This workflow uses actions that are not certified by GitHub.
# They are provided by a third-party and are governed by
# separate terms of service, privacy policy, and support
# documentation.

name: Upload Python Package

on:
release:
types: [published]

jobs:
deploy:

runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.x'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install build
- name: Build package
run: python -m build
- name: Publish package
uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
with:
user: __token__
password: ${{ secrets.PYPI_API_TOKEN }}
11 changes: 10 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@ Implementation of E2-TTS, [Embarrassingly Easy Fully Non-Autoregressive Zero-Sho

This implementation is based on the [lucidrains implementation](https://github.com/lucidrains/e2-tts-pytorch) in Pytorch, which differs from the paper in that it uses a [multistream transformer](https://arxiv.org/abs/2107.10342) for text and audio, with conditioning done every transformer block.

## Installation

```bash
pip install mlx-e2-tts
```

## Usage

```python
Expand All @@ -12,6 +18,7 @@ import mlx.core as mx

from e2_tts_mlx.model import E2TTS
from e2_tts_mlx.trainer import E2Trainer
from e2_tts_mlx.data import load_libritts_r

e2tts = E2TTS(
tokenizer="char-utf8", # or "phoneme_en" for phoneme-based tokenization
Expand All @@ -33,13 +40,15 @@ mx.eval(e2tts.parameters())
batch_size = 128
max_duration = 30

dataset = load_libritts_r(split="dev-clean", max_duration = max_duration)
dataset = load_libritts_r(split="dev-clean", max_duration = max_duration) # or any other audio/caption data set

trainer = E2Trainer(model = e2tts, num_warmup_steps = 1000)
trainer.train(train_dataset = dataset, learning_rate = 7.5e-5, batch_size = batch_size)

```

Note the model size specified above (from the paper) is very large. See `train_example.py` for a more practical-sized model you can train on your local device.

## Appreciation

[lucidrains](https://github.com/lucidrains) for the original implementation in Pytorch.
Expand Down
43 changes: 43 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
[build-system]
requires = [
"einops",
"g2p-en",
"huggingface_hub",
"matplotlib",
"mlx",
"numpy",
"pyyaml",
"setuptools",
]
build-backend = "setuptools.build_meta"

[project]
name = "mlx-e2-tts"
version = "0.0.1"
authors = [{name = "Lucas Newman", email = "[email protected]"}]
license = {text = "MIT"}
description = "E2-TTS - MLX"
readme = "README.md"
keywords = [
"artificial intelligence",
"asr",
"audio-generation",
"deep learning",
"transformers",
"text-to-speech"
]
classifiers = [
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3.9",
]
requires-python = ">=3.9"
dependencies = ['setuptools; python_version>="3.9"']

[project.urls]
Homepage = "https://github.com/lucasnewman/e2-tts-mlx"

[tool.setuptools]
packages = ["e2_tts_mlx"]
40 changes: 40 additions & 0 deletions train_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# This sample will download the "dev-clean" split of the LibriTTS dataset
# and train the model for 100k steps with 1k steps of warmup.

import mlx.core as mx
from mlx.utils import tree_flatten, tree_map

from e2_tts_mlx.model import E2TTS
from e2_tts_mlx.trainer import E2Trainer
from e2_tts_mlx.data import load_libritts_r

e2tts = E2TTS(
tokenizer="phoneme_en",
cond_drop_prob = 0.0,
frac_lengths_mask = (0.7, 0.9),
transformer = dict(
dim = 384,
depth = 12,
heads = 8,
text_depth = 4,
text_heads = 8,
text_ff_mult = 2,
max_seq_len = 1024,
dropout = 0.1
)
)

# cast parameters to float16
e2tts.update(tree_map(lambda p: p.astype(mx.float16), e2tts.parameters()))

mx.eval(e2tts.parameters())

num_trainable_params = sum([p[1].size for p in tree_flatten(e2tts.trainable_parameters())])
print(f"Using {num_trainable_params:,} trainable parameters.")

batch_size = 4 # adjust based on available memory
max_duration = 10
dataset = load_libritts_r(split="dev-clean", max_duration=max_duration)

trainer = E2Trainer(model = e2tts, num_warmup_steps=1000, max_grad_norm=1)
trainer.train(train_dataset=dataset, learning_rate=1e-4, log_every=10, plot_every=100, total_steps=100_000, batch_size=batch_size)

0 comments on commit 3b30ee0

Please sign in to comment.