Skip to content

Commit

Permalink
新クラス設計API (#370)
Browse files Browse the repository at this point in the history
Co-authored-by: Ryo Yamashita <[email protected]>
Co-authored-by: Nanashi. <[email protected]>
  • Loading branch information
3 people authored May 22, 2023
1 parent eb6768f commit fb24f4f
Show file tree
Hide file tree
Showing 52 changed files with 3,929 additions and 2,756 deletions.
352 changes: 244 additions & 108 deletions Cargo.lock

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,14 @@ once_cell = "1.15.0"
process_path = { git = "https://github.com/VOICEVOX/process_path.git", rev = "de226a26e8e18edbdb1d6f986afe37bbbf35fbf4" }
regex = "1.6.0"
serde = { version = "1.0.145", features = ["derive"] }
serde_json = "1.0.85"
serde_json = { version = "1.0.85", features = ["preserve_order"] }
test_util = { path = "crates/test_util" }
thiserror = "1.0.37"
tracing = { version = "0.1.37", features = ["log"] }
tracing-subscriber = { version = "0.3.16", features = ["env-filter"] }
voicevox_core = { path = "crates/voicevox_core" }
tokio = { version = "1.25.0", features = ["rt", "rt-multi-thread", "macros", "sync"] }
derive-getters = "0.2.0"

# min-sized-rustを元にrelease buildのサイズが小さくなるようにした
# https://github.com/johnthagen/min-sized-rust
Expand Down
2 changes: 1 addition & 1 deletion crates/download/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ platforms = "3.0.2"
rayon = "1.6.1"
reqwest = { version = "0.11.13", default-features = false, features = ["rustls-tls", "stream"] }
strum = { version = "0.24.1", features = ["derive"] }
tokio = { version = "1.24.1", features = ["macros", "rt-multi-thread", "sync"] }
tokio.workspace = true
tracing.workspace = true
tracing-subscriber.workspace = true
url = "2.3.0"
Expand Down
8 changes: 7 additions & 1 deletion crates/voicevox_core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ directml = ["onnxruntime/directml"]
[dependencies]
anyhow.workspace = true
cfg-if = "1.0.0"
derive-getters = "0.2.0"
derive-getters.workspace = true
derive-new = "0.5.9"
easy-ext.workspace = true
fs-err.workspace = true
Expand All @@ -25,10 +25,16 @@ thiserror.workspace = true
tracing.workspace = true
open_jtalk = { git = "https://github.com/VOICEVOX/open_jtalk-rs.git", rev="d766a52bad4ccafe18597e57bd6842f59dca881e" }
regex.workspace = true
async_zip = { version = "0.0.11", features = ["full"] }
futures = "0.3.26"
nanoid = "0.4.0"
tokio.workspace = true

[dev-dependencies]
rstest = "0.15.0"
pretty_assertions = "1.3.0"
flate2 = "1.0.24"
tar = "0.4.38"
heck = "0.4.0"
test_util.workspace = true

Expand Down
49 changes: 49 additions & 0 deletions crates/voicevox_core/src/devices.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
use serde::{Deserialize, Serialize};

use super::*;

#[derive(Getters, Debug, Serialize, Deserialize)]
pub struct SupportedDevices {
cpu: bool,
cuda: bool,
dml: bool,
}

impl SupportedDevices {
/// サポートされているデバイス情報を取得する
pub fn get_supported_devices() -> Result<Self> {
let mut cuda_support = false;
let mut dml_support = false;
for provider in onnxruntime::session::get_available_providers()
.map_err(|e| Error::GetSupportedDevices(e.into()))?
.iter()
{
match provider.as_str() {
"CUDAExecutionProvider" => cuda_support = true,
"DmlExecutionProvider" => dml_support = true,
_ => {}
}
}

Ok(SupportedDevices {
cpu: true,
cuda: cuda_support,
dml: dml_support,
})
}

pub fn to_json(&self) -> serde_json::Value {
serde_json::to_value(self).expect("should not fail")
}
}

#[cfg(test)]
mod tests {
use super::*;
#[rstest]
fn supported_devices_get_supported_devices_works() {
let result = SupportedDevices::get_supported_devices();
// 環境によって結果が変わるので、関数呼び出しが成功するかどうかの確認のみ行う
assert!(result.is_ok(), "{result:?}");
}
}
2 changes: 1 addition & 1 deletion crates/voicevox_core/src/engine/full_context_label.rs
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ impl Utterance {
}

pub fn extract_full_context_label(
open_jtalk: &mut open_jtalk::OpenJtalk,
open_jtalk: &open_jtalk::OpenJtalk,
text: impl AsRef<str>,
) -> Result<Self> {
let labels = open_jtalk.extract_fullcontext(text)?;
Expand Down
97 changes: 62 additions & 35 deletions crates/voicevox_core/src/engine/open_jtalk.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
use std::path::{Path, PathBuf};
use std::{
path::{Path, PathBuf},
sync::Mutex,
};

use ::open_jtalk::*;

use crate::Error;

#[derive(thiserror::Error, Debug)]
pub enum OpenJtalkError {
#[error("open_jtalk load error")]
Expand All @@ -17,55 +22,74 @@ pub enum OpenJtalkError {
pub type Result<T> = std::result::Result<T, OpenJtalkError>;

pub struct OpenJtalk {
resources: Mutex<Resources>,
dict_loaded: bool,
}

struct Resources {
mecab: ManagedResource<Mecab>,
njd: ManagedResource<Njd>,
jpcommon: ManagedResource<JpCommon>,
dict_loaded: bool,
}

#[allow(unsafe_code)]
unsafe impl Send for Resources {}

impl OpenJtalk {
pub fn initialize() -> Self {
pub fn new_without_dic() -> Self {
Self {
mecab: ManagedResource::initialize(),
njd: ManagedResource::initialize(),
jpcommon: ManagedResource::initialize(),
resources: Mutex::new(Resources {
mecab: ManagedResource::initialize(),
njd: ManagedResource::initialize(),
jpcommon: ManagedResource::initialize(),
}),
dict_loaded: false,
}
}

pub fn extract_fullcontext(&mut self, text: impl AsRef<str>) -> Result<Vec<String>> {
let result = self.extract_fullcontext_non_reflesh(text);
self.jpcommon.refresh();
self.njd.refresh();
self.mecab.refresh();
result
pub fn new_with_initialize(
open_jtalk_dict_dir: impl AsRef<Path>,
) -> crate::result::Result<Self> {
let mut s = Self::new_without_dic();
s.load(open_jtalk_dict_dir)
.map_err(|_| Error::NotLoadedOpenjtalkDict)?;
Ok(s)
}

fn extract_fullcontext_non_reflesh(&mut self, text: impl AsRef<str>) -> Result<Vec<String>> {
pub fn extract_fullcontext(&self, text: impl AsRef<str>) -> Result<Vec<String>> {
let Resources {
mecab,
njd,
jpcommon,
} = &mut *self.resources.lock().unwrap();

jpcommon.refresh();
njd.refresh();
mecab.refresh();

let mecab_text =
text2mecab(text.as_ref()).map_err(|e| OpenJtalkError::ExtractFullContext {
text: text.as_ref().into(),
source: Some(e.into()),
})?;
if self.mecab.analysis(mecab_text) {
self.njd.mecab2njd(
self.mecab
if mecab.analysis(mecab_text) {
njd.mecab2njd(
mecab
.get_feature()
.ok_or(OpenJtalkError::ExtractFullContext {
text: text.as_ref().into(),
source: None,
})?,
self.mecab.get_size(),
mecab.get_size(),
);
self.njd.set_pronunciation();
self.njd.set_digit();
self.njd.set_accent_phrase();
self.njd.set_accent_type();
self.njd.set_unvoiced_vowel();
self.njd.set_long_vowel();
self.jpcommon.njd2jpcommon(&self.njd);
self.jpcommon.make_label();
self.jpcommon
njd.set_pronunciation();
njd.set_digit();
njd.set_accent_phrase();
njd.set_accent_type();
njd.set_unvoiced_vowel();
njd.set_long_vowel();
jpcommon.njd2jpcommon(njd);
jpcommon.make_label();
jpcommon
.get_label_feature_to_iter()
.ok_or_else(|| OpenJtalkError::ExtractFullContext {
text: text.as_ref().into(),
Expand All @@ -80,15 +104,20 @@ impl OpenJtalk {
}
}

pub fn load(&mut self, mecab_dict_dir: impl AsRef<Path>) -> Result<()> {
let result = self.mecab.load(mecab_dict_dir.as_ref());
fn load(&mut self, open_jtalk_dict_dir: impl AsRef<Path>) -> Result<()> {
let result = self
.resources
.lock()
.unwrap()
.mecab
.load(open_jtalk_dict_dir.as_ref());
if result {
self.dict_loaded = true;
Ok(())
} else {
self.dict_loaded = false;
Err(OpenJtalkError::Load {
mecab_dict_dir: mecab_dict_dir.as_ref().into(),
mecab_dict_dir: open_jtalk_dict_dir.as_ref().into(),
})
}
}
Expand All @@ -101,7 +130,7 @@ impl OpenJtalk {
#[cfg(test)]
mod tests {
use super::*;
use test_util::OPEN_JTALK_DIC_DIR;
use ::test_util::OPEN_JTALK_DIC_DIR;

use crate::{macros::tests::assert_debug_fmt_eq, *};

Expand Down Expand Up @@ -196,8 +225,7 @@ mod tests {
#[case("",Err(OpenJtalkError::ExtractFullContext{text:"".into(),source:None}))]
#[case("こんにちは、ヒホです。", Ok(testdata_hello_hiho()))]
fn extract_fullcontext_works(#[case] text: &str, #[case] expected: super::Result<Vec<String>>) {
let mut open_jtalk = OpenJtalk::initialize();
open_jtalk.load(OPEN_JTALK_DIC_DIR).unwrap();
let open_jtalk = OpenJtalk::new_with_initialize(OPEN_JTALK_DIC_DIR).unwrap();
let result = open_jtalk.extract_fullcontext(text);
assert_debug_fmt_eq!(expected, result);
}
Expand All @@ -208,8 +236,7 @@ mod tests {
#[case] text: &str,
#[case] expected: super::Result<Vec<String>>,
) {
let mut open_jtalk = OpenJtalk::initialize();
open_jtalk.load(OPEN_JTALK_DIC_DIR).unwrap();
let open_jtalk = OpenJtalk::new_with_initialize(OPEN_JTALK_DIC_DIR).unwrap();
for _ in 0..10 {
let result = open_jtalk.extract_fullcontext(text);
assert_debug_fmt_eq!(expected, result);
Expand Down
Loading

0 comments on commit fb24f4f

Please sign in to comment.