Skip to content

Commit

Permalink
[audio] Keep input directory structure when saving processed files (N…
Browse files Browse the repository at this point in the history
…VIDIA#11403)

Signed-off-by: Ante Jukić <[email protected]>
  • Loading branch information
anteju authored Dec 2, 2024
1 parent 8c921dc commit f17c418
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 2 deletions.
11 changes: 11 additions & 0 deletions examples/audio/process_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ def main(cfg: ProcessConfig) -> ProcessConfig:
raise RuntimeError('Model does not have a sampler')

if cfg.audio_dir is not None:
input_dir = cfg.audio_dir
filepaths = list(glob.glob(os.path.join(cfg.audio_dir, f"**/*.{cfg.audio_type}"), recursive=True))
else:
# get filenames from manifest
Expand All @@ -193,6 +194,15 @@ def main(cfg: ProcessConfig) -> ProcessConfig:
audio_file = manifest_dir / audio_file
filepaths.append(str(audio_file.absolute()))

# common path for all files
common_path = os.path.commonpath(filepaths)
if Path(common_path).is_relative_to(manifest_dir):
# if all paths are relative to the manifest, use manifest dir as input dir
input_dir = manifest_dir
else:
# use the parent of the common path as input dir
input_dir = Path(common_path).parent

if cfg.max_utts is not None:
# Limit the number of utterances to process
filepaths = filepaths[: cfg.max_utts]
Expand Down Expand Up @@ -238,6 +248,7 @@ def autocast():
batch_size=cfg.batch_size,
num_workers=cfg.num_workers,
input_channel_selector=cfg.input_channel_selector,
input_dir=input_dir,
)

logging.info(f"Finished processing {len(filepaths)} files!")
Expand Down
14 changes: 12 additions & 2 deletions nemo/collections/audio/models/audio_to_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,7 @@ def process(
batch_size: int = 1,
num_workers: Optional[int] = None,
input_channel_selector: Optional[ChannelSelectorType] = None,
input_dir: Optional[str] = None,
) -> List[str]:
"""
Takes paths to audio files and returns a list of paths to processed
Expand All @@ -344,6 +345,7 @@ def process(
num_workers: Number of workers for the dataloader
input_channel_selector (int | Iterable[int] | str): select a single channel or a subset of channels from multi-channel audio.
If set to `'average'`, it performs averaging across channels. Disabled if set to `None`. Defaults to `None`.
input_dir: Optional, directory that contains the input files. If provided, the output directory will mirror the input directory structure.
Returns:
Paths to processed audio signals.
Expand Down Expand Up @@ -413,9 +415,17 @@ def process(

for example_idx in range(processed_batch.size(0)):
# This assumes the data loader is not shuffling files
file_name = os.path.basename(paths2audio_files[file_idx])
if input_dir is not None:
# Make sure the output has the same directory structure as the input
filepath_relative = os.path.relpath(paths2audio_files[file_idx], start=input_dir)
else:
# Input dir is not provided, save files in the output directory
filepath_relative = os.path.basename(paths2audio_files[file_idx])
# Prepare output file
output_file = os.path.join(output_dir, f'processed_{file_name}')
output_file = os.path.join(output_dir, filepath_relative)
# Create output dir if necessary
if not os.path.isdir(os.path.dirname(output_file)):
os.makedirs(os.path.dirname(output_file))
# Crop the output signal to the actual length
output_signal = processed_batch[example_idx, :, : input_length[example_idx]].cpu().numpy()
# Write audio
Expand Down

0 comments on commit f17c418

Please sign in to comment.