Skip to content

Commit

Permalink
PerformInferenceをeasy-extでの実装にする
Browse files Browse the repository at this point in the history
  • Loading branch information
qryxip committed Nov 2, 2024
1 parent 44edb59 commit c941504
Showing 1 changed file with 20 additions and 69 deletions.
89 changes: 20 additions & 69 deletions crates/voicevox_core/src/synthesizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -833,67 +833,6 @@ mod inner {
}
}

// TODO: `mod blocking`に移動する
pub trait PerformInference {
/// `predict_duration`を実行する。
///
/// # Performance
///
/// CPU-boundな操作であるため、非同期ランタイム上では直接実行されるべきではない。
fn predict_duration(&self, phoneme_vector: &[i64], style_id: StyleId) -> Result<Vec<f32>>;

/// `predict_intonation`を実行する。
///
/// # Performance
///
/// CPU-boundな操作であるため、非同期ランタイム上では直接実行されるべきではない。
#[expect(
clippy::too_many_arguments,
reason = "compatible_engineでの`predict_intonation`の形を考えると、ここの引数を構造体に\
まとめたりしても可読性に寄与しない"
)]
fn predict_intonation(
&self,
length: usize,
vowel_phoneme_vector: &[i64],
consonant_phoneme_vector: &[i64],
start_accent_vector: &[i64],
end_accent_vector: &[i64],
start_accent_phrase_vector: &[i64],
end_accent_phrase_vector: &[i64],
style_id: StyleId,
) -> Result<Vec<f32>>;

fn generate_full_intermediate(
&self,
length: usize,
phoneme_size: usize,
f0: &[f32],
phoneme_vector: &[f32],
style_id: StyleId,
) -> Result<ndarray::Array2<f32>>;

fn render_audio_segment(
&self,
spec: ndarray::Array2<f32>,
style_id: StyleId,
) -> Result<ndarray::Array1<f32>>;

/// `decode`を実行する。
///
/// # Performance
///
/// CPU/GPU-boundな操作であるため、非同期ランタイム上では直接実行されるべきではない。
fn decode(
&self,
length: usize,
phoneme_size: usize,
f0: &[f32],
phoneme_vector: &[f32],
style_id: StyleId,
) -> Result<Vec<f32>>;
}

impl<O, A: AsyncForOnnxruntime> Inner<O, A> {
pub(super) async fn predict_duration(
&self,
Expand Down Expand Up @@ -991,8 +930,8 @@ mod inner {
}
}

// CPU/GPU-bound
impl<R: InferenceRuntime> Status<R> {
/// CPU-boundな操作なので、非同期ランタイム上では直接実行されるべきではない。
fn predict_duration(
&self,
phoneme_vector: ndarray::Array1<i64>,
Expand Down Expand Up @@ -1022,6 +961,7 @@ mod inner {
const PHONEME_LENGTH_MINIMAL: f32 = 0.01;
}

/// CPU-boundな操作なので、非同期ランタイム上では直接実行されるべきではない。
#[expect(
clippy::too_many_arguments,
reason = "compatible_engineでの`predict_intonation`の形を考えると、ここの引数を構造体に\
Expand Down Expand Up @@ -1057,6 +997,7 @@ mod inner {
Ok(output.into_raw_vec())
}

/// CPU-boundな操作なので、非同期ランタイム上では直接実行されるべきではない。
fn generate_full_intermediate(
&self,
length: usize,
Expand All @@ -1078,6 +1019,7 @@ mod inner {
Ok(spec)
}

/// CPU/GPU-boundな操作なので、非同期ランタイム上では直接実行されるべきではない。
fn render_audio_segment(
&self,
spec: ndarray::Array2<f32>,
Expand All @@ -1089,6 +1031,7 @@ mod inner {
Ok(wave)
}

/// CPU/GPU-boundな操作なので、非同期ランタイム上では直接実行されるべきではない。
fn decode(
&self,
length: usize,
Expand Down Expand Up @@ -1246,15 +1189,22 @@ mod inner {
}
}

#[expect(
clippy::too_many_arguments,
reason = "`PerformInference::predict_intonation`用。compatible_engineでの`predict_intonation`の\
形を考えると、ここの引数を構造体にまとめたりしても可読性に寄与しない"
)]
pub(crate) mod blocking {
use easy_ext::ext;

use crate::{
asyncs::SingleTasked, future::FutureExt as _, AccentPhrase, AudioQuery,
FullcontextExtractor, StyleId, VoiceModelId, VoiceModelMeta,
};

use super::{inner::Inner, InitializeOptions, SynthesisOptions, TtsOptions};

pub use super::inner::{AudioFeature, PerformInference};
pub use super::inner::AudioFeature;

/// 音声シンセサイザ。
pub struct Synthesizer<O>(pub(super) Inner<O, SingleTasked>);
Expand Down Expand Up @@ -1564,16 +1514,17 @@ pub(crate) mod blocking {
}
}

impl<O> PerformInference for self::Synthesizer<O> {
fn predict_duration(
#[ext(PerformInference)]
impl self::Synthesizer<()> {
pub fn predict_duration(
&self,
phoneme_vector: &[i64],
style_id: StyleId,
) -> crate::Result<Vec<f32>> {
self.0.predict_duration(phoneme_vector, style_id).block_on()
}

fn predict_intonation(
pub fn predict_intonation(
&self,
length: usize,
vowel_phoneme_vector: &[i64],
Expand All @@ -1598,7 +1549,7 @@ pub(crate) mod blocking {
.block_on()
}

fn generate_full_intermediate(
pub fn generate_full_intermediate(
&self,
length: usize,
phoneme_size: usize,
Expand All @@ -1611,15 +1562,15 @@ pub(crate) mod blocking {
.block_on()
}

fn render_audio_segment(
pub fn render_audio_segment(
&self,
spec: ndarray::Array2<f32>,
style_id: StyleId,
) -> crate::Result<ndarray::Array1<f32>> {
self.0.render_audio_segment(spec, style_id).block_on()
}

fn decode(
pub fn decode(
&self,
length: usize,
phoneme_size: usize,
Expand Down

0 comments on commit c941504

Please sign in to comment.