-
Notifications
You must be signed in to change notification settings - Fork 118
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
StyleMeta::r#type
を追加し、トークという区分を実装に導入する
#761
Changes from 41 commits
ca6ce4a
599cd13
6cca5d0
cb2637a
d0cc720
4c80605
60d2517
098b139
704b5f5
cd9ad03
dcc5ee1
e4d3eec
9251c25
0e24fb2
1e12019
051d181
fd2905b
ad47904
c272d3d
e46cbd4
0403753
1894acd
ca4be74
7ccb2a3
2521744
26f254c
5df5a07
6543aa2
22d53f9
2a2f273
2bb8b82
5eeb975
28382c8
1bb6f45
1a51947
e6493d2
ba81714
de9a919
b30abc7
bcdc1a8
de63365
430a548
66a9924
07126c6
4270a1c
734564f
43fa77d
678004b
ba3d230
7311323
32e20bc
7fd730d
275d3d4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
mod talk; | ||
|
||
pub(crate) use self::talk::{ | ||
DecodeInput, DecodeOutput, PredictDurationInput, PredictDurationOutput, PredictIntonationInput, | ||
PredictIntonationOutput, TalkDomain, TalkOperation, | ||
}; | ||
|
||
pub(crate) struct InferenceDomainMap<V: InferenceDomainMapValues + ?Sized> { | ||
pub(crate) talk: V::Talk, | ||
} | ||
|
||
pub(crate) trait InferenceDomainMapValues { | ||
type Talk; | ||
} | ||
|
||
impl<T> InferenceDomainMapValues for (T,) { | ||
type Talk = T; | ||
} | ||
|
||
impl<A> InferenceDomainMapValues for [A] { | ||
type Talk = A; | ||
} |
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. status.rsから分離。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. このファイルレベルのコメントがoutdateになるやつ、未だによくわからない。何やってもoutdatedになるってこと…? |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
use std::{collections::HashMap, fmt::Display, marker::PhantomData, sync::Arc}; | ||
|
||
use anyhow::bail; | ||
use enum_map::{Enum as _, EnumMap}; | ||
use itertools::Itertools as _; | ||
|
||
use crate::error::ErrorRepr; | ||
|
||
use super::{ | ||
model_file, InferenceDomain, InferenceInputSignature, InferenceOperation, InferenceRuntime, | ||
InferenceSessionOptions, InferenceSignature, ParamInfo, | ||
}; | ||
|
||
pub(crate) struct SessionSet<R: InferenceRuntime, D: InferenceDomain>( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. このSessionが何を指すか、初見だとわからないかもと思いました! 例えばInferenceSessionSetにするのはどうでしょう。SessionCellも同じく。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. これまではstatus.rsの中だけで使ってたので名前を簡潔にしていたのですが、 "session"については語感でなんとなく把握してもらえないかなと思っています。とはいってもTFLiteとかに手を出すときは整理をつける必要がありますし、infer/下の諸々と一緒にドキュメントは整備したい感はあります。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. たしかにinfer以下にあれば察しはつくかもですね。 あーでも1回の推論ごとに1回セッションができると勘違いしてコード読み進めちゃう可能性はありそう。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. まああった方はよいですね。TODOを残しました。 |
||
EnumMap<D::Operation, Arc<std::sync::Mutex<R::Session>>>, | ||
); | ||
|
||
impl<R: InferenceRuntime, D: InferenceDomain> SessionSet<R, D> { | ||
pub(crate) fn new( | ||
model_bytes: &EnumMap<D::Operation, Vec<u8>>, | ||
options: &EnumMap<D::Operation, InferenceSessionOptions>, | ||
) -> anyhow::Result<Self> { | ||
let mut sessions = model_bytes | ||
.iter() | ||
.map(|(op, model_bytes)| { | ||
let (expected_input_param_infos, expected_output_param_infos) = | ||
<D::Operation as InferenceOperation>::PARAM_INFOS[op]; | ||
|
||
let (sess, actual_input_param_infos, actual_output_param_infos) = | ||
R::new_session(|| model_file::decrypt(model_bytes), options[op])?; | ||
|
||
check_param_infos(expected_input_param_infos, &actual_input_param_infos)?; | ||
check_param_infos(expected_output_param_infos, &actual_output_param_infos)?; | ||
|
||
Ok((op.into_usize(), std::sync::Mutex::new(sess).into())) | ||
}) | ||
.collect::<anyhow::Result<HashMap<_, _>>>()?; | ||
|
||
return Ok(Self(EnumMap::<D::Operation, _>::from_fn(|k| { | ||
sessions.remove(&k.into_usize()).expect("should exist") | ||
}))); | ||
|
||
fn check_param_infos<D: PartialEq + Display>( | ||
expected: &[ParamInfo<D>], | ||
actual: &[ParamInfo<D>], | ||
) -> anyhow::Result<()> { | ||
if !(expected.len() == actual.len() | ||
&& itertools::zip_eq(expected, actual) | ||
.all(|(expected, actual)| expected.accepts(actual))) | ||
{ | ||
let expected = display_param_infos(expected); | ||
let actual = display_param_infos(actual); | ||
bail!("expected {{{expected}}}, got {{{actual}}}") | ||
} | ||
Ok(()) | ||
} | ||
|
||
fn display_param_infos(infos: &[ParamInfo<impl Display>]) -> impl Display { | ||
infos | ||
.iter() | ||
.map(|ParamInfo { name, dt, ndim }| { | ||
let brackets = match *ndim { | ||
Some(ndim) => "[]".repeat(ndim), | ||
None => "[]...".to_owned(), | ||
}; | ||
format!("{name}: {dt}{brackets}") | ||
}) | ||
.join(", ") | ||
} | ||
} | ||
} | ||
|
||
impl<R: InferenceRuntime, D: InferenceDomain> SessionSet<R, D> { | ||
pub(crate) fn get<I>(&self) -> SessionCell<R, I> | ||
where | ||
I: InferenceInputSignature, | ||
I::Signature: InferenceSignature<Domain = D>, | ||
{ | ||
SessionCell { | ||
inner: self.0[I::Signature::OPERATION].clone(), | ||
marker: PhantomData, | ||
} | ||
} | ||
} | ||
|
||
pub(crate) struct SessionCell<R: InferenceRuntime, I> { | ||
inner: Arc<std::sync::Mutex<R::Session>>, | ||
marker: PhantomData<fn(I)>, | ||
} | ||
|
||
impl<R: InferenceRuntime, I: InferenceInputSignature> SessionCell<R, I> { | ||
pub(crate) fn run( | ||
self, | ||
input: I, | ||
) -> crate::Result<<I::Signature as InferenceSignature>::Output> { | ||
let inner = &mut self.inner.lock().unwrap(); | ||
let ctx = input.make_run_context::<R>(inner); | ||
R::run(ctx) | ||
.and_then(TryInto::try_into) | ||
.map_err(|e| ErrorRepr::InferenceFailed(e).into()) | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
status.rsの以下の2箇所で使われる。この2箇所が、VVM内の
StyleType
に対する検査となる。ensure_acceptable
(StyleMeta::r#type
を追加し、トークという区分を実装に導入する #761 (comment))ids_for
(StyleMeta::r#type
を追加し、トークという区分を実装に導入する #761 (comment))style_types
の実装としてはStyleType
が不足しててもよいし(上記2箇所を通らなくなって全部弾かれるようになるだけなので)、InferenceDomain
間で重複していてもよい (e.g. 今後StyleType::Sing
はSingingTeacherDomain::style_types
とFrameDecodeDomain::style_types
に含まれる)。(edit)
StyleType::{SingingTeacher,FrameDecode,Sing}
だけ追加しました。実際に"type": "singing_teacher"
のようにmetas.jsonを書いた場合、ensure_acceptable
で確実に弾かれるようになっています。