diff --git a/setup.py b/setup.py index 7e7470d3..e330482a 100644 --- a/setup.py +++ b/setup.py @@ -8,6 +8,7 @@ description='Training and inference tools for generative audio models from Stability AI', packages=find_packages(), install_requires=[ + 'dill', 'aeiou==0.0.20', 'alias-free-torch==0.0.6', 'auraloss==0.4.0', @@ -41,4 +42,4 @@ 'webdataset==0.2.48', 'x-transformers<1.27.0' ], -) \ No newline at end of file +) diff --git a/stable_audio_tools/data/dataset.py b/stable_audio_tools/data/dataset.py index 4bc535a1..bc7c0c88 100644 --- a/stable_audio_tools/data/dataset.py +++ b/stable_audio_tools/data/dataset.py @@ -1,4 +1,5 @@ import importlib +import dill import numpy as np import io import os @@ -155,7 +156,7 @@ def __init__( self.root_paths.append(config.path) self.filenames.extend(get_audio_filenames(config.path, keywords)) if config.custom_metadata_fn is not None: - self.custom_metadata_fns[config.path] = config.custom_metadata_fn + self.custom_metadata_fns[config.path] = dill.dumps(config.custom_metadata_fn) print(f'Found {len(self.filenames)} files') @@ -216,8 +217,8 @@ def __getitem__(self, idx): for custom_md_path in self.custom_metadata_fns.keys(): if custom_md_path in audio_filename: - custom_metadata_fn = self.custom_metadata_fns[custom_md_path] - custom_metadata = custom_metadata_fn(info, audio) + custom_metadata_fn_deserialized = dill.loads(self.custom_metadata_fns[custom_md_path]) + custom_metadata = custom_metadata_fn_deserialized(info, audio) info.update(custom_metadata) if "__reject__" in info and info["__reject__"]: @@ -651,4 +652,4 @@ def create_dataloader_from_config(dataset_config, batch_size, sample_size, sampl persistent_workers=True, force_channels=force_channels, epoch_steps=dataset_config.get("epoch_steps", 2000) - ).data_loader \ No newline at end of file + ).data_loader