Skip to content

Commit

Permalink
Merge pull request #656 from ShiromiyaG/new-prepocess-extract
Browse files Browse the repository at this point in the history
New prepocess and extract
  • Loading branch information
blaisewf authored Sep 1, 2024
2 parents 5ddbc85 + 7bcc357 commit 029665a
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 206 deletions.
237 changes: 72 additions & 165 deletions rvc/train/extract/extract.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
import os
import os, glob
import sys
import time
import tqdm
import torch
# Zluda
if torch.cuda.is_available() and torch.cuda.get_device_name().endswith("[ZLUDA]"):
torch.backends.cudnn.enabled = False
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_math_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(False)
import torchcrepe
import numpy as np
import soundfile as sf
from multiprocessing import Pool
from functools import partial
import concurrent.futures
import torch.nn.functional as F

now_dir = os.getcwd()
sys.path.append(os.path.join(now_dir))
Expand All @@ -22,56 +24,6 @@
# Load config
config = Config()


def setup_paths(exp_dir: str, version: str = None):
"""Set up input and output paths."""
wav_path = os.path.join(exp_dir, "sliced_audios_16k")
if version:
out_path = os.path.join(
exp_dir, "v1_extracted" if version == "v1" else "v2_extracted"
)
os.makedirs(out_path, exist_ok=True)
return wav_path, out_path
else:
output_root1 = os.path.join(exp_dir, "f0")
output_root2 = os.path.join(exp_dir, "f0_voiced")
os.makedirs(output_root1, exist_ok=True)
os.makedirs(output_root2, exist_ok=True)
return wav_path, output_root1, output_root2


def read_wave(wav_path: str, normalize: bool = False):
"""Read a wave file and return its features."""
wav, sr = sf.read(wav_path)
assert sr == 16000, "Sample rate must be 16000"

feats = torch.from_numpy(wav).float()
if config.is_half:
feats = feats.half()
if feats.dim() == 2:
feats = feats.mean(-1)
feats = feats.view(1, -1)

if normalize:
feats = F.layer_norm(feats, feats.shape)
return feats


def get_device(gpu_index):
"""Get the appropriate device based on GPU availability."""
if gpu_index == "cpu":
return "cpu"
try:
index = int(gpu_index)
if index < torch.cuda.device_count():
return f"cuda:{index}"
else:
print("Invalid GPU index. Switching to CPU.")
except ValueError:
print("Invalid GPU index format. Switching to CPU.")
return "cpu"


class FeatureInput:
"""Class for F0 extraction."""

Expand Down Expand Up @@ -142,105 +94,69 @@ def coarse_f0(self, f0):

def process_file(self, file_info, f0_method, hop_length):
"""Process a single audio file for F0 extraction."""
inp_path, opt_path1, opt_path2, np_arr = file_info
inp_path, opt_path1, opt_path2, _ = file_info
#print(f"Process file {inp_path}. Class on {self.device}, model is on {self.model_rmvpe.device}")

if os.path.exists(opt_path1 + ".npy") and os.path.exists(opt_path2 + ".npy"):
if os.path.exists(opt_path1) and os.path.exists(opt_path2):
return

try:
np_arr = load_audio(inp_path, 16000)
feature_pit = self.compute_f0(np_arr, f0_method, hop_length)
np.save(opt_path2, feature_pit, allow_pickle=False)
coarse_pit = self.coarse_f0(feature_pit)
np.save(opt_path1, coarse_pit, allow_pickle=False)
except Exception as error:
print(f"An error occurred extracting file {inp_path}: {error}")
print(f"An error occurred extracting file {inp_path} on {self.device}: {error}")

def process_files(self, files, f0_method, hop_length, pbar):
"""Process multiple files."""
for file_info in files:
self.process_file(file_info, f0_method, hop_length)
pbar.update()


def run_pitch_extraction(exp_dir, f0_method, hop_length, num_processes, gpus):
input_root, *output_roots = setup_paths(exp_dir)

if len(output_roots) == 2:
output_root1, output_root2 = output_roots
else:
output_root1 = output_roots[0]
output_root2 = None

paths = [
(
os.path.join(input_root, name),
os.path.join(output_root1, name) if output_root1 else None,
os.path.join(output_root2, name) if output_root2 else None,
load_audio(os.path.join(input_root, name), 16000),
)
for name in sorted(os.listdir(input_root))
if "spec" not in name
]
pbar.update(1)

def run_pitch_extraction(files, devices, f0_method, hop_length, num_processes):
print(f"Starting pitch extraction with {num_processes} cores and {f0_method}...")
start_time = time.time()

if gpus != "-":
gpus = gpus.split("-")
num_gpus = len(gpus)
process_partials = []
pbar = tqdm.tqdm(total=len(paths), desc="Pitch Extraction")

for idx, gpu in enumerate(gpus):
device = get_device(gpu)
feature_input = FeatureInput(device=device)
part_paths = paths[idx::num_gpus]
process_partials.append((feature_input, part_paths))

with concurrent.futures.ThreadPoolExecutor() as executor:
futures = [
executor.submit(
FeatureInput.process_files,
feature_input,
part_paths,
f0_method,
hop_length,
pbar,
)
for feature_input, part_paths in process_partials
]
for future in concurrent.futures.as_completed(futures):
future.result()
pbar.close()

else:
feature_input = FeatureInput(device="cpu")
with tqdm.tqdm(total=len(paths), desc="Pitch Extraction") as pbar:
with Pool(processes=num_processes) as pool:
process_file_partial = partial(
feature_input.process_file,
f0_method=f0_method,
hop_length=hop_length,
)
for _ in pool.imap_unordered(process_file_partial, paths):
pbar.update()
pbar = tqdm.tqdm(total=len(files), desc="Pitch Extraction")
num_gpus = len(devices)
process_partials = []
for idx, gpu in enumerate(devices):
device = torch.device(gpu)
feature_input = FeatureInput(device=device)
part_paths = files[idx::num_gpus]
process_partials.append((feature_input, part_paths))

with concurrent.futures.ThreadPoolExecutor(max_workers = num_processes) as executor:
futures = [
executor.submit(
FeatureInput.process_files,
feature_input,
part_paths,
f0_method,
hop_length,
pbar,
)
for feature_input, part_paths in process_partials
]
for future in concurrent.futures.as_completed(futures):
future.result()
pbar.close()

elapsed_time = time.time() - start_time
print(f"Pitch extraction completed in {elapsed_time:.2f} seconds.")


def process_file_embedding(file, wav_path, out_path, model, device, version):
def process_file_embedding(file_info, model, device):
"""Process a single audio file for embedding extraction."""
wav_file_path = os.path.join(wav_path, file)
out_file_path = os.path.join(out_path, file.replace("wav", "npy"))
wav_file_path, _, _, out_file_path = file_info

if os.path.exists(out_file_path):
return

feats = read_wave(wav_file_path)
dtype = torch.float16 if config.is_half else torch.float32
feats = feats.to(dtype).to(device)
dtype = torch.float16 if config.is_half and "cuda" in device else torch.float32
model = model.to(dtype).to(device)
feats = torch.from_numpy(load_audio(wav_file_path, 16000)).to(dtype).to(device)
feats = feats.view(1, -1)

with torch.no_grad():
feats = model(feats)["last_hidden_state"]
Expand All @@ -252,53 +168,27 @@ def process_file_embedding(file, wav_path, out_path, model, device, version):
else:
print(f"{file} contains NaN values and will be skipped.")


def run_embedding_extraction(
exp_dir, version, gpus, embedder_model, embedder_model_custom
):
def run_embedding_extraction(files, devices, embedder_model, embedder_model_custom):
"""Main function to orchestrate the embedding extraction process."""
wav_path, out_path = setup_paths(exp_dir, version)

print("Starting embedding extraction...")
start_time = time.time()
model = load_embedding(embedder_model, embedder_model_custom)

models = load_embedding(embedder_model, embedder_model_custom)
pbar = tqdm.tqdm(total=len(files), desc="Embedding Extraction")

# Zluda
if torch.cuda.is_available() and torch.cuda.get_device_name().endswith("[ZLUDA]"):
print("Disabling CUDNN for Zluda")
torch.backends.cudnn.enabled = False
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_math_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(False)

devices = [get_device(gpu) for gpu in (gpus.split("-") if gpus != "-" else ["cpu"])]

paths = sorted([file for file in os.listdir(wav_path) if file.endswith(".wav")])
if not paths:
print("No audio files found. Make sure you have provided the audios correctly.")
sys.exit(1)

pbar = tqdm.tqdm(total=len(paths) * len(devices), desc="Embedding Extraction")

tasks = [
(file, wav_path, out_path, models, device, version)
for file in paths
for device in devices
]

for task in tasks:
# add multi-threading here?
for i, file_info in enumerate(files):
device = devices[i%len(devices)]
try:
process_file_embedding(*task)
process_file_embedding(file_info, model, device)
except Exception as error:
print(f"An error occurred processing {task[0]}: {error}")
print(f"An error occurred processing {file_info[0]}: {error}")
pbar.update(1)

pbar.close()
elapsed_time = time.time() - start_time
print(f"Embedding extraction completed in {elapsed_time:.2f} seconds.")


if __name__ == "__main__":

exp_dir = sys.argv[1]
Expand All @@ -312,13 +202,30 @@ def run_embedding_extraction(
embedder_model = sys.argv[9]
embedder_model_custom = sys.argv[10] if len(sys.argv) > 10 else None

# prep
wav_path = os.path.join(exp_dir, "sliced_audios_16k")
os.makedirs(os.path.join(exp_dir, "f0"), exist_ok=True)
os.makedirs(os.path.join(exp_dir, "f0_voiced"), exist_ok=True)
os.makedirs(os.path.join(exp_dir, version + "_extracted"), exist_ok=True)

files = []
for file in glob.glob(os.path.join(wav_path, "*.wav")):
file_name = os.path.basename(file)
file_info = [
file, # full path to sliced 16k wav
os.path.join(exp_dir, "f0", file_name + ".npy"),
os.path.join(exp_dir, "f0_voiced", file_name + ".npy"),
os.path.join(exp_dir, version + "_extracted", file_name.replace("wav", "npy"))
]
files.append(file_info)

devices = ["cpu"] if gpus == "-" else [f"cuda:{idx}" for idx in gpus.split("-")]

# Run Pitch Extraction
run_pitch_extraction(exp_dir, f0_method, hop_length, num_processes, gpus)
run_pitch_extraction(files, devices, f0_method, hop_length, num_processes)

# Run Embedding Extraction
run_embedding_extraction(
exp_dir, version, gpus, embedder_model, embedder_model_custom
)
run_embedding_extraction(files, devices, embedder_model, embedder_model_custom)

# Run Preparing Files
generate_config(version, sample_rate, exp_dir)
Expand Down
Loading

0 comments on commit 029665a

Please sign in to comment.