Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ONNX Runtimeとモデルのシグネチャを隔離する #675

Merged
merged 48 commits into from
Nov 16, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
5ff2b59
ONNX Runtimeとモデルのシグネチャを隔離する
qryxip Nov 5, 2023
33245ff
`R: InferenceCore`を`SynthesisEngine`まで持っていく
qryxip Nov 5, 2023
4bd2281
`SupportedDevices::create`の実装を移動
qryxip Nov 5, 2023
f959911
不要なreexportを削除
qryxip Nov 5, 2023
55b04d3
`InferenceModels`の定義を`signatures`に移動
qryxip Nov 5, 2023
c38ad99
`ErrorRepr::GetSupportedDevices`の中身を`anyhow::Error`に
qryxip Nov 6, 2023
192417f
enum-map v3.0.0-beta.1を導入し、`EnumMap`駆動に
qryxip Nov 6, 2023
20db67a
Minor refactor
qryxip Nov 6, 2023
cc84068
Minor refactor
qryxip Nov 6, 2023
e0f29c6
色々再構成
qryxip Nov 8, 2023
cb1db34
Fix up
qryxip Nov 8, 2023
c3e08dd
`OnnxruntimeInferenceBuilder` → `OnnxruntimeRunContext`
qryxip Nov 8, 2023
e4b91ab
`impl SupportsInferenceOutput<_> for Onnxruntime`を拡張
qryxip Nov 8, 2023
8584d27
`SignatureKind` → `InferenceSignatureKind`
qryxip Nov 8, 2023
4795309
`LoadedModels`へのアクセスをメソッド越しにするのを徹底する
qryxip Nov 8, 2023
a5dbbdd
Minor refactor
qryxip Nov 8, 2023
525f4b1
`InferenceInput` → `InferenceInputSignature`
qryxip Nov 8, 2023
26476f5
相互参照
qryxip Nov 8, 2023
fbd7d1c
`fn input`まわりを明瞭にする
qryxip Nov 9, 2023
8b4f3b6
"signature"のkindではなく"model"のkindとする
qryxip Nov 9, 2023
c40afd5
"model"ではなく"inference"と呼ぶ
qryxip Nov 11, 2023
81b5804
ランタイムは任意次元任意個数の入出力ができると仮定する
qryxip Nov 11, 2023
120106b
voicevox_core_macrosを作り、"signatures"の実装をマクロ化
qryxip Nov 11, 2023
590ce48
`AnyTensor` → `OutputTensor`
qryxip Nov 11, 2023
c4d5ebe
`INFERENCE` → `KIND`
qryxip Nov 11, 2023
1b1b7bf
`status`を`infer`下に
qryxip Nov 11, 2023
c39f48c
`trait RunContext`を削除
qryxip Nov 11, 2023
c316209
"kind"を直接"group"と呼ぶことにする
qryxip Nov 11, 2023
2274a34
シグネチャの実行時チェック機構を入れる
qryxip Nov 12, 2023
b7d48f3
signaturesのマクロ化を完了させる
qryxip Nov 13, 2023
b6db1c0
Minor refactor
qryxip Nov 13, 2023
96a93e9
`InferenceGroup` → `InferenceDomain`
qryxip Nov 14, 2023
59d8779
Minor refactor
qryxip Nov 14, 2023
d0dc56f
`InferenceDomain::{INPUT,OUTPUT}_PARAM_INFOS`を統合
qryxip Nov 14, 2023
c654cd1
`InferenceDomain::PARAM_INFOS`にdocstring
qryxip Nov 14, 2023
868d3f6
voicevox_core_macrosにdocstring
qryxip Nov 14, 2023
0998793
`sealed::InputScalar`にFIXME
qryxip Nov 14, 2023
75fd7ac
"Domain"と"Operation"に分離
qryxip Nov 14, 2023
ad222c9
`InferenceOperationKind` → `InferenceOperationImpl`
qryxip Nov 15, 2023
7005c96
docを修正
qryxip Nov 15, 2023
9417992
Merge branch 'main' into split-onnxruntime-and-model-signatures
qryxip Nov 15, 2023
1655719
"voicevox_core内で" → "Rust APIクレート内で"
qryxip Nov 15, 2023
a73f22c
docを追記
qryxip Nov 15, 2023
f17919b
docを追記
qryxip Nov 15, 2023
48bdb1b
`InferenceDomain`のdocを書く
qryxip Nov 16, 2023
9d7d001
不要な文の削除
qryxip Nov 16, 2023
af828eb
Minor refactor
qryxip Nov 16, 2023
b6b7975
`ArrayExt`をマクロ内に押し込める
qryxip Nov 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions crates/voicevox_core/src/infer.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
pub(crate) mod domain;
mod model_file;
pub(crate) mod runtimes;
pub(crate) mod signatures;
pub(crate) mod status;

use std::{borrow::Cow, fmt::Debug};
Expand Down Expand Up @@ -36,16 +36,25 @@ pub(crate) trait InferenceRuntime: 'static {
fn run(ctx: Self::RunContext<'_>) -> anyhow::Result<Vec<OutputTensor>>;
}

pub(crate) trait InferenceGroup: Copy + Enum {
const INPUT_PARAM_INFOS: EnumMap<Self, &'static [ParamInfo<InputScalarKind>]>;
const OUTPUT_PARAM_INFOS: EnumMap<Self, &'static [ParamInfo<OutputScalarKind>]>;
pub(crate) trait InferenceDomain: Copy + Enum {
/// `{InferenceInputSignature,InferenceOutputSignature}::PARAM_INFOS`を集めたもの。
///
/// マクロ(voicevox_core_macros)で実装される前提。
#[allow(clippy::type_complexity)]
const PARAM_INFOS: EnumMap<
Self,
(
&'static [ParamInfo<InputScalarKind>],
&'static [ParamInfo<OutputScalarKind>],
),
>;
}

pub(crate) trait InferenceSignature: Sized + Send + 'static {
type Group: InferenceGroup;
type Domain: InferenceDomain;
type Input: InferenceInputSignature<Signature = Self>;
type Output: InferenceOutputSignature;
const KIND: Self::Group;
const KIND: Self::Domain;
Hiroshiba marked this conversation as resolved.
Show resolved Hide resolved
}

pub(crate) trait InferenceInputSignature: Send + 'static {
Expand Down Expand Up @@ -155,6 +164,8 @@ pub(crate) enum ExtractError {
#[error("不正なモデルファイルです")]
pub(crate) struct DecryptModelError;

// FIXME: `onnxruntime::TypeToTensorElementDataType`に依存する代わりに、`InputScalar`から`runtimes`
// まではvisitor patternでつなぐ
mod sealed {
pub(crate) trait InputScalar: OnnxruntimeInputScalar {}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,24 +1,24 @@
use enum_map::Enum;
use macros::{InferenceGroup, InferenceInputSignature, InferenceOutputSignature};
use macros::{InferenceDomain, InferenceInputSignature, InferenceOutputSignature};
use ndarray::{Array0, Array1, Array2};

use super::{InferenceInputSignature as _, InferenceOutputSignature as _, OutputTensor};

#[derive(Clone, Copy, Enum, InferenceGroup)]
#[derive(Clone, Copy, Enum, InferenceDomain)]
pub(crate) enum InferenceKind {
Hiroshiba marked this conversation as resolved.
Show resolved Hide resolved
#[inference_group(
#[inference_domain(
type Input = PredictDurationInput;
type Output = PredictDurationOutput;
)]
PredictDuration,

#[inference_group(
#[inference_domain(
type Input = PredictIntonationInput;
type Output = PredictIntonationOutput;
)]
PredictIntonation,

#[inference_group(
#[inference_domain(
type Input = DecodeInput;
type Output = DecodeOutput;
)]
Expand Down
55 changes: 27 additions & 28 deletions crates/voicevox_core/src/infer/status.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,17 @@ use crate::{
};

use super::{
model_file, InferenceGroup, InferenceInputSignature, InferenceRuntime, InferenceSessionOptions,
InferenceSignature,
model_file, InferenceDomain, InferenceInputSignature, InferenceRuntime,
InferenceSessionOptions, InferenceSignature,
};

pub(crate) struct Status<R: InferenceRuntime, G: InferenceGroup> {
loaded_models: std::sync::Mutex<LoadedModels<R, G>>,
session_options: EnumMap<G, InferenceSessionOptions>,
pub(crate) struct Status<R: InferenceRuntime, D: InferenceDomain> {
loaded_models: std::sync::Mutex<LoadedModels<R, D>>,
session_options: EnumMap<D, InferenceSessionOptions>,
}

impl<R: InferenceRuntime, G: InferenceGroup> Status<R, G> {
pub fn new(session_options: EnumMap<G, InferenceSessionOptions>) -> Self {
impl<R: InferenceRuntime, D: InferenceDomain> Status<R, D> {
Hiroshiba marked this conversation as resolved.
Show resolved Hide resolved
pub fn new(session_options: EnumMap<D, InferenceSessionOptions>) -> Self {
Self {
loaded_models: Default::default(),
session_options,
Expand All @@ -40,7 +40,7 @@ impl<R: InferenceRuntime, G: InferenceGroup> Status<R, G> {
pub async fn load_model(
&self,
model: &VoiceModel,
model_bytes: &EnumMap<G, Vec<u8>>,
model_bytes: &EnumMap<D, Vec<u8>>,
) -> Result<()> {
self.loaded_models
.lock()
Expand Down Expand Up @@ -100,7 +100,7 @@ impl<R: InferenceRuntime, G: InferenceGroup> Status<R, G> {
) -> Result<<I::Signature as InferenceSignature>::Output>
where
I: InferenceInputSignature,
I::Signature: InferenceSignature<Group = G>,
I::Signature: InferenceSignature<Domain = D>,
{
let sess = self.loaded_models.lock().unwrap().get(model_id);

Expand All @@ -114,18 +114,18 @@ impl<R: InferenceRuntime, G: InferenceGroup> Status<R, G> {
///
/// この構造体のメソッドは、すべて一瞬で完了すべきである。
#[derive(Educe)]
#[educe(Default(bound = "R: InferenceRuntime, G: InferenceGroup"))]
struct LoadedModels<R: InferenceRuntime, G: InferenceGroup>(
BTreeMap<VoiceModelId, LoadedModel<R, G>>,
#[educe(Default(bound = "R: InferenceRuntime, D: InferenceDomain"))]
struct LoadedModels<R: InferenceRuntime, D: InferenceDomain>(
BTreeMap<VoiceModelId, LoadedModel<R, D>>,
);

struct LoadedModel<R: InferenceRuntime, G: InferenceGroup> {
struct LoadedModel<R: InferenceRuntime, D: InferenceDomain> {
model_inner_ids: BTreeMap<StyleId, ModelInnerId>,
metas: VoiceModelMeta,
session_set: SessionSet<R, G>,
session_set: SessionSet<R, D>,
}

impl<R: InferenceRuntime, G: InferenceGroup> LoadedModels<R, G> {
impl<R: InferenceRuntime, D: InferenceDomain> LoadedModels<R, D> {
fn metas(&self) -> VoiceModelMeta {
self.0
.values()
Expand Down Expand Up @@ -164,7 +164,7 @@ impl<R: InferenceRuntime, G: InferenceGroup> LoadedModels<R, G> {
fn get<I>(&self, model_id: &VoiceModelId) -> SessionCell<R, I>
where
I: InferenceInputSignature,
I::Signature: InferenceSignature<Group = G>,
I::Signature: InferenceSignature<Domain = D>,
{
self.0[model_id].session_set.get()
}
Expand Down Expand Up @@ -207,7 +207,7 @@ impl<R: InferenceRuntime, G: InferenceGroup> LoadedModels<R, G> {
Ok(())
}

fn insert(&mut self, model: &VoiceModel, session_set: SessionSet<R, G>) -> Result<()> {
fn insert(&mut self, model: &VoiceModel, session_set: SessionSet<R, D>) -> Result<()> {
self.ensure_acceptable(model)?;

let prev = self.0.insert(
Expand Down Expand Up @@ -240,20 +240,19 @@ impl<R: InferenceRuntime, G: InferenceGroup> LoadedModels<R, G> {
}
}

struct SessionSet<R: InferenceRuntime, G: InferenceGroup>(
EnumMap<G, Arc<std::sync::Mutex<R::Session>>>,
struct SessionSet<R: InferenceRuntime, D: InferenceDomain>(
EnumMap<D, Arc<std::sync::Mutex<R::Session>>>,
);

impl<R: InferenceRuntime, G: InferenceGroup> SessionSet<R, G> {
impl<R: InferenceRuntime, D: InferenceDomain> SessionSet<R, D> {
fn new(
model_bytes: &EnumMap<G, Vec<u8>>,
options: &EnumMap<G, InferenceSessionOptions>,
model_bytes: &EnumMap<D, Vec<u8>>,
options: &EnumMap<D, InferenceSessionOptions>,
) -> anyhow::Result<Self> {
let mut sessions = model_bytes
.iter()
.map(|(k, m)| {
Hiroshiba marked this conversation as resolved.
Show resolved Hide resolved
let expected_input_param_infos = G::INPUT_PARAM_INFOS[k];
let expected_output_param_infos = G::OUTPUT_PARAM_INFOS[k];
let (expected_input_param_infos, expected_output_param_infos) = D::PARAM_INFOS[k];

let (sess, actual_input_param_infos, actual_output_param_infos) =
R::new_session(|| model_file::decrypt(m), options[k])?;
Expand All @@ -265,7 +264,7 @@ impl<R: InferenceRuntime, G: InferenceGroup> SessionSet<R, G> {
})
.collect::<anyhow::Result<HashMap<_, _>>>()?;

return Ok(Self(EnumMap::<G, _>::from_fn(|k| {
return Ok(Self(EnumMap::<D, _>::from_fn(|k| {
sessions.remove(&k.into_usize()).expect("should exist")
})));

Expand Down Expand Up @@ -299,11 +298,11 @@ impl<R: InferenceRuntime, G: InferenceGroup> SessionSet<R, G> {
}
}

impl<R: InferenceRuntime, G: InferenceGroup> SessionSet<R, G> {
impl<R: InferenceRuntime, D: InferenceDomain> SessionSet<R, D> {
fn get<I>(&self) -> SessionCell<R, I>
where
I: InferenceInputSignature,
I::Signature: InferenceSignature<Group = G>,
I::Signature: InferenceSignature<Domain = D>,
{
SessionCell {
inner: self.0[I::Signature::KIND].clone(),
Expand Down Expand Up @@ -334,7 +333,7 @@ mod tests {
use rstest::rstest;

use crate::{
infer::signatures::InferenceKind, macros::tests::assert_debug_fmt_eq,
infer::domain::InferenceKind, macros::tests::assert_debug_fmt_eq,
synthesizer::InferenceRuntimeImpl, test_util::open_default_vvm_file,
};

Expand Down
2 changes: 1 addition & 1 deletion crates/voicevox_core/src/inference_core.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use enum_map::enum_map;

use crate::infer::{
signatures::{
domain::{
DecodeInput, DecodeOutput, InferenceKind, PredictDurationInput, PredictDurationOutput,
PredictIntonationInput, PredictIntonationOutput,
},
Expand Down
2 changes: 1 addition & 1 deletion crates/voicevox_core/src/voice_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use futures::future::join3;
use serde::{de::DeserializeOwned, Deserialize};

use super::*;
use crate::infer::signatures::InferenceKind;
use crate::infer::domain::InferenceKind;
use std::{
collections::{BTreeMap, HashMap},
io,
Expand Down
Loading
Loading