Skip to content

Commit

Permalink
Merge branch 'master' into fix/prepost_sil_scale
Browse files Browse the repository at this point in the history
  • Loading branch information
tarepan committed Jun 28, 2024
2 parents bace6b3 + 925c4d5 commit d0877ed
Show file tree
Hide file tree
Showing 12 changed files with 183 additions and 125 deletions.
File renamed without changes.
43 changes: 16 additions & 27 deletions test/unit/tts_pipeline/test_tts_engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from fastapi import HTTPException

from voicevox_engine.dev.tts_engine.mock import MockTTSEngine
from voicevox_engine.tts_pipeline.tts_engine import TTSEngineManager
from voicevox_engine.tts_pipeline.tts_engine import LATEST_VERSION, TTSEngineManager


def test_tts_engines_register_engine() -> None:
Expand Down Expand Up @@ -48,45 +48,34 @@ def test_tts_engines_get_engine_existing() -> None:
assert true_acquired_tts_engine == acquired_tts_engine


def test_tts_engines_get_engine_missing() -> None:
"""TTSEngineManager.get_engine() で存在しない TTS エンジンを取得しようとするとエラーになる。"""
def test_tts_engines_get_engine_latest() -> None:
"""TTSEngineManager.get_engine(LATEST_VERSION) で最新版の TTS エンジンを取得できる。"""
# Inputs
tts_engines = TTSEngineManager()
tts_engine1 = MockTTSEngine()
tts_engine2 = MockTTSEngine()
tts_engine3 = MockTTSEngine()
tts_engines.register_engine(tts_engine1, "0.0.1")
tts_engines.register_engine(tts_engine2, "0.0.2")

# Test
with pytest.raises(HTTPException) as _:
tts_engines.get_engine("0.0.3")


def test_tts_engines_has_engine_true() -> None:
"""TTSEngineManager.has_engine() で TTS エンジンが登録されていることを確認できる。"""
# Inputs
tts_engines = TTSEngineManager()
tts_engines.register_engine(MockTTSEngine(), "0.0.1")
tts_engines.register_engine(MockTTSEngine(), "0.0.2")
tts_engines.register_engine(tts_engine3, "0.1.0")
# Expects
expected_has = True
true_acquired_tts_engine = tts_engine3
# Outputs
has = tts_engines.has_engine("0.0.1")
acquired_tts_engine = tts_engines.get_engine(LATEST_VERSION)

# Test
assert expected_has == has
assert true_acquired_tts_engine == acquired_tts_engine


def test_tts_engines_has_engine_false() -> None:
"""TTSEngineManager.has_engine() TTS エンジンが登録されていないことを確認できる。"""
def test_tts_engines_get_engine_missing() -> None:
"""TTSEngineManager.get_engine() で存在しない TTS エンジンを取得しようとするとエラーになる。"""
# Inputs
tts_engines = TTSEngineManager()
tts_engines.register_engine(MockTTSEngine(), "0.0.1")
tts_engines.register_engine(MockTTSEngine(), "0.0.2")
# Expects
expected_has = False
# Outputs
has = tts_engines.has_engine("0.0.3")
tts_engine1 = MockTTSEngine()
tts_engine2 = MockTTSEngine()
tts_engines.register_engine(tts_engine1, "0.0.1")
tts_engines.register_engine(tts_engine2, "0.0.2")

# Test
assert expected_has == has
with pytest.raises(HTTPException) as _:
tts_engines.get_engine("0.0.3")
82 changes: 70 additions & 12 deletions test/unit/user_dict/test_user_dict_model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""UserDictWord のテスト"""

from typing import TypedDict

import pytest
Expand All @@ -7,7 +9,7 @@
from voicevox_engine.user_dict.model import UserDictWord


class _TestModel(TypedDict):
class UserDictWordInputs(TypedDict):
surface: str
priority: int
part_of_speech: str
Expand All @@ -24,7 +26,8 @@ class _TestModel(TypedDict):
accent_associative_rule: str


def generate_model() -> _TestModel:
def generate_model() -> UserDictWordInputs:
"""テスト用に UserDictWord の要素を生成する。"""
return {
"surface": "テスト",
"priority": 0,
Expand All @@ -44,19 +47,39 @@ def generate_model() -> _TestModel:


def test_valid_word() -> None:
test_value = generate_model()
UserDictWord(**test_value)
"""generate_model 関数は UserDictWord の要素を生成する。"""
# Outputs
args = generate_model()

# Test
UserDictWord(**args)


def test_convert_to_zenkaku() -> None:
"""UserDictWord は surface を全角にする。"""
# Inputs
test_value = generate_model()
test_value["surface"] = "test"
assert UserDictWord(**test_value).surface == "test"
# Expects
true_surface = "test"
# Outputs
surface = UserDictWord(**test_value).surface

# Test
assert surface == true_surface


def test_count_mora() -> None:
"""UserDictWord は mora_count=None を上書きする。"""
# Inputs
test_value = generate_model()
assert UserDictWord(**test_value).mora_count == 3
# Expects
true_mora_count = 3
# Outputs
mora_count = UserDictWord(**test_value).mora_count

# Test
assert mora_count == true_mora_count


def test_count_mora_x() -> None:
Expand All @@ -75,52 +98,87 @@ def test_count_mora_x() -> None:


def test_count_mora_xwa() -> None:
"""「ヮ」を含む発音のモーラ数が適切にカウントされる。"""
# Inputs
test_value = generate_model()
test_value["pronunciation"] = "クヮンセイ"
expected_count = 0
# Expects
true_mora_count = 0
for accent_phrase in parse_kana(
test_value["pronunciation"] + "'",
):
expected_count += len(accent_phrase.moras)
assert UserDictWord(**test_value).mora_count == expected_count
true_mora_count += len(accent_phrase.moras)
# Outputs
mora_rount = UserDictWord(**test_value).mora_count

# Test
assert mora_rount == true_mora_count


def test_invalid_pronunciation_not_katakana() -> None:
"""UserDictWord はカタカナでない pronunciation をエラーとする。"""
# Inputs
test_value = generate_model()
test_value["pronunciation"] = "ぼいぼ"

# Test
with pytest.raises(ValidationError):
UserDictWord(**test_value)


def test_invalid_pronunciation_invalid_sutegana() -> None:
"""UserDictWord は無効な pronunciation をエラーとする。"""
# Inputs
test_value = generate_model()
test_value["pronunciation"] = "アィウェォ"

# Test
with pytest.raises(ValidationError):
UserDictWord(**test_value)


def test_invalid_pronunciation_invalid_xwa() -> None:
"""UserDictWord は無効な pronunciation をエラーとする。"""
# Inputs
test_value = generate_model()
test_value["pronunciation"] = "アヮ"

# Test
with pytest.raises(ValidationError):
UserDictWord(**test_value)


def test_count_mora_voiced_sound() -> None:
"""UserDictWord はモーラ数を正しくカウントして上書きする。"""
# Inputs
test_value = generate_model()
test_value["pronunciation"] = "ボイボ"
assert UserDictWord(**test_value).mora_count == 3
# Expects
true_mora_count = 3
# Outputs
mora_count = UserDictWord(**test_value).mora_count

# Test
assert mora_count == true_mora_count

def test_invalid_accent_type() -> None:

def test_word_accent_type_too_big() -> None:
"""UserDictWord はモーラ数を超えた accent_type をエラーとする。"""
# Inputs
test_value = generate_model()
test_value["accent_type"] = 4

# Test
with pytest.raises(ValidationError):
UserDictWord(**test_value)


def test_invalid_accent_type_2() -> None:
def test_word_accent_type_negative() -> None:
"""UserDictWord は負の accent_type をエラーとする。"""
# Inputs
test_value = generate_model()
test_value["accent_type"] = -1

# Test
with pytest.raises(ValidationError):
UserDictWord(**test_value)
27 changes: 18 additions & 9 deletions voicevox_engine/app/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from voicevox_engine.app.routers.tts_pipeline import generate_tts_pipeline_router
from voicevox_engine.app.routers.user_dict import generate_user_dict_router
from voicevox_engine.cancellable_engine import CancellableEngine
from voicevox_engine.core.core_adapter import CoreCharacter
from voicevox_engine.core.core_initializer import CoreManager
from voicevox_engine.engine_manifest import EngineManifest
from voicevox_engine.library.library_manager import LibraryManager
Expand Down Expand Up @@ -66,27 +67,35 @@ def generate_app(

resource_manager = ResourceManager(is_development())
resource_manager.register_dir(character_info_dir)
metas_store = MetasStore(character_info_dir, resource_manager)

app.include_router(
generate_tts_pipeline_router(
tts_engines, core_manager, preset_manager, cancellable_engine
)
core_version_list = core_manager.versions()

def _get_core_characters(version: str | None) -> list[CoreCharacter]:
version = version or core_manager.latest_version()
core = core_manager.get_core(version)
return core.characters

metas_store = MetasStore(
character_info_dir,
_get_core_characters,
resource_manager,
)
app.include_router(generate_morphing_router(tts_engines, core_manager, metas_store))

app.include_router(
generate_preset_router(preset_manager, verify_mutability_allowed)
generate_tts_pipeline_router(tts_engines, preset_manager, cancellable_engine)
)
app.include_router(generate_morphing_router(tts_engines, metas_store))
app.include_router(
generate_character_router(core_manager, resource_manager, metas_store)
generate_preset_router(preset_manager, verify_mutability_allowed)
)
app.include_router(generate_character_router(resource_manager, metas_store))
if engine_manifest.supported_features.manage_library:
app.include_router(
generate_library_router(library_manager, verify_mutability_allowed)
)
app.include_router(generate_user_dict_router(user_dict, verify_mutability_allowed))
app.include_router(
generate_engine_info_router(core_manager, tts_engines, engine_manifest)
generate_engine_info_router(core_version_list, tts_engines, engine_manifest)
)
app.include_router(
generate_setting_router(
Expand Down
21 changes: 5 additions & 16 deletions voicevox_engine/app/routers/character.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from fastapi.responses import FileResponse
from pydantic.json_schema import SkipJsonSchema

from voicevox_engine.core.core_initializer import CoreManager
from voicevox_engine.metas.Metas import Speaker, SpeakerInfo
from voicevox_engine.metas.MetasStore import Character, MetasStore, ResourceFormat
from voicevox_engine.resource_manager import ResourceManager, ResourceManagerError
Expand Down Expand Up @@ -35,19 +34,15 @@ def _characters_to_speakers(characters: list[Character]) -> list[Speaker]:


def generate_character_router(
core_manager: CoreManager,
resource_manager: ResourceManager,
metas_store: MetasStore,
resource_manager: ResourceManager, metas_store: MetasStore
) -> APIRouter:
"""キャラクター情報 API Router を生成する"""
router = APIRouter(tags=["その他"])

@router.get("/speakers")
def speakers(core_version: str | SkipJsonSchema[None] = None) -> list[Speaker]:
"""喋れるキャラクターの情報の一覧を返します。"""
version = core_version or core_manager.latest_version()
core = core_manager.get_core(version)
characters = metas_store.talk_characters(core.characters)
characters = metas_store.talk_characters(core_version)
return _characters_to_speakers(characters)

@router.get("/speaker_info")
Expand All @@ -61,22 +56,18 @@ def speaker_info(
UUID で指定された喋れるキャラクターの情報を返します。
画像や音声はresource_formatで指定した形式で返されます。
"""
version = core_version or core_manager.latest_version()
core = core_manager.get_core(version)
return metas_store.character_info(
character_uuid=speaker_uuid,
talk_or_sing="talk",
core_characters=core.characters,
core_version=core_version,
resource_baseurl=resource_baseurl,
resource_format=resource_format,
)

@router.get("/singers")
def singers(core_version: str | SkipJsonSchema[None] = None) -> list[Speaker]:
"""歌えるキャラクターの情報の一覧を返します。"""
version = core_version or core_manager.latest_version()
core = core_manager.get_core(version)
characters = metas_store.sing_characters(core.characters)
characters = metas_store.sing_characters(core_version)
return _characters_to_speakers(characters)

@router.get("/singer_info")
Expand All @@ -90,12 +81,10 @@ def singer_info(
UUID で指定された歌えるキャラクターの情報を返します。
画像や音声はresource_formatで指定した形式で返されます。
"""
version = core_version or core_manager.latest_version()
core = core_manager.get_core(version)
return metas_store.character_info(
character_uuid=speaker_uuid,
talk_or_sing="sing",
core_characters=core.characters,
core_version=core_version,
resource_baseurl=resource_baseurl,
resource_format=resource_format,
)
Expand Down
9 changes: 4 additions & 5 deletions voicevox_engine/app/routers/engine_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@

from voicevox_engine import __version__
from voicevox_engine.core.core_adapter import DeviceSupport
from voicevox_engine.core.core_initializer import CoreManager
from voicevox_engine.engine_manifest import EngineManifest
from voicevox_engine.tts_pipeline.tts_engine import TTSEngineManager
from voicevox_engine.tts_pipeline.tts_engine import LATEST_VERSION, TTSEngineManager


class SupportedDevicesInfo(BaseModel):
Expand All @@ -33,7 +32,7 @@ def generate_from(cls, device_support: DeviceSupport) -> Self:


def generate_engine_info_router(
core_manager: CoreManager,
core_version_list: list[str],
tts_engine_manager: TTSEngineManager,
engine_manifest_data: EngineManifest,
) -> APIRouter:
Expand All @@ -48,14 +47,14 @@ async def version() -> str:
@router.get("/core_versions")
async def core_versions() -> list[str]:
"""利用可能なコアのバージョン一覧を取得します。"""
return core_manager.versions()
return core_version_list

@router.get("/supported_devices")
def supported_devices(
core_version: str | SkipJsonSchema[None] = None,
) -> SupportedDevicesInfo:
"""対応デバイスの一覧を取得します。"""
version = core_version or core_manager.latest_version()
version = core_version or LATEST_VERSION
supported_devices = tts_engine_manager.get_engine(version).supported_devices
if supported_devices is None:
raise HTTPException(status_code=422, detail="非対応の機能です。")
Expand Down
Loading

0 comments on commit d0877ed

Please sign in to comment.