-
Notifications
You must be signed in to change notification settings - Fork 118
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Rust]onnxruntimeをrust版voicevox_coreに導入 #135
Changes from all commits
6a986d0
6bbe029
7edc8c7
3563c1a
fb51a18
3c8b004
93062cf
5258c05
6d11919
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,14 +3,21 @@ name = "voicevox_core" | |
version = "0.1.0" | ||
edition = "2021" | ||
|
||
[features] | ||
default = [] | ||
directml = [] | ||
|
||
[lib] | ||
name = "core" | ||
crate-type = ["cdylib"] | ||
|
||
[dependencies] | ||
anyhow = "1.0.57" | ||
cfg-if = "1.0.0" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. directmlでbuildするかどうかで出力されるコードを変えたかったので条件わけしやすいように導入した |
||
derive-getters = "0.2.0" | ||
derive-new = "0.5.9" | ||
once_cell = "1.10.0" | ||
onnxruntime = { git = "https://github.com/qwerty2501/onnxruntime-rs.git", version = "0.0.17" } | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. コミットIDの指定ってできそうでしょうか👀 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. コミットID指定もできたはずですが、version指定のほうがやりやすいのでこっちにさせていただいて良いでしょうか? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. コミットID指定だとバージョンアップするたびにコミットIDを調べないといけないのでちょっと面倒です There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. なるほどです!バージョンが良さそうに感じました! |
||
thiserror = "1.0.31" | ||
|
||
[dev-dependencies] | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,14 +10,17 @@ use std::sync::Mutex; | |
* これはC文脈の処理と実装をわけるためと、内部実装の変更がAPIに影響を与えにくくするためである | ||
*/ | ||
|
||
#[repr(C)] | ||
#[repr(i32)] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. enumの場合はrepr(i32)が正しいらしいので修正 |
||
#[derive(Debug, PartialEq)] | ||
#[allow(non_camel_case_types)] | ||
pub enum VoicevoxResultCode { | ||
// C でのenum定義に合わせて大文字で定義している | ||
// 出力フォーマットを変更すればRustでよく使われているUpperCamelにできるが、実際に出力されるコードとの差異をできるだけ少なくするため | ||
VOICEVOX_RESULT_SUCCEED = 0, | ||
VOICEVOX_RESULT_NOT_LOADED_OPENJTALK_DICT = 1, | ||
VOICEVOX_RESULT_FAILED_LOAD_MODEL = 2, | ||
VOICEVOX_RESULT_FAILED_GET_SUPPORTED_DEVICES = 3, | ||
VOICEVOX_RESULT_CANT_GPU_SUPPORT = 4, | ||
} | ||
|
||
fn convert_result<T>(result: Result<T>) -> (Option<T>, VoicevoxResultCode) { | ||
|
@@ -31,6 +34,16 @@ fn convert_result<T>(result: Result<T>) -> (Option<T>, VoicevoxResultCode) { | |
None, | ||
VoicevoxResultCode::VOICEVOX_RESULT_NOT_LOADED_OPENJTALK_DICT, | ||
), | ||
Error::CantGpuSupport => { | ||
(None, VoicevoxResultCode::VOICEVOX_RESULT_CANT_GPU_SUPPORT) | ||
} | ||
Error::LoadModel(_) => { | ||
(None, VoicevoxResultCode::VOICEVOX_RESULT_FAILED_LOAD_MODEL) | ||
} | ||
Error::GetSupportedDevices(_) => ( | ||
None, | ||
VoicevoxResultCode::VOICEVOX_RESULT_FAILED_GET_SUPPORTED_DEVICES, | ||
), | ||
PickledChair marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
} | ||
} | ||
|
@@ -219,6 +232,7 @@ pub extern "C" fn voicevox_error_result_to_message( | |
#[cfg(test)] | ||
mod tests { | ||
use super::*; | ||
use anyhow::anyhow; | ||
use pretty_assertions::assert_eq; | ||
|
||
#[rstest] | ||
|
@@ -227,6 +241,14 @@ mod tests { | |
Err(Error::NotLoadedOpenjtalkDict), | ||
VoicevoxResultCode::VOICEVOX_RESULT_NOT_LOADED_OPENJTALK_DICT | ||
)] | ||
#[case( | ||
Err(Error::LoadModel(anyhow!("some load model error"))), | ||
VoicevoxResultCode::VOICEVOX_RESULT_FAILED_LOAD_MODEL | ||
)] | ||
#[case( | ||
Err(Error::GetSupportedDevices(anyhow!("some get supported devices error"))), | ||
VoicevoxResultCode::VOICEVOX_RESULT_FAILED_GET_SUPPORTED_DEVICES | ||
)] | ||
fn convert_result_works(#[case] result: Result<()>, #[case] expected: VoicevoxResultCode) { | ||
let (_, actual) = convert_result(result); | ||
assert_eq!(expected, actual); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,184 @@ | ||
use super::*; | ||
use once_cell::sync::Lazy; | ||
use onnxruntime::{ | ||
environment::Environment, session::Session, GraphOptimizationLevel, LoggingLevel, | ||
}; | ||
cfg_if! { | ||
if #[cfg(not(feature="directml"))]{ | ||
use onnxruntime::CudaProviderOptions; | ||
} | ||
} | ||
use std::collections::BTreeMap; | ||
|
||
pub struct Status { | ||
models: StatusModels, | ||
session_options: SessionOptions, | ||
} | ||
|
||
struct StatusModels { | ||
yukarin_s: BTreeMap<usize, Session<'static>>, | ||
yukarin_sa: BTreeMap<usize, Session<'static>>, | ||
decode: BTreeMap<usize, Session<'static>>, | ||
} | ||
|
||
#[derive(new, Getters)] | ||
struct SessionOptions { | ||
cpu_num_threads: usize, | ||
use_gpu: bool, | ||
} | ||
|
||
struct Model { | ||
yukarin_s_model: &'static [u8], | ||
yukarin_sa_model: &'static [u8], | ||
decode_model: &'static [u8], | ||
} | ||
|
||
static ENVIRONMENT: Lazy<Environment> = Lazy::new(|| { | ||
cfg_if! { | ||
if #[cfg(debug_assertions)]{ | ||
const LOGGING_LEVEL: LoggingLevel = LoggingLevel::Verbose; | ||
} else{ | ||
const LOGGING_LEVEL: LoggingLevel = LoggingLevel::Warning; | ||
} | ||
} | ||
Environment::builder() | ||
.with_name(env!("CARGO_PKG_NAME")) | ||
.with_log_level(LOGGING_LEVEL) | ||
.build() | ||
.unwrap() | ||
}); | ||
|
||
#[derive(Getters)] | ||
pub struct SupportedDevices { | ||
// TODO:supported_devices関数を実装したらこのattributeをはずす | ||
#[allow(dead_code)] | ||
cpu: bool, | ||
// TODO:supported_devices関数を実装したらこのattributeをはずす | ||
#[allow(dead_code)] | ||
cuda: bool, | ||
// TODO:supported_devices関数を実装したらこのattributeをはずす | ||
#[allow(dead_code)] | ||
dml: bool, | ||
} | ||
|
||
impl SupportedDevices { | ||
pub fn get_supported_devices() -> Result<Self> { | ||
let mut cuda_support = false; | ||
let mut dml_support = false; | ||
for provider in onnxruntime::session::get_available_providers() | ||
.map_err(|e| Error::GetSupportedDevices(e.into()))? | ||
.iter() | ||
{ | ||
match provider.as_str() { | ||
"CUDAExecutionProvider" => cuda_support = true, | ||
"DmlExecutionProvider" => dml_support = true, | ||
_ => {} | ||
} | ||
} | ||
|
||
Ok(SupportedDevices { | ||
cpu: true, | ||
cuda: cuda_support, | ||
dml: dml_support, | ||
}) | ||
} | ||
} | ||
|
||
unsafe impl Send for Status {} | ||
unsafe impl Sync for Status {} | ||
|
||
impl Status { | ||
const YUKARIN_S_MODEL: &'static [u8] = include_bytes!(concat!( | ||
env!("CARGO_WORKSPACE_DIR"), | ||
"/model/yukarin_s.onnx" | ||
)); | ||
const YUKARIN_SA_MODEL: &'static [u8] = include_bytes!(concat!( | ||
env!("CARGO_WORKSPACE_DIR"), | ||
"/model/yukarin_sa.onnx" | ||
)); | ||
|
||
const DECODE_MODEL: &'static [u8] = | ||
include_bytes!(concat!(env!("CARGO_WORKSPACE_DIR"), "/model/decode.onnx")); | ||
|
||
const MODELS: [Model; 1] = [Model { | ||
yukarin_s_model: Self::YUKARIN_S_MODEL, | ||
yukarin_sa_model: Self::YUKARIN_SA_MODEL, | ||
decode_model: Self::DECODE_MODEL, | ||
}]; | ||
pub const MODELS_COUNT: usize = Self::MODELS.len(); | ||
|
||
pub fn new(use_gpu: bool, cpu_num_threads: usize) -> Self { | ||
Self { | ||
models: StatusModels { | ||
yukarin_s: BTreeMap::new(), | ||
yukarin_sa: BTreeMap::new(), | ||
decode: BTreeMap::new(), | ||
}, | ||
session_options: SessionOptions::new(cpu_num_threads, use_gpu), | ||
} | ||
} | ||
|
||
pub fn load_model(&mut self, model_index: usize) -> Result<()> { | ||
let model = &Self::MODELS[model_index]; | ||
let yukarin_s_session = self | ||
.new_session(model.yukarin_s_model) | ||
.map_err(Error::LoadModel)?; | ||
let yukarin_sa_session = self | ||
.new_session(model.yukarin_sa_model) | ||
.map_err(Error::LoadModel)?; | ||
let decode_model = self | ||
.new_session(model.decode_model) | ||
.map_err(Error::LoadModel)?; | ||
|
||
self.models.yukarin_s.insert(model_index, yukarin_s_session); | ||
self.models | ||
.yukarin_sa | ||
.insert(model_index, yukarin_sa_session); | ||
|
||
self.models.decode.insert(model_index, decode_model); | ||
|
||
Ok(()) | ||
} | ||
|
||
fn new_session<B: AsRef<[u8]>>( | ||
&self, | ||
model_bytes: B, | ||
) -> std::result::Result<Session<'static>, anyhow::Error> { | ||
let session_builder = ENVIRONMENT | ||
.new_session_builder()? | ||
.with_optimization_level(GraphOptimizationLevel::Basic)? | ||
.with_intra_op_num_threads(*self.session_options.cpu_num_threads() as i32)? | ||
.with_inter_op_num_threads(*self.session_options.cpu_num_threads() as i32)?; | ||
|
||
let session_builder = if *self.session_options.use_gpu() { | ||
cfg_if! { | ||
if #[cfg(feature = "directml")]{ | ||
session_builder | ||
.with_disable_mem_pattern()? | ||
.with_execution_mode(onnxruntime::ExecutionMode::ORT_SEQUENTIAL)? | ||
} else { | ||
let options = CudaProviderOptions::default(); | ||
session_builder | ||
.with_disable_mem_pattern()? | ||
.with_append_execution_provider_cuda(options)? | ||
} | ||
} | ||
} else { | ||
session_builder | ||
}; | ||
|
||
Ok(session_builder.with_model_from_memory(model_bytes)?) | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use super::*; | ||
|
||
#[rstest] | ||
fn supported_devices_get_supported_devices_works() { | ||
let result = SupportedDevices::get_supported_devices(); | ||
// 環境によって結果が変わるので、関数呼び出しが成功するかどうかの確認のみ行う | ||
assert!(result.is_ok()); | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
内部で発生したエラーをError内部のsourceとして保持するために導入