From 812b91f734485d4ce9b0e4875fa3be5d64dad568 Mon Sep 17 00:00:00 2001 From: Lucas Newman Date: Mon, 30 Sep 2024 11:57:43 -0700 Subject: [PATCH] Unify pretrained model loading and avoid having to build MLX from source. --- README.md | 4 +- pyproject.toml | 1 + vocos_mlx/encodec.py | 2 +- vocos_mlx/model.py | 166 ++++++++++++++++++++++++++++++------------- 4 files changed, 120 insertions(+), 53 deletions(-) diff --git a/README.md b/README.md index 3e799a0..862cdec 100644 --- a/README.md +++ b/README.md @@ -43,8 +43,8 @@ audio = load_audio("audio.wav", 24_000) reconstructed_audio = vocos(audio, bandwidth_id = 3) # decode with encodec codes -codes = ... -decoded_audio = vocos.decode_from_codes(codes) +codes = vocos.feature_extractor.get_encodec_codes(audio, bandwidth_id = 3) +decoded_audio = vocos.decode_from_codes(codes, bandwidth_id = 3) ``` ## Citations diff --git a/pyproject.toml b/pyproject.toml index 6736181..20ca3f9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,6 +3,7 @@ requires = [ "huggingface_hub", "mlx>=0.18.0", "numpy", + "pyyaml", "setuptools", ] build-backend = "setuptools.build_meta" diff --git a/vocos_mlx/encodec.py b/vocos_mlx/encodec.py index bdeb364..2b661ef 100644 --- a/vocos_mlx/encodec.py +++ b/vocos_mlx/encodec.py @@ -678,7 +678,7 @@ def chunk_stride(self): return max(1, int((1.0 - self.config.overlap) * self.chunk_length)) @classmethod - def from_pretrained(cls, path_or_repo): + def from_pretrained(cls, path_or_repo: str): """ Load the model and audo preprocessor. """ diff --git a/vocos_mlx/model.py b/vocos_mlx/model.py index 0544ab2..de44c1d 100644 --- a/vocos_mlx/model.py +++ b/vocos_mlx/model.py @@ -1,37 +1,21 @@ from __future__ import annotations - -import os from functools import lru_cache - -from typing import Any, Dict, List, Optional, Tuple, Union - -import numpy as np +import math +import os +from pathlib import Path +from typing import Any, List, Optional, Union +from types import SimpleNamespace import mlx.core as mx import mlx.nn as nn +import numpy as np +from huggingface_hub import snapshot_download import yaml -from huggingface_hub import hf_hub_download - from vocos_mlx.encodec import EncodecModel -def instantiate_class(args: Union[Any, Tuple[Any, ...]], init: Dict[str, Any]) -> Any: - kwargs = init.get("init_args", {}) - if not isinstance(args, tuple): - args = (args,) - - if "." not in init["class_path"]: - class_name = init["class_path"] - args_class = globals()[class_name] - else: - class_module, class_name = init["class_path"].rsplit(".", 1) - module = __import__(class_module, fromlist=[class_name]) - args_class = getattr(module, class_name) - return args_class(*args, **kwargs) - - @lru_cache(maxsize=None) def mel_filters(n_mels: int) -> mx.array: """ @@ -204,13 +188,16 @@ def __init__( def get_encodec_codes(self, audio: mx.array, bandwidth_id: int) -> mx.array: features, mask = self.preprocessor(audio) - codes, _ = self.encodec.encode(features, mask, bandwidth=self.bandwidths[bandwidth_id]) - return codes + codes, _ = self.encodec.encode( + features, mask, bandwidth=self.bandwidths[bandwidth_id] + ) + return mx.reshape(codes, (codes.shape[-2], 1, codes.shape[-1])) def get_features_from_codes(self, codes: mx.array) -> mx.array: - codes = mx.reshape(codes, (codes.shape[-2], 1, codes.shape[-1])) offsets = mx.arange( - 0, self.encodec.quantizer.codebook_size * codes.shape[0], self.encodec.quantizer.codebook_size + 0, + self.encodec.quantizer.codebook_size * codes.shape[0], + self.encodec.quantizer.codebook_size, ) embeddings_idxs = codes + mx.reshape(offsets, (offsets.shape[0], 1, 1)) embeddings = self.codebook_weights[embeddings_idxs] @@ -262,7 +249,7 @@ def __init__( super().__init__() # depthwise conv - self.dwconv = nn.Conv1d(dim, dim, kernel_size=7, padding=3, groups=dim) + self.dwconv = GroupableConv1d(dim, dim, kernel_size=7, padding=3, groups=dim) self.adanorm = adanorm_num_embeddings is not None if adanorm_num_embeddings: self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6) @@ -312,6 +299,81 @@ def __call__(self, x: mx.array, cond_embedding_id: mx.array) -> mx.array: return x +class GroupableConv1d(nn.Module): + """Applies a 1-dimensional convolution over the multi-channel input sequence. + + The channels are expected to be last i.e. the input shape should be ``NLC`` where: + + * ``N`` is the batch dimension + * ``L`` is the sequence length + * ``C`` is the number of input channels + + Args: + in_channels (int): The number of input channels + out_channels (int): The number of output channels + kernel_size (int): The size of the convolution filters + stride (int, optional): The stride when applying the filter. + Default: ``1``. + padding (int, optional): How many positions to 0-pad the input with. + Default: ``0``. + dilation (int, optional): The dilation of the convolution. + groups (int, optional): The number of groups for the convolution. + Default: ``1``. + bias (bool, optional): If ``True`` add a learnable bias to the output. + Default: ``True`` + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + dilation: int = 1, + groups: int = 1, + bias: bool = True, + ): + super().__init__() + + if in_channels % groups != 0: + raise ValueError( + f"The number of input channels ({in_channels}) must be " + f"divisible by the number of groups ({groups})" + ) + + scale = math.sqrt(1 / (in_channels * kernel_size)) + self.weight = mx.random.uniform( + low=-scale, + high=scale, + shape=(out_channels, kernel_size, in_channels // groups), + ) + if bias: + self.bias = mx.zeros((out_channels,)) + + self.padding = padding + self.dilation = dilation + self.stride = stride + self.groups = groups + + def _extra_repr(self): + return ( + f"{self.weight.shape[-1]}, {self.weight.shape[0]}, " + f"kernel_size={self.weight.shape[1]}, stride={self.stride}, " + f"padding={self.padding}, dilation={self.dilation}, " + f"groups={self.groups}, " + f"bias={'bias' in self}" + ) + + def __call__(self, x): + y = mx.conv1d( + x, self.weight, self.stride, self.padding, self.dilation, self.groups + ) + if "bias" in self: + y = y + self.bias + return y + + class VocosBackbone(nn.Module): def __init__( self, @@ -376,36 +438,40 @@ def from_hparams(cls, config_path: str) -> Vocos: Class method to create a new Vocos model instance from hyperparameters stored in a yaml configuration file. """ with open(config_path, "r") as f: - config = yaml.safe_load(f) - - # remap the class paths - if "MelSpectrogramFeatures" in config["feature_extractor"]["class_path"]: - config["feature_extractor"]["class_path"] = "MelSpectrogramFeatures" - elif "EncodecFeatures" in config["feature_extractor"]["class_path"]: - config["feature_extractor"]["class_path"] = "EncodecFeatures" - config["backbone"]["class_path"] = "VocosBackbone" - config["head"]["class_path"] = "ISTFTHead" - - feature_extractor = instantiate_class(args=(), init=config["feature_extractor"]) - backbone = instantiate_class(args=(), init=config["backbone"]) - head = instantiate_class(args=(), init=config["head"]) + config = SimpleNamespace(**yaml.safe_load(f)) + + if "MelSpectrogramFeatures" in config.feature_extractor["class_path"]: + feature_extractor = MelSpectrogramFeatures( + **config.feature_extractor["init_args"] + ) + elif "EncodecFeatures" in config.feature_extractor["class_path"]: + feature_extractor = EncodecFeatures(**config.feature_extractor["init_args"]) + backbone = VocosBackbone(**config.backbone["init_args"]) + head = ISTFTHead(**config.head["init_args"]) model = cls(feature_extractor=feature_extractor, backbone=backbone, head=head) return model @classmethod - def from_pretrained(cls, repo_id: str, revision: Optional[str] = None) -> Vocos: + def from_pretrained(cls, path_or_repo: str) -> Vocos: """ Class method to create a new Vocos model instance from a pre-trained model stored in the Hugging Face model hub. """ - config_path = hf_hub_download( - repo_id=repo_id, filename="config.yaml", revision=revision - ) + + path = Path(path_or_repo) + if not path.exists(): + path = Path( + snapshot_download( + repo_id=path_or_repo, + allow_patterns=["*.yaml", "*.safetensors"], + ) + ) + + config_path = path / "config.yaml" model = cls.from_hparams(config_path) - model_path = hf_hub_download( - repo_id=repo_id, filename="model.safetensors", revision=revision - ) - weights = mx.load(model_path) + model_path = path / "model.safetensors" + with open(model_path, "rb") as f: + weights = mx.load(f) # remove unused weights try: @@ -428,7 +494,7 @@ def from_pretrained(cls, repo_id: str, revision: Optional[str] = None) -> Vocos: new_weights[k] = v # use strict = False to avoid the encodec weights - model.load_weights(list(new_weights.items()), strict = False) + model.load_weights(list(new_weights.items()), strict=False) model.eval() return model