diff --git a/crates/voicevox_core/Cargo.toml b/crates/voicevox_core/Cargo.toml index 7f63d504a..48b41e214 100644 --- a/crates/voicevox_core/Cargo.toml +++ b/crates/voicevox_core/Cargo.toml @@ -3,14 +3,21 @@ name = "voicevox_core" version = "0.1.0" edition = "2021" +[features] +default = [] +directml = [] + [lib] name = "core" crate-type = ["cdylib"] [dependencies] +anyhow = "1.0.57" +cfg-if = "1.0.0" derive-getters = "0.2.0" derive-new = "0.5.9" once_cell = "1.10.0" +onnxruntime = { git = "https://github.com/qwerty2501/onnxruntime-rs.git", version = "0.0.17" } thiserror = "1.0.31" [dev-dependencies] diff --git a/crates/voicevox_core/src/c_export.rs b/crates/voicevox_core/src/c_export.rs index dd213da57..d1bc3254f 100644 --- a/crates/voicevox_core/src/c_export.rs +++ b/crates/voicevox_core/src/c_export.rs @@ -10,7 +10,7 @@ use std::sync::Mutex; * これはC文脈の処理と実装をわけるためと、内部実装の変更がAPIに影響を与えにくくするためである */ -#[repr(C)] +#[repr(i32)] #[derive(Debug, PartialEq)] #[allow(non_camel_case_types)] pub enum VoicevoxResultCode { @@ -18,6 +18,9 @@ pub enum VoicevoxResultCode { // 出力フォーマットを変更すればRustでよく使われているUpperCamelにできるが、実際に出力されるコードとの差異をできるだけ少なくするため VOICEVOX_RESULT_SUCCEED = 0, VOICEVOX_RESULT_NOT_LOADED_OPENJTALK_DICT = 1, + VOICEVOX_RESULT_FAILED_LOAD_MODEL = 2, + VOICEVOX_RESULT_FAILED_GET_SUPPORTED_DEVICES = 3, + VOICEVOX_RESULT_CANT_GPU_SUPPORT = 4, } fn convert_result(result: Result) -> (Option, VoicevoxResultCode) { @@ -31,6 +34,16 @@ fn convert_result(result: Result) -> (Option, VoicevoxResultCode) { None, VoicevoxResultCode::VOICEVOX_RESULT_NOT_LOADED_OPENJTALK_DICT, ), + Error::CantGpuSupport => { + (None, VoicevoxResultCode::VOICEVOX_RESULT_CANT_GPU_SUPPORT) + } + Error::LoadModel(_) => { + (None, VoicevoxResultCode::VOICEVOX_RESULT_FAILED_LOAD_MODEL) + } + Error::GetSupportedDevices(_) => ( + None, + VoicevoxResultCode::VOICEVOX_RESULT_FAILED_GET_SUPPORTED_DEVICES, + ), } } } @@ -219,6 +232,7 @@ pub extern "C" fn voicevox_error_result_to_message( #[cfg(test)] mod tests { use super::*; + use anyhow::anyhow; use pretty_assertions::assert_eq; #[rstest] @@ -227,6 +241,14 @@ mod tests { Err(Error::NotLoadedOpenjtalkDict), VoicevoxResultCode::VOICEVOX_RESULT_NOT_LOADED_OPENJTALK_DICT )] + #[case( + Err(Error::LoadModel(anyhow!("some load model error"))), + VoicevoxResultCode::VOICEVOX_RESULT_FAILED_LOAD_MODEL + )] + #[case( + Err(Error::GetSupportedDevices(anyhow!("some get supported devices error"))), + VoicevoxResultCode::VOICEVOX_RESULT_FAILED_GET_SUPPORTED_DEVICES + )] fn convert_result_works(#[case] result: Result<()>, #[case] expected: VoicevoxResultCode) { let (_, actual) = convert_result(result); assert_eq!(expected, actual); diff --git a/crates/voicevox_core/src/error.rs b/crates/voicevox_core/src/error.rs index 4a91234d0..23abc45be 100644 --- a/crates/voicevox_core/src/error.rs +++ b/crates/voicevox_core/src/error.rs @@ -16,6 +16,18 @@ pub enum Error { // TODO:仮実装がlinterエラーにならないようにするための属性なのでこのenumが正式に使われる際にallow(dead_code)を取り除くこと #[allow(dead_code)] NotLoadedOpenjtalkDict, + + #[error("{}", base_error_message(VOICEVOX_RESULT_CANT_GPU_SUPPORT))] + CantGpuSupport, + + #[error("{},{0}", base_error_message(VOICEVOX_RESULT_FAILED_LOAD_MODEL))] + LoadModel(#[source] anyhow::Error), + + #[error( + "{},{0}", + base_error_message(VOICEVOX_RESULT_FAILED_GET_SUPPORTED_DEVICES) + )] + GetSupportedDevices(#[source] anyhow::Error), } fn base_error_message(result_code: VoicevoxResultCode) -> &'static str { diff --git a/crates/voicevox_core/src/internal.rs b/crates/voicevox_core/src/internal.rs index 8dcdeb2ef..06dd73499 100644 --- a/crates/voicevox_core/src/internal.rs +++ b/crates/voicevox_core/src/internal.rs @@ -1,12 +1,51 @@ use super::*; use c_export::VoicevoxResultCode; +use once_cell::sync::Lazy; use std::ffi::CStr; use std::os::raw::c_int; +use std::sync::Mutex; + +use status::*; + +static INITIALIZED: Lazy> = Lazy::new(|| Mutex::new(false)); +static STATUS: Lazy>> = Lazy::new(|| Mutex::new(None)); -//TODO:仮実装がlinterエラーにならないようにするための属性なのでこの関数を正式に実装する際にallow(unused_variables)を取り除くこと -#[allow(unused_variables)] pub fn initialize(use_gpu: bool, cpu_num_threads: usize, load_all_models: bool) -> Result<()> { - unimplemented!() + let mut initialized = INITIALIZED.lock().unwrap(); + *initialized = false; + if !use_gpu || can_support_gpu_feature()? { + let mut status_opt = STATUS.lock().unwrap(); + let mut status = Status::new(use_gpu, cpu_num_threads); + + // TODO: ここに status.load_metas() を呼び出すようにする + // https://github.com/VOICEVOX/voicevox_core/blob/main/core/src/core.cpp#L199-L201 + + if load_all_models { + for model_index in 0..Status::MODELS_COUNT { + status.load_model(model_index)?; + } + // TODO: ここにGPUメモリを確保させる処理を実装する + // https://github.com/VOICEVOX/voicevox_core/blob/main/core/src/core.cpp#L210-L219 + } + + *status_opt = Some(status); + *initialized = true; + Ok(()) + } else { + Err(Error::CantGpuSupport) + } +} + +fn can_support_gpu_feature() -> Result { + let supported_devices = SupportedDevices::get_supported_devices()?; + + cfg_if! { + if #[cfg(feature = "directml")]{ + Ok(*supported_devices.dml()) + } else{ + Ok(*supported_devices.cuda()) + } + } } //TODO:仮実装がlinterエラーにならないようにするための属性なのでこの関数を正式に実装する際にallow(unused_variables)を取り除くこと @@ -110,11 +149,20 @@ pub fn voicevox_wav_free(wav: *mut u8) -> Result<()> { pub const fn voicevox_error_result_to_message(result_code: VoicevoxResultCode) -> &'static str { // C APIのため、messageには必ず末尾にNULL文字を追加する + use VoicevoxResultCode::*; match result_code { - VoicevoxResultCode::VOICEVOX_RESULT_NOT_LOADED_OPENJTALK_DICT => { + VOICEVOX_RESULT_NOT_LOADED_OPENJTALK_DICT => { "voicevox_load_openjtalk_dict() を初めに呼んでください\0" } + VOICEVOX_RESULT_FAILED_LOAD_MODEL => { + "modelデータ読み込み中にOnnxruntimeエラーが発生しました\0" + } + + VOICEVOX_RESULT_CANT_GPU_SUPPORT => "GPU機能をサポートすることができません\0", + VOICEVOX_RESULT_FAILED_GET_SUPPORTED_DEVICES => { + "サポートされているデバイス情報取得中にエラーが発生しました\0" + } - VoicevoxResultCode::VOICEVOX_RESULT_SUCCEED => "エラーが発生しませんでした\0", + VOICEVOX_RESULT_SUCCEED => "エラーが発生しませんでした\0", } } diff --git a/crates/voicevox_core/src/lib.rs b/crates/voicevox_core/src/lib.rs index 7ec79b138..e0c2c0688 100644 --- a/crates/voicevox_core/src/lib.rs +++ b/crates/voicevox_core/src/lib.rs @@ -2,9 +2,14 @@ mod c_export; mod error; mod internal; mod result; +mod status; use error::*; use result::*; +use derive_getters::*; +use derive_new::new; #[cfg(test)] use rstest::*; + +use cfg_if::cfg_if; diff --git a/crates/voicevox_core/src/status.rs b/crates/voicevox_core/src/status.rs new file mode 100644 index 000000000..521efb70f --- /dev/null +++ b/crates/voicevox_core/src/status.rs @@ -0,0 +1,184 @@ +use super::*; +use once_cell::sync::Lazy; +use onnxruntime::{ + environment::Environment, session::Session, GraphOptimizationLevel, LoggingLevel, +}; +cfg_if! { + if #[cfg(not(feature="directml"))]{ + use onnxruntime::CudaProviderOptions; + } +} +use std::collections::BTreeMap; + +pub struct Status { + models: StatusModels, + session_options: SessionOptions, +} + +struct StatusModels { + yukarin_s: BTreeMap>, + yukarin_sa: BTreeMap>, + decode: BTreeMap>, +} + +#[derive(new, Getters)] +struct SessionOptions { + cpu_num_threads: usize, + use_gpu: bool, +} + +struct Model { + yukarin_s_model: &'static [u8], + yukarin_sa_model: &'static [u8], + decode_model: &'static [u8], +} + +static ENVIRONMENT: Lazy = Lazy::new(|| { + cfg_if! { + if #[cfg(debug_assertions)]{ + const LOGGING_LEVEL: LoggingLevel = LoggingLevel::Verbose; + } else{ + const LOGGING_LEVEL: LoggingLevel = LoggingLevel::Warning; + } + } + Environment::builder() + .with_name(env!("CARGO_PKG_NAME")) + .with_log_level(LOGGING_LEVEL) + .build() + .unwrap() +}); + +#[derive(Getters)] +pub struct SupportedDevices { + // TODO:supported_devices関数を実装したらこのattributeをはずす + #[allow(dead_code)] + cpu: bool, + // TODO:supported_devices関数を実装したらこのattributeをはずす + #[allow(dead_code)] + cuda: bool, + // TODO:supported_devices関数を実装したらこのattributeをはずす + #[allow(dead_code)] + dml: bool, +} + +impl SupportedDevices { + pub fn get_supported_devices() -> Result { + let mut cuda_support = false; + let mut dml_support = false; + for provider in onnxruntime::session::get_available_providers() + .map_err(|e| Error::GetSupportedDevices(e.into()))? + .iter() + { + match provider.as_str() { + "CUDAExecutionProvider" => cuda_support = true, + "DmlExecutionProvider" => dml_support = true, + _ => {} + } + } + + Ok(SupportedDevices { + cpu: true, + cuda: cuda_support, + dml: dml_support, + }) + } +} + +unsafe impl Send for Status {} +unsafe impl Sync for Status {} + +impl Status { + const YUKARIN_S_MODEL: &'static [u8] = include_bytes!(concat!( + env!("CARGO_WORKSPACE_DIR"), + "/model/yukarin_s.onnx" + )); + const YUKARIN_SA_MODEL: &'static [u8] = include_bytes!(concat!( + env!("CARGO_WORKSPACE_DIR"), + "/model/yukarin_sa.onnx" + )); + + const DECODE_MODEL: &'static [u8] = + include_bytes!(concat!(env!("CARGO_WORKSPACE_DIR"), "/model/decode.onnx")); + + const MODELS: [Model; 1] = [Model { + yukarin_s_model: Self::YUKARIN_S_MODEL, + yukarin_sa_model: Self::YUKARIN_SA_MODEL, + decode_model: Self::DECODE_MODEL, + }]; + pub const MODELS_COUNT: usize = Self::MODELS.len(); + + pub fn new(use_gpu: bool, cpu_num_threads: usize) -> Self { + Self { + models: StatusModels { + yukarin_s: BTreeMap::new(), + yukarin_sa: BTreeMap::new(), + decode: BTreeMap::new(), + }, + session_options: SessionOptions::new(cpu_num_threads, use_gpu), + } + } + + pub fn load_model(&mut self, model_index: usize) -> Result<()> { + let model = &Self::MODELS[model_index]; + let yukarin_s_session = self + .new_session(model.yukarin_s_model) + .map_err(Error::LoadModel)?; + let yukarin_sa_session = self + .new_session(model.yukarin_sa_model) + .map_err(Error::LoadModel)?; + let decode_model = self + .new_session(model.decode_model) + .map_err(Error::LoadModel)?; + + self.models.yukarin_s.insert(model_index, yukarin_s_session); + self.models + .yukarin_sa + .insert(model_index, yukarin_sa_session); + + self.models.decode.insert(model_index, decode_model); + + Ok(()) + } + + fn new_session>( + &self, + model_bytes: B, + ) -> std::result::Result, anyhow::Error> { + let session_builder = ENVIRONMENT + .new_session_builder()? + .with_optimization_level(GraphOptimizationLevel::Basic)? + .with_intra_op_num_threads(*self.session_options.cpu_num_threads() as i32)? + .with_inter_op_num_threads(*self.session_options.cpu_num_threads() as i32)?; + + let session_builder = if *self.session_options.use_gpu() { + cfg_if! { + if #[cfg(feature = "directml")]{ + session_builder + .with_disable_mem_pattern()? + .with_execution_mode(onnxruntime::ExecutionMode::ORT_SEQUENTIAL)? + } else { + let options = CudaProviderOptions::default(); + session_builder + .with_disable_mem_pattern()? + .with_append_execution_provider_cuda(options)? + } + } + } else { + session_builder + }; + + Ok(session_builder.with_model_from_memory(model_bytes)?) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[rstest] + fn supported_devices_get_supported_devices_works() { + let result = SupportedDevices::get_supported_devices(); + // 環境によって結果が変わるので、関数呼び出しが成功するかどうかの確認のみ行う + assert!(result.is_ok()); + } +}