Skip to content

Commit

Permalink
refactor: InferenceDomainMapValuesのインスタンスをマクロで作る
Browse files Browse the repository at this point in the history
 #737 に向け。また #851 の後にdecode.onnx入りのVVMに対応するときも同様に
役に立つはず。
  • Loading branch information
qryxip committed Oct 9, 2024
1 parent f2e6b60 commit db228b1
Show file tree
Hide file tree
Showing 11 changed files with 147 additions and 81 deletions.
82 changes: 14 additions & 68 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ const_format = "0.2.33"
cstr = "0.2.12" # https://github.com/dtolnay/syn/issues/1502
derive-getters = "0.2.0"
derive-new = "0.5.9"
derive-syn-parse = "0.2.0"
derive_more = "0.99.17"
duct = "0.13.7"
duplicate = "1.0.0"
Expand Down
3 changes: 2 additions & 1 deletion crates/voicevox_core/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ mod model_file;
pub(crate) mod runtimes;
pub(crate) mod session_set;

use std::{borrow::Cow, collections::BTreeSet, fmt::Debug};
use std::{borrow::Cow, collections::BTreeSet, fmt::Debug, ops::Index, sync::Arc};

use derive_new::new;
use duplicate::duplicate_item;
Expand Down Expand Up @@ -51,6 +51,7 @@ pub(crate) trait InferenceRuntime: 'static {
/// 共に扱われるべき推論操作の集合を示す。
pub(crate) trait InferenceDomain: Sized {
type Operation: InferenceOperation;
type Manifest: Index<Self::Operation, Output = Arc<str>>;

/// 対応する`StyleType`。
///
Expand Down
9 changes: 9 additions & 0 deletions crates/voicevox_core/src/infer/domains.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,12 @@ pub(crate) trait InferenceDomainMapValues {
impl<T> InferenceDomainMapValues for (T,) {
type Talk = T;
}

macro_rules! inference_domain_map_values {
(for<$arg:ident> $body:ty) => {
(::macros::substitute_type!(
$body where $arg = crate::infer::domains::TalkDomain as crate::infer::InferenceDomain
),)
};
}
pub(crate) use inference_domain_map_values;
3 changes: 2 additions & 1 deletion crates/voicevox_core/src/infer/domains/talk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use enum_map::Enum;
use macros::{InferenceInputSignature, InferenceOperation, InferenceOutputSignature};
use ndarray::{Array0, Array1, Array2};

use crate::StyleType;
use crate::{manifest::TalkManifest, StyleType};

use super::super::{
InferenceDomain, InferenceInputSignature as _, InferenceOutputSignature as _, OutputTensor,
Expand All @@ -14,6 +14,7 @@ pub(crate) enum TalkDomain {}

impl InferenceDomain for TalkDomain {
type Operation = TalkOperation;
type Manifest = TalkManifest;

fn style_types() -> &'static BTreeSet<StyleType> {
static STYLE_TYPES: LazyLock<BTreeSet<StyleType>> =
Expand Down
4 changes: 2 additions & 2 deletions crates/voicevox_core/src/manifest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use serde::{de, Deserialize, Deserializer, Serialize};
use serde_with::{serde_as, DisplayFromStr};

use crate::{
infer::domains::{InferenceDomainMap, TalkOperation},
infer::domains::{inference_domain_map_values, InferenceDomainMap, TalkOperation},
StyleId, VoiceModelId,
};

Expand Down Expand Up @@ -79,7 +79,7 @@ pub struct Manifest {
domains: InferenceDomainMap<ManifestDomains>,
}

pub(crate) type ManifestDomains = (Option<TalkManifest>,);
pub(crate) type ManifestDomains = inference_domain_map_values!(for<D> Option<D::Manifest>);

#[derive(Deserialize, IndexForFields)]
#[cfg_attr(test, derive(Default))]
Expand Down
7 changes: 4 additions & 3 deletions crates/voicevox_core/src/status.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use itertools::iproduct;
use crate::{
error::{ErrorRepr, LoadModelError, LoadModelErrorKind, LoadModelResult},
infer::{
domains::{InferenceDomainMap, TalkDomain, TalkOperation},
domains::{inference_domain_map_values, InferenceDomainMap, TalkDomain},
session_set::{InferenceSessionCell, InferenceSessionSet},
InferenceDomain, InferenceInputSignature, InferenceRuntime, InferenceSessionOptions,
InferenceSignature,
Expand Down Expand Up @@ -338,10 +338,11 @@ impl InferenceDomainMap<ModelBytesWithInnerVoiceIdsByDomain> {
}
}

type SessionOptionsByDomain = (EnumMap<TalkOperation, InferenceSessionOptions>,);
type SessionOptionsByDomain =
inference_domain_map_values!(for<D> EnumMap<D::Operation, InferenceSessionOptions>);

type SessionSetsWithInnerVoiceIdsByDomain<R> =
(Option<(StyleIdToInnerVoiceId, InferenceSessionSet<R, TalkDomain>)>,);
inference_domain_map_values!(for<D> Option<(StyleIdToInnerVoiceId, InferenceSessionSet<R, D>)>);

#[cfg(test)]
mod tests {
Expand Down
11 changes: 6 additions & 5 deletions crates/voicevox_core/src/voice_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ use crate::{
asyncs::{Async, Mutex as _},
error::{LoadModelError, LoadModelErrorKind, LoadModelResult},
infer::{
domains::{InferenceDomainMap, TalkDomain, TalkOperation},
domains::{inference_domain_map_values, InferenceDomainMap, TalkDomain, TalkOperation},
InferenceDomain,
},
manifest::{Manifest, ManifestDomains, StyleIdToInnerVoiceId, TalkManifest},
manifest::{Manifest, ManifestDomains, StyleIdToInnerVoiceId},
SpeakerMeta, StyleMeta, StyleType, VoiceModelMeta,
};

Expand All @@ -35,8 +35,9 @@ use crate::{
/// [`VoiceModelId`]: VoiceModelId
pub type RawVoiceModelId = Uuid;

pub(crate) type ModelBytesWithInnerVoiceIdsByDomain =
(Option<(StyleIdToInnerVoiceId, EnumMap<TalkOperation, Vec<u8>>)>,);
pub(crate) type ModelBytesWithInnerVoiceIdsByDomain = inference_domain_map_values!(
for<D> Option<(StyleIdToInnerVoiceId, EnumMap<D::Operation, Vec<u8>>)>
);

/// 音声モデルID。
#[derive(
Expand Down Expand Up @@ -251,7 +252,7 @@ impl<A: Async> Inner<A> {
}

type InferenceModelEntries<'manifest> =
(Option<InferenceModelEntry<TalkDomain, &'manifest TalkManifest>>,);
inference_domain_map_values!(for<D> Option<InferenceModelEntry<D, &'manifest D::Manifest>>);

struct InferenceModelEntry<D: InferenceDomain, M> {
indices: EnumMap<D::Operation, usize>,
Expand Down
3 changes: 2 additions & 1 deletion crates/voicevox_core_macros/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@ name = "macros"
proc-macro = true

[dependencies]
derive-syn-parse.workspace = true
indexmap.workspace = true
proc-macro2.workspace = true
quote.workspace = true
syn = { workspace = true, features = ["extra-traits", "full"] }
syn = { workspace = true, features = ["extra-traits", "full", "visit-mut"] }

[lints.rust]
unsafe_code = "forbid"
Expand Down
83 changes: 83 additions & 0 deletions crates/voicevox_core_macros/src/inference_domains.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
use derive_syn_parse::Parse;
use quote::ToTokens as _;
use syn::{
parse_quote,
visit_mut::{self, VisitMut},
Path, PathArguments, PathSegment, Token, Type, TypePath,
};

pub(crate) fn substitute_type(input: Substitution) -> syn::Result<proc_macro2::TokenStream> {
let Substitution {
mut body,
arg,
replacement,
replacement_as,
..
} = input;

Substitute {
arg,
replacement,
replacement_as,
}
.visit_type_mut(&mut body);

return Ok(body.to_token_stream());

struct Substitute {
arg: syn::Ident,
replacement: Path,
replacement_as: Path,
}

impl VisitMut for Substitute {
fn visit_type_mut(&mut self, i: &mut Type) {
visit_mut::visit_type_mut(self, i);

let Type::Path(TypePath {
qself: None,
path:
Path {
leading_colon: None,
segments,
},
}) = i
else {
return;
};

match &mut *segments.iter_mut().collect::<Vec<_>>() {
[PathSegment {
ident,
arguments: PathArguments::None,
}] if *ident == self.arg => {
let replacement = self.replacement.clone();
*i = parse_quote!(#replacement);
}
[PathSegment {
ident: ident1,
arguments: PathArguments::None,
}, seg]
if *ident1 == self.arg =>
{
let replacement = self.replacement.clone();
let replacement_as = self.replacement_as.clone();
*i = parse_quote!(<#replacement as #replacement_as>::#seg);
}
_ => {}
}
}
}
}

/// `$body:ty where $arg:ident = $replacement:path as $replacement_as:path`
#[derive(Parse)]
pub(crate) struct Substitution {
body: Type,
_where_token: Token![where],
arg: syn::Ident,
_eq_token: Token![=],
replacement: Path,
_as_token: Token![as],
replacement_as: Path,
}
Loading

0 comments on commit db228b1

Please sign in to comment.