Skip to content

Commit

Permalink
[Rust]is_model_loadedを実装した (VOICEVOX#151)
Browse files Browse the repository at this point in the history
* is_model_loadedを実装した

* is_model_loadedに渡す数値をspeaker_idではなくmodel_indexにした
  • Loading branch information
qwerty2501 committed Jul 23, 2022
1 parent 705c95d commit 8c2b38e
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 5 deletions.
2 changes: 1 addition & 1 deletion crates/voicevox_core/src/c_export.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ pub extern "C" fn load_model(speaker_id: i64) -> bool {

#[no_mangle]
pub extern "C" fn is_model_loaded(speaker_id: i64) -> bool {
internal::is_model_loaded(speaker_id)
internal::is_model_loaded(speaker_id as usize)
}

#[no_mangle]
Expand Down
31 changes: 27 additions & 4 deletions crates/voicevox_core/src/internal.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use super::*;
use c_export::VoicevoxResultCode;
use once_cell::sync::Lazy;
use std::collections::BTreeMap;
use std::ffi::CStr;
use std::os::raw::c_int;
use std::sync::Mutex;
Expand All @@ -11,6 +12,8 @@ use std::ffi::CString;
static INITIALIZED: Lazy<Mutex<bool>> = Lazy::new(|| Mutex::new(false));
static STATUS: Lazy<Mutex<Option<Status>>> = Lazy::new(|| Mutex::new(None));

static SPEAKER_ID_MAP: Lazy<BTreeMap<usize, (usize, usize)>> = Lazy::new(BTreeMap::new);

pub fn initialize(use_gpu: bool, cpu_num_threads: usize, load_all_models: bool) -> Result<()> {
let mut initialized = INITIALIZED.lock().unwrap();
*initialized = false;
Expand Down Expand Up @@ -58,10 +61,13 @@ pub fn load_model(speaker_id: i64) -> Result<()> {
}
}

//TODO:仮実装がlinterエラーにならないようにするための属性なのでこの関数を正式に実装する際にallow(unused_variables)を取り除くこと
#[allow(unused_variables)]
pub fn is_model_loaded(speaker_id: i64) -> bool {
unimplemented!()
pub fn is_model_loaded(speaker_id: usize) -> bool {
if let Some(status) = STATUS.lock().unwrap().as_ref() {
let (model_index, _) = get_model_index_and_speaker_id(speaker_id);
status.is_model_loaded(model_index)
} else {
false
}
}

pub fn finalize() {
Expand Down Expand Up @@ -158,6 +164,10 @@ pub fn voicevox_wav_free(wav: *mut u8) -> Result<()> {
unimplemented!()
}

fn get_model_index_and_speaker_id(speaker_id: usize) -> (usize, usize) {
*SPEAKER_ID_MAP.get(&speaker_id).unwrap_or(&(0, speaker_id))
}

pub const fn voicevox_error_result_to_message(result_code: VoicevoxResultCode) -> &'static str {
// C APIのため、messageには必ず末尾にNULL文字を追加する
use VoicevoxResultCode::*;
Expand All @@ -183,6 +193,7 @@ pub const fn voicevox_error_result_to_message(result_code: VoicevoxResultCode) -
#[cfg(test)]
mod tests {
use super::*;
use pretty_assertions::assert_eq;

#[rstest]
fn supported_devices_works() {
Expand All @@ -193,4 +204,16 @@ mod tests {
serde_json::from_str(cstr_result.to_str().unwrap());
assert!(json_result.is_ok(), "{:?}", json_result);
}

#[rstest]
#[case(0,(0,0))]
#[case(1,(0,1))]
#[case(3,(0,3))]
fn get_model_index_and_speaker_id_works(
#[case] speaker_id: usize,
#[case] expected: (usize, usize),
) {
let actual = get_model_index_and_speaker_id(speaker_id);
assert_eq!(expected, actual);
}
}
22 changes: 22 additions & 0 deletions crates/voicevox_core/src/status.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,12 @@ impl Status {
Ok(())
}

pub fn is_model_loaded(&self, model_index: usize) -> bool {
self.models.yukarin_sa.contains_key(&model_index)
&& self.models.yukarin_s.contains_key(&model_index)
&& self.models.decode.contains_key(&model_index)
}

fn new_session<B: AsRef<[u8]>>(
&self,
model_bytes: B,
Expand Down Expand Up @@ -246,4 +252,20 @@ mod tests {
assert_eq!(1, status.models.yukarin_sa.len());
assert_eq!(1, status.models.decode.len());
}

#[rstest]
fn status_is_model_loaded_works() {
let mut status = Status::new(false, 0);
let model_index = 0;
assert!(
!status.is_model_loaded(model_index),
"model should not be loaded"
);
let result = status.load_model(model_index);
assert!(result.is_ok(), "{:?}", result);
assert!(
status.is_model_loaded(model_index),
"model should be loaded"
);
}
}

0 comments on commit 8c2b38e

Please sign in to comment.