Skip to content

Commit

Permalink
Unify pretrained model loading and avoid having to build MLX from sou…
Browse files Browse the repository at this point in the history
…rce.
  • Loading branch information
lucasnewman committed Sep 30, 2024
1 parent bc78eff commit 812b91f
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 53 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ audio = load_audio("audio.wav", 24_000)
reconstructed_audio = vocos(audio, bandwidth_id = 3)

# decode with encodec codes
codes = ...
decoded_audio = vocos.decode_from_codes(codes)
codes = vocos.feature_extractor.get_encodec_codes(audio, bandwidth_id = 3)
decoded_audio = vocos.decode_from_codes(codes, bandwidth_id = 3)
```

## Citations
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ requires = [
"huggingface_hub",
"mlx>=0.18.0",
"numpy",
"pyyaml",
"setuptools",
]
build-backend = "setuptools.build_meta"
Expand Down
2 changes: 1 addition & 1 deletion vocos_mlx/encodec.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,7 +678,7 @@ def chunk_stride(self):
return max(1, int((1.0 - self.config.overlap) * self.chunk_length))

@classmethod
def from_pretrained(cls, path_or_repo):
def from_pretrained(cls, path_or_repo: str):
"""
Load the model and audo preprocessor.
"""
Expand Down
166 changes: 116 additions & 50 deletions vocos_mlx/model.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,21 @@
from __future__ import annotations

import os
from functools import lru_cache

from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
import math
import os
from pathlib import Path
from typing import Any, List, Optional, Union
from types import SimpleNamespace

import mlx.core as mx
import mlx.nn as nn
import numpy as np

from huggingface_hub import snapshot_download
import yaml

from huggingface_hub import hf_hub_download

from vocos_mlx.encodec import EncodecModel


def instantiate_class(args: Union[Any, Tuple[Any, ...]], init: Dict[str, Any]) -> Any:
kwargs = init.get("init_args", {})
if not isinstance(args, tuple):
args = (args,)

if "." not in init["class_path"]:
class_name = init["class_path"]
args_class = globals()[class_name]
else:
class_module, class_name = init["class_path"].rsplit(".", 1)
module = __import__(class_module, fromlist=[class_name])
args_class = getattr(module, class_name)
return args_class(*args, **kwargs)


@lru_cache(maxsize=None)
def mel_filters(n_mels: int) -> mx.array:
"""
Expand Down Expand Up @@ -204,13 +188,16 @@ def __init__(

def get_encodec_codes(self, audio: mx.array, bandwidth_id: int) -> mx.array:
features, mask = self.preprocessor(audio)
codes, _ = self.encodec.encode(features, mask, bandwidth=self.bandwidths[bandwidth_id])
return codes
codes, _ = self.encodec.encode(
features, mask, bandwidth=self.bandwidths[bandwidth_id]
)
return mx.reshape(codes, (codes.shape[-2], 1, codes.shape[-1]))

def get_features_from_codes(self, codes: mx.array) -> mx.array:
codes = mx.reshape(codes, (codes.shape[-2], 1, codes.shape[-1]))
offsets = mx.arange(
0, self.encodec.quantizer.codebook_size * codes.shape[0], self.encodec.quantizer.codebook_size
0,
self.encodec.quantizer.codebook_size * codes.shape[0],
self.encodec.quantizer.codebook_size,
)
embeddings_idxs = codes + mx.reshape(offsets, (offsets.shape[0], 1, 1))
embeddings = self.codebook_weights[embeddings_idxs]
Expand Down Expand Up @@ -262,7 +249,7 @@ def __init__(
super().__init__()

# depthwise conv
self.dwconv = nn.Conv1d(dim, dim, kernel_size=7, padding=3, groups=dim)
self.dwconv = GroupableConv1d(dim, dim, kernel_size=7, padding=3, groups=dim)
self.adanorm = adanorm_num_embeddings is not None
if adanorm_num_embeddings:
self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)
Expand Down Expand Up @@ -312,6 +299,81 @@ def __call__(self, x: mx.array, cond_embedding_id: mx.array) -> mx.array:
return x


class GroupableConv1d(nn.Module):
"""Applies a 1-dimensional convolution over the multi-channel input sequence.
The channels are expected to be last i.e. the input shape should be ``NLC`` where:
* ``N`` is the batch dimension
* ``L`` is the sequence length
* ``C`` is the number of input channels
Args:
in_channels (int): The number of input channels
out_channels (int): The number of output channels
kernel_size (int): The size of the convolution filters
stride (int, optional): The stride when applying the filter.
Default: ``1``.
padding (int, optional): How many positions to 0-pad the input with.
Default: ``0``.
dilation (int, optional): The dilation of the convolution.
groups (int, optional): The number of groups for the convolution.
Default: ``1``.
bias (bool, optional): If ``True`` add a learnable bias to the output.
Default: ``True``
"""

def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
padding: int = 0,
dilation: int = 1,
groups: int = 1,
bias: bool = True,
):
super().__init__()

if in_channels % groups != 0:
raise ValueError(
f"The number of input channels ({in_channels}) must be "
f"divisible by the number of groups ({groups})"
)

scale = math.sqrt(1 / (in_channels * kernel_size))
self.weight = mx.random.uniform(
low=-scale,
high=scale,
shape=(out_channels, kernel_size, in_channels // groups),
)
if bias:
self.bias = mx.zeros((out_channels,))

self.padding = padding
self.dilation = dilation
self.stride = stride
self.groups = groups

def _extra_repr(self):
return (
f"{self.weight.shape[-1]}, {self.weight.shape[0]}, "
f"kernel_size={self.weight.shape[1]}, stride={self.stride}, "
f"padding={self.padding}, dilation={self.dilation}, "
f"groups={self.groups}, "
f"bias={'bias' in self}"
)

def __call__(self, x):
y = mx.conv1d(
x, self.weight, self.stride, self.padding, self.dilation, self.groups
)
if "bias" in self:
y = y + self.bias
return y


class VocosBackbone(nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -376,36 +438,40 @@ def from_hparams(cls, config_path: str) -> Vocos:
Class method to create a new Vocos model instance from hyperparameters stored in a yaml configuration file.
"""
with open(config_path, "r") as f:
config = yaml.safe_load(f)

# remap the class paths
if "MelSpectrogramFeatures" in config["feature_extractor"]["class_path"]:
config["feature_extractor"]["class_path"] = "MelSpectrogramFeatures"
elif "EncodecFeatures" in config["feature_extractor"]["class_path"]:
config["feature_extractor"]["class_path"] = "EncodecFeatures"
config["backbone"]["class_path"] = "VocosBackbone"
config["head"]["class_path"] = "ISTFTHead"

feature_extractor = instantiate_class(args=(), init=config["feature_extractor"])
backbone = instantiate_class(args=(), init=config["backbone"])
head = instantiate_class(args=(), init=config["head"])
config = SimpleNamespace(**yaml.safe_load(f))

if "MelSpectrogramFeatures" in config.feature_extractor["class_path"]:
feature_extractor = MelSpectrogramFeatures(
**config.feature_extractor["init_args"]
)
elif "EncodecFeatures" in config.feature_extractor["class_path"]:
feature_extractor = EncodecFeatures(**config.feature_extractor["init_args"])
backbone = VocosBackbone(**config.backbone["init_args"])
head = ISTFTHead(**config.head["init_args"])
model = cls(feature_extractor=feature_extractor, backbone=backbone, head=head)
return model

@classmethod
def from_pretrained(cls, repo_id: str, revision: Optional[str] = None) -> Vocos:
def from_pretrained(cls, path_or_repo: str) -> Vocos:
"""
Class method to create a new Vocos model instance from a pre-trained model stored in the Hugging Face model hub.
"""
config_path = hf_hub_download(
repo_id=repo_id, filename="config.yaml", revision=revision
)

path = Path(path_or_repo)
if not path.exists():
path = Path(
snapshot_download(
repo_id=path_or_repo,
allow_patterns=["*.yaml", "*.safetensors"],
)
)

config_path = path / "config.yaml"
model = cls.from_hparams(config_path)

model_path = hf_hub_download(
repo_id=repo_id, filename="model.safetensors", revision=revision
)
weights = mx.load(model_path)
model_path = path / "model.safetensors"
with open(model_path, "rb") as f:
weights = mx.load(f)

# remove unused weights
try:
Expand All @@ -428,7 +494,7 @@ def from_pretrained(cls, repo_id: str, revision: Optional[str] = None) -> Vocos:
new_weights[k] = v

# use strict = False to avoid the encodec weights
model.load_weights(list(new_weights.items()), strict = False)
model.load_weights(list(new_weights.items()), strict=False)
model.eval()

return model
Expand Down

0 comments on commit 812b91f

Please sign in to comment.