Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Rust]onnxruntimeをrust版voicevox_coreに導入 #135

Merged
merged 9 commits into from
May 25, 2022
7 changes: 7 additions & 0 deletions crates/voicevox_core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,21 @@ name = "voicevox_core"
version = "0.1.0"
edition = "2021"

[features]
default = []
directml = []

[lib]
name = "core"
crate-type = ["cdylib"]

[dependencies]
anyhow = "1.0.57"
Copy link
Contributor Author

@qwerty2501 qwerty2501 May 17, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

内部で発生したエラーをError内部のsourceとして保持するために導入

cfg-if = "1.0.0"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

directmlでbuildするかどうかで出力されるコードを変えたかったので条件わけしやすいように導入した

derive-getters = "0.2.0"
derive-new = "0.5.9"
once_cell = "1.10.0"
onnxruntime = { git = "https://github.com/qwerty2501/onnxruntime-rs.git", version = "0.0.17" }
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

コミットIDの指定ってできそうでしょうか👀
指定しておくと、今後qwertyさんがご自身のonnxruntime-rsを気軽に修正できそうだなと思い。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

コミットID指定もできたはずですが、version指定のほうがやりやすいのでこっちにさせていただいて良いでしょうか?

Copy link
Contributor Author

@qwerty2501 qwerty2501 May 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

コミットID指定だとバージョンアップするたびにコミットIDを調べないといけないのでちょっと面倒です
version指定だとonnxruntime-rs projectのcargo file内のversionを更新する必要がありますが、こっちのほうが使う側は特に何も考えずにインクリメントればよいので楽かなと
あとversionの更新はきちんとやるべきだと思いますし

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

なるほどです!バージョンが良さそうに感じました!
(変更があまり発生しなさそうになったらぜひvoicevox側でもメンテしたいですね…!)

thiserror = "1.0.31"

[dev-dependencies]
Expand Down
24 changes: 23 additions & 1 deletion crates/voicevox_core/src/c_export.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,17 @@ use std::sync::Mutex;
* これはC文脈の処理と実装をわけるためと、内部実装の変更がAPIに影響を与えにくくするためである
*/

#[repr(C)]
#[repr(i32)]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

enumの場合はrepr(i32)が正しいらしいので修正

#[derive(Debug, PartialEq)]
#[allow(non_camel_case_types)]
pub enum VoicevoxResultCode {
// C でのenum定義に合わせて大文字で定義している
// 出力フォーマットを変更すればRustでよく使われているUpperCamelにできるが、実際に出力されるコードとの差異をできるだけ少なくするため
VOICEVOX_RESULT_SUCCEED = 0,
VOICEVOX_RESULT_NOT_LOADED_OPENJTALK_DICT = 1,
VOICEVOX_RESULT_FAILED_LOAD_MODEL = 2,
VOICEVOX_RESULT_FAILED_GET_SUPPORTED_DEVICES = 3,
VOICEVOX_RESULT_CANT_GPU_SUPPORT = 4,
}

fn convert_result<T>(result: Result<T>) -> (Option<T>, VoicevoxResultCode) {
Expand All @@ -31,6 +34,16 @@ fn convert_result<T>(result: Result<T>) -> (Option<T>, VoicevoxResultCode) {
None,
VoicevoxResultCode::VOICEVOX_RESULT_NOT_LOADED_OPENJTALK_DICT,
),
Error::CantGpuSupport => {
(None, VoicevoxResultCode::VOICEVOX_RESULT_CANT_GPU_SUPPORT)
}
Error::LoadModel(_) => {
(None, VoicevoxResultCode::VOICEVOX_RESULT_FAILED_LOAD_MODEL)
}
Error::GetSupportedDevices(_) => (
None,
VoicevoxResultCode::VOICEVOX_RESULT_FAILED_GET_SUPPORTED_DEVICES,
),
PickledChair marked this conversation as resolved.
Show resolved Hide resolved
}
}
}
Expand Down Expand Up @@ -219,6 +232,7 @@ pub extern "C" fn voicevox_error_result_to_message(
#[cfg(test)]
mod tests {
use super::*;
use anyhow::anyhow;
use pretty_assertions::assert_eq;

#[rstest]
Expand All @@ -227,6 +241,14 @@ mod tests {
Err(Error::NotLoadedOpenjtalkDict),
VoicevoxResultCode::VOICEVOX_RESULT_NOT_LOADED_OPENJTALK_DICT
)]
#[case(
Err(Error::LoadModel(anyhow!("some load model error"))),
VoicevoxResultCode::VOICEVOX_RESULT_FAILED_LOAD_MODEL
)]
#[case(
Err(Error::GetSupportedDevices(anyhow!("some get supported devices error"))),
VoicevoxResultCode::VOICEVOX_RESULT_FAILED_GET_SUPPORTED_DEVICES
)]
fn convert_result_works(#[case] result: Result<()>, #[case] expected: VoicevoxResultCode) {
let (_, actual) = convert_result(result);
assert_eq!(expected, actual);
Expand Down
12 changes: 12 additions & 0 deletions crates/voicevox_core/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,18 @@ pub enum Error {
// TODO:仮実装がlinterエラーにならないようにするための属性なのでこのenumが正式に使われる際にallow(dead_code)を取り除くこと
#[allow(dead_code)]
NotLoadedOpenjtalkDict,

#[error("{}", base_error_message(VOICEVOX_RESULT_CANT_GPU_SUPPORT))]
CantGpuSupport,

#[error("{},{0}", base_error_message(VOICEVOX_RESULT_FAILED_LOAD_MODEL))]
LoadModel(#[source] anyhow::Error),

#[error(
"{},{0}",
base_error_message(VOICEVOX_RESULT_FAILED_GET_SUPPORTED_DEVICES)
)]
GetSupportedDevices(#[source] anyhow::Error),
}

fn base_error_message(result_code: VoicevoxResultCode) -> &'static str {
Expand Down
58 changes: 53 additions & 5 deletions crates/voicevox_core/src/internal.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,51 @@
use super::*;
use c_export::VoicevoxResultCode;
use once_cell::sync::Lazy;
use std::ffi::CStr;
use std::os::raw::c_int;
use std::sync::Mutex;

use status::*;

static INITIALIZED: Lazy<Mutex<bool>> = Lazy::new(|| Mutex::new(false));
static STATUS: Lazy<Mutex<Option<Status>>> = Lazy::new(|| Mutex::new(None));

//TODO:仮実装がlinterエラーにならないようにするための属性なのでこの関数を正式に実装する際にallow(unused_variables)を取り除くこと
#[allow(unused_variables)]
pub fn initialize(use_gpu: bool, cpu_num_threads: usize, load_all_models: bool) -> Result<()> {
unimplemented!()
let mut initialized = INITIALIZED.lock().unwrap();
*initialized = false;
if !use_gpu || can_support_gpu_feature()? {
let mut status_opt = STATUS.lock().unwrap();
let mut status = Status::new(use_gpu, cpu_num_threads);

// TODO: ここに status.load_metas() を呼び出すようにする
// https://github.com/VOICEVOX/voicevox_core/blob/main/core/src/core.cpp#L199-L201

if load_all_models {
for model_index in 0..Status::MODELS_COUNT {
status.load_model(model_index)?;
}
// TODO: ここにGPUメモリを確保させる処理を実装する
// https://github.com/VOICEVOX/voicevox_core/blob/main/core/src/core.cpp#L210-L219
}

*status_opt = Some(status);
*initialized = true;
Ok(())
} else {
Err(Error::CantGpuSupport)
}
}

fn can_support_gpu_feature() -> Result<bool> {
let supported_devices = SupportedDevices::get_supported_devices()?;

cfg_if! {
if #[cfg(feature = "directml")]{
Ok(*supported_devices.dml())
} else{
Ok(*supported_devices.cuda())
}
}
}

//TODO:仮実装がlinterエラーにならないようにするための属性なのでこの関数を正式に実装する際にallow(unused_variables)を取り除くこと
Expand Down Expand Up @@ -110,11 +149,20 @@ pub fn voicevox_wav_free(wav: *mut u8) -> Result<()> {

pub const fn voicevox_error_result_to_message(result_code: VoicevoxResultCode) -> &'static str {
// C APIのため、messageには必ず末尾にNULL文字を追加する
use VoicevoxResultCode::*;
match result_code {
VoicevoxResultCode::VOICEVOX_RESULT_NOT_LOADED_OPENJTALK_DICT => {
VOICEVOX_RESULT_NOT_LOADED_OPENJTALK_DICT => {
"voicevox_load_openjtalk_dict() を初めに呼んでください\0"
}
VOICEVOX_RESULT_FAILED_LOAD_MODEL => {
"modelデータ読み込み中にOnnxruntimeエラーが発生しました\0"
}

VOICEVOX_RESULT_CANT_GPU_SUPPORT => "GPU機能をサポートすることができません\0",
VOICEVOX_RESULT_FAILED_GET_SUPPORTED_DEVICES => {
"サポートされているデバイス情報取得中にエラーが発生しました\0"
}

VoicevoxResultCode::VOICEVOX_RESULT_SUCCEED => "エラーが発生しませんでした\0",
VOICEVOX_RESULT_SUCCEED => "エラーが発生しませんでした\0",
}
}
5 changes: 5 additions & 0 deletions crates/voicevox_core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,14 @@ mod c_export;
mod error;
mod internal;
mod result;
mod status;

use error::*;
use result::*;

use derive_getters::*;
use derive_new::new;
#[cfg(test)]
use rstest::*;

use cfg_if::cfg_if;
Hiroshiba marked this conversation as resolved.
Show resolved Hide resolved
184 changes: 184 additions & 0 deletions crates/voicevox_core/src/status.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
use super::*;
use once_cell::sync::Lazy;
use onnxruntime::{
environment::Environment, session::Session, GraphOptimizationLevel, LoggingLevel,
};
cfg_if! {
if #[cfg(not(feature="directml"))]{
use onnxruntime::CudaProviderOptions;
}
}
use std::collections::BTreeMap;

pub struct Status {
models: StatusModels,
session_options: SessionOptions,
}

struct StatusModels {
yukarin_s: BTreeMap<usize, Session<'static>>,
yukarin_sa: BTreeMap<usize, Session<'static>>,
decode: BTreeMap<usize, Session<'static>>,
}

#[derive(new, Getters)]
struct SessionOptions {
cpu_num_threads: usize,
use_gpu: bool,
}

struct Model {
yukarin_s_model: &'static [u8],
yukarin_sa_model: &'static [u8],
decode_model: &'static [u8],
}

static ENVIRONMENT: Lazy<Environment> = Lazy::new(|| {
cfg_if! {
if #[cfg(debug_assertions)]{
const LOGGING_LEVEL: LoggingLevel = LoggingLevel::Verbose;
} else{
const LOGGING_LEVEL: LoggingLevel = LoggingLevel::Warning;
}
}
Environment::builder()
.with_name(env!("CARGO_PKG_NAME"))
.with_log_level(LOGGING_LEVEL)
.build()
.unwrap()
});

#[derive(Getters)]
pub struct SupportedDevices {
// TODO:supported_devices関数を実装したらこのattributeをはずす
#[allow(dead_code)]
cpu: bool,
// TODO:supported_devices関数を実装したらこのattributeをはずす
#[allow(dead_code)]
cuda: bool,
// TODO:supported_devices関数を実装したらこのattributeをはずす
#[allow(dead_code)]
dml: bool,
}

impl SupportedDevices {
pub fn get_supported_devices() -> Result<Self> {
let mut cuda_support = false;
let mut dml_support = false;
for provider in onnxruntime::session::get_available_providers()
.map_err(|e| Error::GetSupportedDevices(e.into()))?
.iter()
{
match provider.as_str() {
"CUDAExecutionProvider" => cuda_support = true,
"DmlExecutionProvider" => dml_support = true,
_ => {}
}
}

Ok(SupportedDevices {
cpu: true,
cuda: cuda_support,
dml: dml_support,
})
}
}

unsafe impl Send for Status {}
unsafe impl Sync for Status {}

impl Status {
const YUKARIN_S_MODEL: &'static [u8] = include_bytes!(concat!(
env!("CARGO_WORKSPACE_DIR"),
"/model/yukarin_s.onnx"
));
const YUKARIN_SA_MODEL: &'static [u8] = include_bytes!(concat!(
env!("CARGO_WORKSPACE_DIR"),
"/model/yukarin_sa.onnx"
));

const DECODE_MODEL: &'static [u8] =
include_bytes!(concat!(env!("CARGO_WORKSPACE_DIR"), "/model/decode.onnx"));

const MODELS: [Model; 1] = [Model {
yukarin_s_model: Self::YUKARIN_S_MODEL,
yukarin_sa_model: Self::YUKARIN_SA_MODEL,
decode_model: Self::DECODE_MODEL,
}];
pub const MODELS_COUNT: usize = Self::MODELS.len();

pub fn new(use_gpu: bool, cpu_num_threads: usize) -> Self {
Self {
models: StatusModels {
yukarin_s: BTreeMap::new(),
yukarin_sa: BTreeMap::new(),
decode: BTreeMap::new(),
},
session_options: SessionOptions::new(cpu_num_threads, use_gpu),
}
}

pub fn load_model(&mut self, model_index: usize) -> Result<()> {
let model = &Self::MODELS[model_index];
let yukarin_s_session = self
.new_session(model.yukarin_s_model)
.map_err(Error::LoadModel)?;
let yukarin_sa_session = self
.new_session(model.yukarin_sa_model)
.map_err(Error::LoadModel)?;
let decode_model = self
.new_session(model.decode_model)
.map_err(Error::LoadModel)?;

self.models.yukarin_s.insert(model_index, yukarin_s_session);
self.models
.yukarin_sa
.insert(model_index, yukarin_sa_session);

self.models.decode.insert(model_index, decode_model);

Ok(())
}

fn new_session<B: AsRef<[u8]>>(
&self,
model_bytes: B,
) -> std::result::Result<Session<'static>, anyhow::Error> {
let session_builder = ENVIRONMENT
.new_session_builder()?
.with_optimization_level(GraphOptimizationLevel::Basic)?
.with_intra_op_num_threads(*self.session_options.cpu_num_threads() as i32)?
.with_inter_op_num_threads(*self.session_options.cpu_num_threads() as i32)?;

let session_builder = if *self.session_options.use_gpu() {
cfg_if! {
if #[cfg(feature = "directml")]{
session_builder
.with_disable_mem_pattern()?
.with_execution_mode(onnxruntime::ExecutionMode::ORT_SEQUENTIAL)?
} else {
let options = CudaProviderOptions::default();
session_builder
.with_disable_mem_pattern()?
.with_append_execution_provider_cuda(options)?
}
}
} else {
session_builder
};

Ok(session_builder.with_model_from_memory(model_bytes)?)
}
}

#[cfg(test)]
mod tests {
use super::*;

#[rstest]
fn supported_devices_get_supported_devices_works() {
let result = SupportedDevices::get_supported_devices();
// 環境によって結果が変わるので、関数呼び出しが成功するかどうかの確認のみ行う
assert!(result.is_ok());
}
}