diff --git a/examples/audio/process_audio.py b/examples/audio/process_audio.py index 8657d53ef957..d34461937284 100644 --- a/examples/audio/process_audio.py +++ b/examples/audio/process_audio.py @@ -176,6 +176,7 @@ def main(cfg: ProcessConfig) -> ProcessConfig: raise RuntimeError('Model does not have a sampler') if cfg.audio_dir is not None: + input_dir = cfg.audio_dir filepaths = list(glob.glob(os.path.join(cfg.audio_dir, f"**/*.{cfg.audio_type}"), recursive=True)) else: # get filenames from manifest @@ -193,6 +194,15 @@ def main(cfg: ProcessConfig) -> ProcessConfig: audio_file = manifest_dir / audio_file filepaths.append(str(audio_file.absolute())) + # common path for all files + common_path = os.path.commonpath(filepaths) + if Path(common_path).is_relative_to(manifest_dir): + # if all paths are relative to the manifest, use manifest dir as input dir + input_dir = manifest_dir + else: + # use the parent of the common path as input dir + input_dir = Path(common_path).parent + if cfg.max_utts is not None: # Limit the number of utterances to process filepaths = filepaths[: cfg.max_utts] @@ -238,6 +248,7 @@ def autocast(): batch_size=cfg.batch_size, num_workers=cfg.num_workers, input_channel_selector=cfg.input_channel_selector, + input_dir=input_dir, ) logging.info(f"Finished processing {len(filepaths)} files!") diff --git a/nemo/collections/audio/models/audio_to_audio.py b/nemo/collections/audio/models/audio_to_audio.py index 60c16f756f58..57b4c9d48119 100644 --- a/nemo/collections/audio/models/audio_to_audio.py +++ b/nemo/collections/audio/models/audio_to_audio.py @@ -332,6 +332,7 @@ def process( batch_size: int = 1, num_workers: Optional[int] = None, input_channel_selector: Optional[ChannelSelectorType] = None, + input_dir: Optional[str] = None, ) -> List[str]: """ Takes paths to audio files and returns a list of paths to processed @@ -344,6 +345,7 @@ def process( num_workers: Number of workers for the dataloader input_channel_selector (int | Iterable[int] | str): select a single channel or a subset of channels from multi-channel audio. If set to `'average'`, it performs averaging across channels. Disabled if set to `None`. Defaults to `None`. + input_dir: Optional, directory that contains the input files. If provided, the output directory will mirror the input directory structure. Returns: Paths to processed audio signals. @@ -413,9 +415,17 @@ def process( for example_idx in range(processed_batch.size(0)): # This assumes the data loader is not shuffling files - file_name = os.path.basename(paths2audio_files[file_idx]) + if input_dir is not None: + # Make sure the output has the same directory structure as the input + filepath_relative = os.path.relpath(paths2audio_files[file_idx], start=input_dir) + else: + # Input dir is not provided, save files in the output directory + filepath_relative = os.path.basename(paths2audio_files[file_idx]) # Prepare output file - output_file = os.path.join(output_dir, f'processed_{file_name}') + output_file = os.path.join(output_dir, filepath_relative) + # Create output dir if necessary + if not os.path.isdir(os.path.dirname(output_file)): + os.makedirs(os.path.dirname(output_file)) # Crop the output signal to the actual length output_signal = processed_batch[example_idx, :, : input_length[example_idx]].cpu().numpy() # Write audio