diff --git a/aixplain/model_interfaces/__init__.py b/aixplain/model_interfaces/__init__.py index 81ddefc..b23129c 100644 --- a/aixplain/model_interfaces/__init__.py +++ b/aixplain/model_interfaces/__init__.py @@ -1,4 +1,4 @@ -from aixplain.model_interfaces.schemas.function_input import ( +from aixplain.model_interfaces.schemas.function.function_input import ( APIInput, AudioEncoding, AudioConfig, @@ -8,10 +8,16 @@ ClassificationInput, SpeechEnhancementInput, SpeechSynthesisInput, - TextToImageGenerationInput + TextToImageGenerationInput, + TextGenerationInput, + TextSummarizationInput, + SearchInput, + TextReconstructionInput, + FillTextMaskInput, + SubtitleTranslationInput ) -from aixplain.model_interfaces.schemas.function_output import( +from aixplain.model_interfaces.schemas.function.function_output import( APIOutput, WordDetails, TextSegmentDetails, @@ -21,10 +27,16 @@ DiacritizationOutput, ClassificationOutput, SpeechEnhancementOutput, - TextToImageGenerationOutput + TextToImageGenerationOutput, + TextGenerationOutput, + TextSummarizationOutput, + SearchOutput, + TextReconstructionOutput, + FillTextMaskOutput, + SubtitleTranslationOutput ) -from aixplain.model_interfaces.schemas.metric_input import( +from aixplain.model_interfaces.schemas.metric.metric_input import( MetricInput, MetricAggregate, TextGenerationSettings, @@ -38,7 +50,7 @@ NamedEntityRecognitionMetricInput ) -from aixplain.model_interfaces.schemas.metric_output import( +from aixplain.model_interfaces.schemas.metric.metric_output import( MetricOutput, TextGenerationMetricOutput, ReferencelessTextGenerationMetricOutput, @@ -55,7 +67,14 @@ ClassificationModel, SpeechEnhancementModel, SpeechSynthesis, - TextToImageGeneration + TextToImageGeneration, + TextGenerationModel, + TextGenerationChatModel, + TextSummarizationModel, + SearchModel, + TextReconstructionModel, + FillTextMaskModel, + SubtitleTranslationModel ) from aixplain.model_interfaces.interfaces.metric_models import( @@ -74,7 +93,14 @@ ClassificationModel, SpeechEnhancementModel, SpeechSynthesis, - TextToImageGeneration + TextToImageGeneration, + TextGenerationModel, + TextGenerationChatModel, + TextSummarizationModel, + SearchModel, + TextReconstructionModel, + FillTextMaskModel, + SubtitleTranslationModel ] function_classes_input = [ @@ -87,7 +113,8 @@ ClassificationInput, SpeechEnhancementInput, SpeechSynthesisInput, - TextToImageGenerationInput + TextToImageGenerationInput, + TextGenerationInput ] metric_classes_input = [ diff --git a/aixplain/model_interfaces/__version__.py b/aixplain/model_interfaces/__version__.py index 0ef9437..9dd3e0a 100644 --- a/aixplain/model_interfaces/__version__.py +++ b/aixplain/model_interfaces/__version__.py @@ -1,7 +1,7 @@ __title__ = "model-interfaces" __description__ = "model-interfaces is the interface to host your models on aiXplain" __url__ = "https://github.com/aixplain/aixplain-models/tree/main/docs" -__version__ = "0.0.1" +__version__ = "0.0.2rc2" __author__ = "Duraikrishna Selvaraju and Michael Lam" __author_email__ = "krishna.durai@aixplain.com" __license__ = "http://www.apache.org/licenses/LICENSE-2.0" diff --git a/aixplain/model_interfaces/interfaces/aixplain_metric.py b/aixplain/model_interfaces/interfaces/aixplain_metric.py index e62e3a9..d838865 100644 --- a/aixplain/model_interfaces/interfaces/aixplain_metric.py +++ b/aixplain/model_interfaces/interfaces/aixplain_metric.py @@ -8,8 +8,8 @@ import time from typing import Dict, List -from aixplain.model_interfaces.schemas.metric_input import MetricInput, MetricAggregate -from aixplain.model_interfaces.schemas.metric_output import MetricOutput +from aixplain.model_interfaces.schemas.metric.metric_input import MetricInput, MetricAggregate +from aixplain.model_interfaces.schemas.metric.metric_output import MetricOutput class MetricType(Enum): SCORE = 1 diff --git a/aixplain/model_interfaces/interfaces/aixplain_model.py b/aixplain/model_interfaces/interfaces/aixplain_model.py index 28420ad..80b0910 100644 --- a/aixplain/model_interfaces/interfaces/aixplain_model.py +++ b/aixplain/model_interfaces/interfaces/aixplain_model.py @@ -1,7 +1,7 @@ from kserve.model import Model from typing import Dict, List -from aixplain.model_interfaces.schemas.function_input import APIInput -from aixplain.model_interfaces.schemas.function_output import APIOutput +from aixplain.model_interfaces.schemas.function.function_input import APIInput +from aixplain.model_interfaces.schemas.function.function_output import APIOutput class AixplainModel(Model): diff --git a/aixplain/model_interfaces/interfaces/function_models.py b/aixplain/model_interfaces/interfaces/function_models.py index fdb3b14..83e55d4 100644 --- a/aixplain/model_interfaces/interfaces/function_models.py +++ b/aixplain/model_interfaces/interfaces/function_models.py @@ -1,26 +1,42 @@ import tornado.web from http import HTTPStatus -from typing import Dict, List +from typing import Dict, List, Union, Optional, Text +from enum import Enum +from pydantic import BaseModel, validate_call -from aixplain.model_interfaces.schemas.function_input import ( +from aixplain.model_interfaces.schemas.function.function_input import ( TranslationInput, SpeechRecognitionInput, DiacritizationInput, ClassificationInput, SpeechEnhancementInput, SpeechSynthesisInput, - TextToImageGenerationInput + TextToImageGenerationInput, + TextGenerationInput, + TextSummarizationInput, + SearchInput, + TextReconstructionInput, + FillTextMaskInput, + SubtitleTranslationInput ) -from aixplain.model_interfaces.schemas.function_output import ( +from aixplain.model_interfaces.schemas.function.function_output import ( TranslationOutput, SpeechRecognitionOutput, DiacritizationOutput, ClassificationOutput, SpeechEnhancementOutput, SpeechSynthesisOutput, - TextToImageGenerationOutput + TextToImageGenerationOutput, + TextGenerationOutput, + TextSummarizationOutput, + SearchOutput, + TextReconstructionOutput, + FillTextMaskOutput, + SubtitleTranslationOutput ) +from aixplain.model_interfaces.schemas.modality.modality_input import TextInput, TextListInput +from aixplain.model_interfaces.schemas.modality.modality_output import TextListOutput from aixplain.model_interfaces.interfaces.aixplain_model import AixplainModel class TranslationModel(AixplainModel): @@ -159,6 +175,7 @@ def predict(self, request: Dict[str, str], headers: Dict[str, str] = None) -> Di SpeechSynthesisOutput(**speech_synthesis_dict) speech_synthesis_output["instances"][i] = speech_synthesis_dict return speech_synthesis_output + class TextToImageGeneration(AixplainModel): def run_model(self, api_input: Dict[str, List[TextToImageGenerationInput]], headers: Dict[str, str] = None) -> Dict[str, List[TextToImageGenerationOutput]]: pass @@ -179,3 +196,176 @@ def predict(self, request: Dict[str, str], headers: Dict[str, str] = None) -> Di TextToImageGenerationOutput(**text_to_image_generation_dict) text_to_image_generation_output["predictions"][i] = text_to_image_generation_dict return text_to_image_generation_output + +class TextGenerationChatTemplatizeInput(BaseModel): + data: List[Dict] + +class TextGenerationPredictInput(BaseModel): + instances: Union[List[TextGenerationInput], List[TextListInput], List[TextGenerationChatTemplatizeInput]] + function: Optional[Text] = "PREDICT" + +class TextGenerationRunModelOutput(BaseModel): + predictions: List[TextGenerationOutput] + +class TextGenerationTokenizeOutput(BaseModel): + token_counts: List[List[int]] + +class TextGenerationModel(AixplainModel): + @validate_call + def predict(self, request: TextGenerationPredictInput, headers: Dict[str, str] = None) -> Dict: + instances = request.instances + if request.function.upper() == "PREDICT": + predict_output = { + "predictions": self.run_model(instances, headers) + } + return predict_output + elif request.function.upper() == "TOKENIZE": + token_counts_output = { + "token_counts": self.tokenize(instances, headers) + } + return token_counts_output + else: + raise ValueError("Invalid function.") + + @validate_call + def run_model(self, api_input: List[TextGenerationInput], headers: Dict[str, str] = None) -> List[TextGenerationOutput]: + raise NotImplementedError + + @validate_call + def tokenize(self, api_input: List[TextListInput], headers: Dict[str, str] = None) -> List[List[int]]: + raise NotImplementedError + +class TextGenerationChatModel(TextGenerationModel): + @validate_call + def run_model(self, api_input: List[TextInput], headers: Dict[str, str] = None) -> List[TextGenerationOutput]: + raise NotImplementedError + + @validate_call + def predict(self, request: TextGenerationPredictInput, headers: Dict[str, str] = None) -> Dict: + instances = request.instances + if request.function.upper() == "PREDICT": + predict_output = { + "predictions": self.run_model(instances, headers) + } + return predict_output + elif request.function.upper() == "TOKENIZE": + token_counts_output = { + "token_counts": self.tokenize(instances, headers) + } + return token_counts_output + elif request.function.upper() == "TEMPLATIZE": + templatize_output = { + "prompts": self.templatize(instances, headers) + } + return templatize_output + else: + raise ValueError("Invalid function.") + + @validate_call + def templatize(self, api_input: List[TextGenerationChatTemplatizeInput], headers: Dict[str, str] = None) -> List[Text]: + pass + +class TextSummarizationModel(AixplainModel): + def run_model(self, api_input: Dict[str, List[TextSummarizationInput]], headers: Dict[str, str] = None) -> Dict[str, List[TextSummarizationOutput]]: + pass + + def predict(self, request: Dict[str, str], headers: Dict[str, str] = None) -> Dict: + instances = request['instances'] + text_summarization_input_list = [] + # Convert JSON serializables into TextSummarizationInputs + for instance in instances: + text_summarization_input = TextSummarizationInput(**instance) + text_summarization_input_list.append(text_summarization_input) + + text_summarization_output = self.run_model({"instances": text_summarization_input_list}) + + # Convert JSON serializables into TextSummarizationOutputs + for i in range(len(text_summarization_output["predictions"])): + text_summarization_dict = text_summarization_output["predictions"][i].dict() + TextSummarizationOutput(**text_summarization_dict) + text_summarization_output["predictions"][i] = text_summarization_dict + return text_summarization_output + +class SearchModel(AixplainModel): + def run_model(self, api_input: Dict[str, List[SearchInput]], headers: Dict[str, str] = None) -> Dict[str, List[SearchOutput]]: + pass + + def predict(self, request: Dict[str, str], headers: Dict[str, str] = None) -> Dict: + instances = request['instances'] + search_input_list = [] + # Convert JSON serializables into SearchInputs + for instance in instances: + search_input = SearchInput(**instance) + search_input_list.append(search_input) + + search_output = self.run_model({"instances": search_input_list}) + + # Convert JSON serializables into SearchOutputs + for i in range(len(search_output["predictions"])): + search_dict = search_output["predictions"][i].dict() + SearchOutput(**search_dict) + search_output["predictions"][i] = search_dict + return search_output + +class TextReconstructionModel(AixplainModel): + def run_model(self, api_input: Dict[str, List[TextReconstructionInput]], headers: Dict[str, str] = None) -> Dict[str, List[TextReconstructionInput]]: + pass + + def predict(self, request: Dict[str, str], headers: Dict[str, str] = None) -> Dict: + instances = request['instances'] + text_reconstruction_input_list = [] + # Convert JSON serializables into TextReconstructionInputs + for instance in instances: + text_reconstruction_input = TextReconstructionInput(**instance) + text_reconstruction_input_list.append(text_reconstruction_input) + + text_reconstruction_output = self.run_model({"instances": text_reconstruction_input_list}) + + # Convert JSON serializables into TextReconstructionOutputs + for i in range(len(text_reconstruction_output["predictions"])): + text_reconstruction_dict = text_reconstruction_output["predictions"][i].dict() + TextReconstructionOutput(**text_reconstruction_dict) + text_reconstruction_output["predictions"][i] = text_reconstruction_dict + return text_reconstruction_output + +class FillTextMaskModel(AixplainModel): + def run_model(self, api_input: Dict[str, List[FillTextMaskInput]], headers: Dict[str, str] = None) -> Dict[str, List[FillTextMaskOutput]]: + pass + + def predict(self, request: Dict[str, str], headers: Dict[str, str] = None) -> Dict: + instances = request['instances'] + fill_text_mask_input_list = [] + # Convert JSON serializables into FillTextMaskInputs + for instance in instances: + fill_text_mask_input = FillTextMaskInput(**instance) + fill_text_mask_input_list.append(fill_text_mask_input) + + fill_text_mask_output = self.run_model({"instances": fill_text_mask_input_list}) + + # Convert JSON serializables into FillTextMaskOutputs + for i in range(len(fill_text_mask_output["predictions"])): + fill_text_mask_dict = fill_text_mask_output["predictions"][i].dict() + FillTextMaskOutput(**fill_text_mask_dict) + fill_text_mask_output["predictions"][i] = fill_text_mask_dict + return fill_text_mask_output + +class SubtitleTranslationModel(AixplainModel): + def run_model(self, api_input: Dict[str, List[SubtitleTranslationInput]], headers: Dict[str, str] = None) -> Dict[str, List[SubtitleTranslationOutput]]: + pass + + def predict(self, request: Dict[str, str], headers: Dict[str, str] = None) -> Dict: + instances = request['instances'] + subtitle_translation_input_list = [] + # Convert JSON serializables into SubtitleTranslationInputs + for instance in instances: + subtitle_translation_input = SubtitleTranslationInput(**instance) + subtitle_translation_input_list.append(subtitle_translation_input) + + subtitle_translation_output = self.run_model({"instances": subtitle_translation_input_list}) + + # Convert JSON serializables into SubtitleTranslationOutput + for i in range(len(subtitle_translation_output["predictions"])): + subtitle_translation_dict = subtitle_translation_output["predictions"][i].dict() + SubtitleTranslationOutput(**subtitle_translation_dict) + subtitle_translation_output["predictions"][i] = subtitle_translation_dict + return subtitle_translation_output \ No newline at end of file diff --git a/aixplain/model_interfaces/interfaces/metric_models.py b/aixplain/model_interfaces/interfaces/metric_models.py index a34a9e0..9e6d5ec 100644 --- a/aixplain/model_interfaces/interfaces/metric_models.py +++ b/aixplain/model_interfaces/interfaces/metric_models.py @@ -5,7 +5,7 @@ import tornado.web from fastapi.responses import JSONResponse -from aixplain.model_interfaces.schemas.metric_input import ( +from aixplain.model_interfaces.schemas.metric.metric_input import ( TextGenerationMetricInput, ReferencelessTextGenerationMetricInput, ClassificationMetricInput, @@ -14,7 +14,7 @@ NamedEntityRecognitionMetricInput, MetricAggregate, ) -from aixplain.model_interfaces.schemas.metric_output import ( +from aixplain.model_interfaces.schemas.metric.metric_output import ( TextGenerationMetricOutput, ReferencelessTextGenerationMetricOutput, ClassificationMetricOutput, diff --git a/aixplain/model_interfaces/schemas/api/basic_api_input.py b/aixplain/model_interfaces/schemas/api/basic_api_input.py new file mode 100644 index 0000000..076d38a --- /dev/null +++ b/aixplain/model_interfaces/schemas/api/basic_api_input.py @@ -0,0 +1,36 @@ +""" +Basic class for API inputs +""" +from pydantic import BaseModel +from typing import Any, Optional + +class APIInput(BaseModel): + """The standardized schema of the aiXplain's API input. + + :param data: + Input data to the model. + :type data: + Any + :param supplier: + Supplier name. + :type supplier: + str + :param function: + The aixplain function name for the model. + :type function: + str + :param version: + The version number of the model if the supplier has multiple + models with the same function. Optional. + :type version: + str + :param language: + The language the model processes (if relevant). Optional. + :type language: + str + """ + data: Any + supplier: Optional[str] = "" + function: Optional[str] = "" + version: Optional[str] = "" + language: Optional[str] = "" \ No newline at end of file diff --git a/aixplain/model_interfaces/schemas/api/basic_api_output.py b/aixplain/model_interfaces/schemas/api/basic_api_output.py new file mode 100644 index 0000000..94619b8 --- /dev/null +++ b/aixplain/model_interfaces/schemas/api/basic_api_output.py @@ -0,0 +1,20 @@ +""" +Basic class for API outputs +""" +from pydantic import BaseModel +from typing import Any, Optional, Union, List, Dict + +class APIOutput(BaseModel): + """The standardized schema of the aiXplain's API Output. + + :param data: + Processed output data from supplier model. + :type data: + Any + :param details: + Details of the output data. Optional. + :type details: + List[str] or Dict[str, str] + """ + data: Any + details: Optional[Union[List[str], Dict[str, str]]] = [] \ No newline at end of file diff --git a/aixplain/model_interfaces/schemas/function_input.py b/aixplain/model_interfaces/schemas/function/function_input.py similarity index 60% rename from aixplain/model_interfaces/schemas/function_input.py rename to aixplain/model_interfaces/schemas/function/function_input.py index 428bdf1..fa2c5e9 100644 --- a/aixplain/model_interfaces/schemas/function_input.py +++ b/aixplain/model_interfaces/schemas/function/function_input.py @@ -1,42 +1,13 @@ from enum import Enum from http import HTTPStatus -from typing import Optional, Any +from typing import Optional, Any, List from pydantic import BaseModel, validator import tornado.web from aixplain.model_interfaces.utils import serialize - -class APIInput(BaseModel): - """The standardized schema of the aiXplain's API input. - - :param data: - Input data to the model. - :type data: - Any - :param supplier: - Supplier name. - :type supplier: - str - :param function: - The aixplain function name for the model. - :type function: - str - :param version: - The version number of the model if the supplier has multiple - models with the same function. Optional. - :type version: - str - :param language: - The language the model processes (if relevant). Optional. - :type language: - str - """ - data: Any - supplier: Optional[str] = "" - function: Optional[str] = "" - version: Optional[str] = "" - language: Optional[str] = "" +from aixplain.model_interfaces.schemas.api.basic_api_input import APIInput +from aixplain.model_interfaces.schemas.modality.modality_input import TextInput class AudioEncoding(Enum): """ @@ -62,56 +33,6 @@ class AudioConfig(BaseModel): audio_encoding: AudioEncoding sampling_rate: Optional[int] -class TranslationInputSchema(APIInput): - """The standardized schema of the aiXplain's Translation API input. - - :param data: - Input data to the model. - :type data: - Any - :param supplier: - Supplier name. - :type supplier: - str - :param function: - The aixplain function name for the model. - :type function: - str - :param version: - The version number of the model if the supplier has multiple - models with the same function. Optional. - :type version: - str - :param source_language: - The source language the model processes for translation. - :type source_language: - str - :param source_dialect: - The source dialect the model processes (if specified) for translation. - Optional. - :type source_dialect: - str - :param target_language: - The target language the model processes for translation. - :type target_language: - str - """ - source_language: str - source_dialect: Optional[str] = "" - target_language: str - target_dialect: Optional[str] = "" - -class TranslationInput(TranslationInputSchema): - def __init__(self, **input): - super().__init__(**input) - try: - super().__init__(**input) - except ValueError: - raise tornado.web.HTTPError( - status_code=HTTPStatus.BAD_REQUEST, - reason="Incorrect types passed into TranslationInput." - ) - class SpeechRecognitionInputSchema(APIInput): """The standardized schema of the aiXplain's Speech Recognition API input. @@ -170,50 +91,6 @@ def __init__(self, **input): reason="Incorrect types passed into SpeechRecognitionInput." ) -class DiacritizationInputSchema(APIInput): - """The standardized schema of the aiXplain's diacritization API input. - - :param data: - Input data to the model. - :type data: - Any - :param supplier: - Supplier name. - :type supplier: - str - :param function: - The aixplain function name for the model. - :type function: - str - :param version: - The version number of the model if the supplier has multiple - models with the same function. Optional. - :type version: - str - :param language: - The source language the model processes for diarization. - :type language: - str - :param dialect: - The source dialect the model processes (if specified) for diarization. - Optional. - :type dialect: - str - """ - language: str - dialect: Optional[str] = "" - -class DiacritizationInput(DiacritizationInputSchema): - def __init__(self, **input): - super().__init__(**input) - try: - super().__init__(**input) - except ValueError: - raise tornado.web.HTTPError( - status_code=HTTPStatus.BAD_REQUEST, - reason="Incorrect types passed into DiacritizationInput." - ) - class ClassificationInputSchema(APIInput): """The standardized schema of the aiXplain's classification API input. @@ -344,7 +221,7 @@ class SpeechSynthesisInputSchema(BaseModel): audio: str = "" text: str text_language: str = "en" - audio_config: AudioConfig = AudioConfig(audio_encoding = AudioEncoding.WAV) + audio_config: AudioConfig class SpeechSynthesisInput(SpeechSynthesisInputSchema): def __init__(self, **input): @@ -384,4 +261,352 @@ def __init__(self, **input): raise tornado.web.HTTPError( status_code=HTTPStatus.BAD_REQUEST, reason="Incorrect type passed into TextToImageGenerationInput." + ) + +class TextGenerationInputSchema(TextInput): + """The standardized schema of aiXplains text generation API Input + + + """ + temperature: Optional[float] = 1.0 + max_new_tokens: Optional[int] = 200 + top_p: Optional[float] = 0.8 + top_k: Optional[int] = 40 + num_return_sequences: Optional[int] = 1 + script: Optional[str] = "" + +class TextGenerationInput(TextGenerationInputSchema): + def __init__(self, **input): + super().__init__(**input) + try: + super().__init__(**input) + except ValueError: + raise tornado.web.HTTPError( + status_code=HTTPStatus.BAD_REQUEST, + reason="Incorrect type passed into TextGenerationInput." + ) + + +class TranslationInputSchema(TextInput): + """The standardized schema of the aiXplain's Translation API input. + + :param data: + Input data to the model. + :type data: + Any + :param supplier: + Supplier name. + :type supplier: + str + :param function: + The aixplain function name for the model. + :type function: + str + :param version: + The version number of the model if the supplier has multiple + models with the same function. Optional. + :type version: + str + :param source_language: + The source language the model processes for translation. + :type source_language: + str + :param source_dialect: + The source dialect the model processes (if specified) for translation. + Optional. + :type source_dialect: + str + :param target_language: + The target language the model processes for translation. + :type target_language: + str + """ + source_language: str + source_dialect: Optional[str] = "" + target_language: str + target_dialect: Optional[str] = "" + +class TranslationInput(TranslationInputSchema): + def __init__(self, **input): + super().__init__(**input) + try: + super().__init__(**input) + except ValueError: + raise tornado.web.HTTPError( + status_code=HTTPStatus.BAD_REQUEST, + reason="Incorrect types passed into TranslationInput." + ) + +class TextSummarizationInputSchema(TextInput): + """The standardized schema of the aiXplain's text summarization API input. + + :param data: + Input data to the model. + :type data: + str + :param supplier: + Supplier name. Optional. + :type supplier: + str + :param function: + The aixplain function name for the model. Optional. + :type function: + str + :param version: + The version number of the model if the supplier has multiple + models with the same function. Optional. + :param language: + The model's language. Optional. + :type language: + str + :param script: + # TODO What is this? + :type script: + str + :param dialect: + The language's dialect. Optional. + :type dialect: + str + """ + language: Optional[str] = "" + script: Optional[str] = "" + dialect: Optional[str] = "" + +class TextSummarizationInput(TextSummarizationInputSchema): + def __init__(self, **input): + super().__init__(**input) + try: + super().__init__(**input) + except ValueError: + raise tornado.web.HTTPError( + status_code=HTTPStatus.BAD_REQUEST, + reason="Incorrect types passed into TextSummarizationInput." + ) + +class SearchInputSchema(TextInput): + """The standardized schema of the aiXplain's text summarization API input. + + :param data: + Input data to the model. + :type data: + str + :param supplier: + Supplier name. Optional. + :type supplier: + str + :param function: + The aixplain function name for the model. Optional. + :type function: + str + :param version: + The version number of the model if the supplier has multiple + models with the same function. Optional. + :param language: + The model's language. Optional. + :type language: + str + :param script: + # TODO What is this? + :type script: + str + :param supplier_model_id: + The model ID from the supplier. Optional. + :type supplier_model_id: + str + """ + script: Optional[str] = "" + supplier_model_id: Optional[str] = "" + +class SearchInput(SearchInputSchema): + def __init__(self, **input): + super().__init__(**input) + try: + super().__init__(**input) + except ValueError: + raise tornado.web.HTTPError( + status_code=HTTPStatus.BAD_REQUEST, + reason="Incorrect types passed into SearchInput." + ) + +class DiacritizationInputSchema(TextInput): + """The standardized schema of the aiXplain's diacritization API input. + + :param data: + Input data to the model. + :type data: + Any + :param supplier: + Supplier name. + :type supplier: + str + :param function: + The aixplain function name for the model. + :type function: + str + :param version: + The version number of the model if the supplier has multiple + models with the same function. Optional. + :type version: + str + :param language: + The source language the model processes for diarization. + :type language: + str + :param dialect: + The source dialect the model processes (if specified) for diarization. + Optional. + :type dialect: + str + """ + dialect: Optional[str] = "" + +class DiacritizationInput(DiacritizationInputSchema): + def __init__(self, **input): + super().__init__(**input) + try: + super().__init__(**input) + except ValueError: + raise tornado.web.HTTPError( + status_code=HTTPStatus.BAD_REQUEST, + reason="Incorrect types passed into DiacritizationInput." + ) + +class TextReconstructionInputSchema(TextInput): + """The standardized schema of the aiXplain's text reconstruction API input. + + :param data: + Input data to the model. + :type data: + Any + :param supplier: + Supplier name. + :type supplier: + str + :param function: + The aixplain function name for the model. + :type function: + str + :param version: + The version number of the model if the supplier has multiple + models with the same function. Optional. + :type version: + str + :param language: + The source language the model processes for diarization. + :type language: + str + """ + pass + +class TextReconstructionInput(TextReconstructionInputSchema): + def __init__(self, **input): + super().__init__(**input) + try: + super().__init__(**input) + except ValueError: + raise tornado.web.HTTPError( + status_code=HTTPStatus.BAD_REQUEST, + reason="Incorrect types passed into TextReconstructionInput." + ) + +class FillTextMaskInputSchema(TextInput): + """The standardized schema of the aiXplain's fill-text-mask API input. + + :param data: + Input data to the model. + :type data: + Any + :param supplier: + Supplier name. Optional. + :type supplier: + str + :param function: + The aixplain function name for the model. Optional. + :type function: + str + :param version: + The version number of the model if the supplier has multiple + models with the same function. Optional. + :type version: + str + :param language: + The source language the model processes for diarization. + :type language: + str + :param dialect: + The source dialect the model processes (if specified) for diarization. + Optional. + :type dialect: + str + :param script: + # TODO What is this? Optional. + :type script: + str + """ + language: str + dialect: Optional[str] + script: Optional[str] + +class FillTextMaskInput(FillTextMaskInputSchema): + def __init__(self, **input): + super().__init__(**input) + try: + super().__init__(**input) + except ValueError: + raise tornado.web.HTTPError( + status_code=HTTPStatus.BAD_REQUEST, + reason="Incorrect types passed into FillTextMaskInput." + ) + +class SubtitleTranslationInputSchema(TextInput): + """The standardized schema of the aiXplain's subtitle translation API input. + + :param data: + Input data to the model. + :type data: + Any + :param supplier: + Supplier name. Optional. + :type supplier: + str + :param function: + The aixplain function name for the model. Optional. + :type function: + str + :param version: + The version number of the model if the supplier has multiple + models with the same function. Optional. + :type version: + str + :param source_language: + The subtitle's source language. + :type source_language: + str + :param dialect_in: + The dialect of the source language. Optional. + :type dialect_in: + str + :param target_supplier: + TODO What is this? + :type target_supplier: + str + :param target_languages: + Languages to which to translate the subtitle. + :type target_languages: + List[str] + """ + source_language: str + dialect_in: Optional[str] + target_supplier: Optional[str] + target_languages: Optional[List[str]] + +class SubtitleTranslationInput(SubtitleTranslationInputSchema): + def __init__(self, **input): + super().__init__(**input) + try: + super().__init__(**input) + except ValueError: + raise tornado.web.HTTPError( + status_code=HTTPStatus.BAD_REQUEST, + reason="Incorrect types passed into SubtitleTranslationInput." ) \ No newline at end of file diff --git a/aixplain/model_interfaces/schemas/function_output.py b/aixplain/model_interfaces/schemas/function/function_output.py similarity index 56% rename from aixplain/model_interfaces/schemas/function_output.py rename to aixplain/model_interfaces/schemas/function/function_output.py index c4f35fe..af03289 100644 --- a/aixplain/model_interfaces/schemas/function_output.py +++ b/aixplain/model_interfaces/schemas/function/function_output.py @@ -3,23 +3,10 @@ import tornado.web from http import HTTPStatus -from aixplain.model_interfaces.schemas.function_input import AudioConfig, AudioEncoding +from aixplain.model_interfaces.schemas.function.function_input import AudioConfig, AudioEncoding from aixplain.model_interfaces.utils import serialize - -class APIOutput(BaseModel): - """The standardized schema of the aiXplain's API Output. - - :param data: - Processed output data from supplier model. - :type data: - Any - :param details: - Details of the output data. Optional. - :type details: - List[str] or Dict[str, str] - """ - data: Any - details: Optional[Union[List[str], Dict[str, str]]] = [] +from aixplain.model_interfaces.schemas.api.basic_api_output import APIOutput +from aixplain.model_interfaces.schemas.modality.modality_output import TextOutput class WordDetails(BaseModel): """The standardized schema of the aiXplain's representation of word @@ -127,29 +114,6 @@ def __init__(self, **input): reason="Incorrect types passed into SpeechRecognitionOutput" ) -class DiacritizationOutputSchema(APIOutput): - """The standardized schema of the aiXplain's Diacritization Output. - :param data: - Processed output data from supplier model. - :type data: - Any - :param details: - Details of the text segments generated. - :type details: - TextSegmentDetails - """ - details: TextSegmentDetails - -class DiacritizationOutput(DiacritizationOutputSchema): - def __init__(self, **input): - try: - super().__init__(**input) - except ValueError: - raise tornado.web.HTTPError( - status_code=HTTPStatus.BAD_REQUEST, - reason="Incorrect types passed into DiacritizationOutput" - ) - class ClassificationOutput(APIOutput): """The standardized schema of the aiXplain's Classification Output. :param predicted_labels: @@ -229,4 +193,182 @@ def __init__(self, **input): raise tornado.web.HTTPError( status_code=HTTPStatus.BAD_REQUEST, reason="Incorrect types passed into TextToImageGenerationOutput" - ) \ No newline at end of file + ) + +class TextGenerationOutputSchema(TextOutput): + """The standardized schema of the aiXplain's text generation output. + """ + details: Optional[Any] = "" + +class TextGenerationOutput(TextGenerationOutputSchema): + def __init__(self, **input): + try: + super().__init__(**input) + except ValueError: + raise tornado.web.HTTPError( + status_code=HTTPStatus.BAD_REQUEST, + reason="Incorrect types passed into TextGenerationOutput" + ) + +class TranslationOutputSchema(TextOutput): + """The standardized schema of the aiXplain's Translation Output. + :param data: + Processed output data from supplier model. + :type data: + Any + :param details: + Details of the text segments generated. + :type details: + TextSegmentDetails + """ + details: TextSegmentDetails + +class TranslationOutput(TranslationOutputSchema): + def __init__(self, **input): + try: + super().__init__(**input) + except ValueError: + raise tornado.web.HTTPError( + status_code=HTTPStatus.BAD_REQUEST, + reason="Incorrect types passed into TranslationOutput" + ) + +class TextSummarizationOutputSchema(TextOutput): + """The standardized schema of the aiXplain's Translation Output. + + :param data: + Processed output data from supplier model. + :type data: + str + :param details: + Details of the summary generated. + :type details: + Any. Optional. + """ + details: Optional[Any] + +class TextSummarizationOutput(TextSummarizationOutputSchema): + def __init__(self, **input): + try: + super().__init__(**input) + except ValueError: + raise tornado.web.HTTPError( + status_code=HTTPStatus.BAD_REQUEST, + reason="Incorrect types passed into TextSummarizationOutput" + ) + +class SearchOutputSchema(TextOutput): + """The standardized schema of the aiXplain's search output. + + :param data: + Processed output data from supplier model. + :type data: + str + :param details: + Details of the summary generated. + :type details: + Any. Optional. + """ + details: Optional[Any] + +class SearchOutput(SearchOutputSchema): + def __init__(self, **input): + try: + super().__init__(**input) + except ValueError: + raise tornado.web.HTTPError( + status_code=HTTPStatus.BAD_REQUEST, + reason="Incorrect types passed into SearchOutput" + ) + +class DiacritizationOutputSchema(TextOutput): + """The standardized schema of the aiXplain's Diacritization Output. + :param data: + Processed output data from supplier model. + :type data: + Any + :param details: + Details of the text segments generated. + :type details: + TextSegmentDetails + """ + details: TextSegmentDetails + +class DiacritizationOutput(DiacritizationOutputSchema): + def __init__(self, **input): + try: + super().__init__(**input) + except ValueError: + raise tornado.web.HTTPError( + status_code=HTTPStatus.BAD_REQUEST, + reason="Incorrect types passed into DiacritizationOutput" + ) + +class TextReconstructionOutputSchema(TextOutput): + """The standardized schema of the aiXplain's text reconstruction output. + :param data: + Processed output data from supplier model. + :type data: + Any + :param details: + Details of the text segments generated. + :type details: + TextSegmentDetails + """ + details: Optional[TextSegmentDetails] + +class TextReconstructionOutput(TextReconstructionOutputSchema): + def __init__(self, **input): + try: + super().__init__(**input) + except ValueError: + raise tornado.web.HTTPError( + status_code=HTTPStatus.BAD_REQUEST, + reason="Incorrect types passed into TextReconstructionOutput" + ) + +class FillTextMaskOutputSchema(TextOutput): + """The standardized schema of the aiXplain's fill-text-mask output. + :param data: + Processed output data from supplier model. + :type data: + Any + :param details: + Details of the text segments generated. + :type details: + TextSegmentDetails + """ + details: Optional[TextSegmentDetails] + +class FillTextMaskOutput(FillTextMaskOutputSchema): + def __init__(self, **input): + try: + super().__init__(**input) + except ValueError: + raise tornado.web.HTTPError( + status_code=HTTPStatus.BAD_REQUEST, + reason="Incorrect types passed into FillTextMaskOutput" + ) + +class SubtitleTranslationOutputSchema(TextOutput): + """The standardized schema of the aiXplain's subtitle translation output. + :param data: + Processed output data from supplier model. + :type data: + Any + :param details: + Details of the text segments generated. + :type details: + TextSegmentDetails + """ + details: Optional[TextSegmentDetails] + +class SubtitleTranslationOutput(SubtitleTranslationOutputSchema): + def __init__(self, **input): + try: + super().__init__(**input) + except ValueError: + raise tornado.web.HTTPError( + status_code=HTTPStatus.BAD_REQUEST, + reason="Incorrect types passed into SubtitleTranslationOutput" + ) diff --git a/aixplain/model_interfaces/schemas/metric_input.py b/aixplain/model_interfaces/schemas/metric/metric_input.py similarity index 100% rename from aixplain/model_interfaces/schemas/metric_input.py rename to aixplain/model_interfaces/schemas/metric/metric_input.py diff --git a/aixplain/model_interfaces/schemas/metric_output.py b/aixplain/model_interfaces/schemas/metric/metric_output.py similarity index 98% rename from aixplain/model_interfaces/schemas/metric_output.py rename to aixplain/model_interfaces/schemas/metric/metric_output.py index a237026..5e33acf 100644 --- a/aixplain/model_interfaces/schemas/metric_output.py +++ b/aixplain/model_interfaces/schemas/metric/metric_output.py @@ -3,7 +3,7 @@ import tornado.web from http import HTTPStatus -from aixplain.model_interfaces.schemas.metric_input import MetricAggregate +from aixplain.model_interfaces.schemas.metric.metric_input import MetricAggregate class MetricOutput(BaseModel): """The standardized schema of the aiXplain's Metric API Output. diff --git a/aixplain/model_interfaces/schemas/modality/modality_input.py b/aixplain/model_interfaces/schemas/modality/modality_input.py new file mode 100644 index 0000000..82eac4e --- /dev/null +++ b/aixplain/model_interfaces/schemas/modality/modality_input.py @@ -0,0 +1,27 @@ +""" +Modality input classes for modality-based model classification +""" +from http import HTTPStatus +from typing import Optional, Any, List + +from aixplain.model_interfaces.schemas.api.basic_api_input import APIInput + +class TextInput(APIInput): + """The standardized schema of the aiXplain's text API inputs. + + :param data: + Input data to the model. + :type data: + str + """ + data: str + +class TextListInput(APIInput): + """The standardized schema of the aiXplain's text list API inputs. + + :param data: + Input data to the model. + :type data: + List[str] + """ + data: List[str] \ No newline at end of file diff --git a/aixplain/model_interfaces/schemas/modality/modality_output.py b/aixplain/model_interfaces/schemas/modality/modality_output.py new file mode 100644 index 0000000..39a2fe3 --- /dev/null +++ b/aixplain/model_interfaces/schemas/modality/modality_output.py @@ -0,0 +1,27 @@ +""" +Modality output classes for modality-based model classification +""" +from http import HTTPStatus +from typing import Optional, Any, List + +from aixplain.model_interfaces.schemas.api.basic_api_output import APIOutput + +class TextOutput(APIOutput): + """The standardized schema of the aiXplain's text API outputs. + + :param data: + Output data from the model. + :type data: + str + """ + data: str + +class TextListOutput(APIOutput): + """The standardized schema of the aiXplain's text list API utputs. + + :param data: + Output data from the model. + :type data: + List[str] + """ + data: List[str] \ No newline at end of file diff --git a/docs/user/samples/speech-enhancement/src/model.py b/docs/user/samples/speech-enhancement/src/model.py index 014fef3..804aa0a 100644 --- a/docs/user/samples/speech-enhancement/src/model.py +++ b/docs/user/samples/speech-enhancement/src/model.py @@ -3,8 +3,8 @@ from aixplain.model_interfaces.interfaces.aixplain_model_server import AixplainModelServer from aixplain.model_interfaces.interfaces.asset_resolver import AssetResolver -from aixplain.model_interfaces.schemas.function_input import AudioEncoding, SpeechEnhancementInput -from aixplain.model_interfaces.schemas.function_output import SpeechEnhancementOutput +from aixplain.model_interfaces.schemas.function.function_input import AudioEncoding, SpeechEnhancementInput +from aixplain.model_interfaces.schemas.function.function_output import SpeechEnhancementOutput from aixplain.model_interfaces.interfaces.function_models import SpeechEnhancementModel from aixplain.model_interfaces.utils import serialize diff --git a/docs/user/samples/speech-enhancement/src/test_dtln_speech_enhancement.py b/docs/user/samples/speech-enhancement/src/test_dtln_speech_enhancement.py index 31fcb61..f302fa3 100644 --- a/docs/user/samples/speech-enhancement/src/test_dtln_speech_enhancement.py +++ b/docs/user/samples/speech-enhancement/src/test_dtln_speech_enhancement.py @@ -1,5 +1,5 @@ from aixplain.model_interfaces.interfaces.asset_resolver import AssetResolver -from aixplain.model_interfaces.schemas.function_input import AudioEncoding +from aixplain.model_interfaces.schemas.function.function_input import AudioEncoding from aixplain.model_interfaces.utils.serialize import ( audio_file_handle, diff --git a/docs/user/samples/translation/src/model.py b/docs/user/samples/translation/src/model.py index cf75717..596169a 100644 --- a/docs/user/samples/translation/src/model.py +++ b/docs/user/samples/translation/src/model.py @@ -5,8 +5,8 @@ from aixplain.model_interfaces.interfaces.aixplain_model_server import AixplainModelServer from aixplain.model_interfaces.interfaces.asset_resolver import AssetResolver -from aixplain.model_interfaces.schemas.function_input import TranslationInput -from aixplain.model_interfaces.schemas.function_output import TextSegmentDetails, TranslationOutput +from aixplain.model_interfaces.schemas.function.function_input import TranslationInput +from aixplain.model_interfaces.schemas.function.function_output import TextSegmentDetails, TranslationOutput from aixplain.model_interfaces.interfaces.function_models import TranslationModel MODEL_NOT_FOUND_ERROR = """ diff --git a/pyproject.toml b/pyproject.toml index cdd0d1e..d375397 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,13 +15,14 @@ includes = ["aixplain"] [project] name = "model-interfaces" -version = "0.0.1" +version = "0.0.2rc2" description = "A package specifying the model interfaces supported by aiXplain" license = { text = "Apache License, Version 2.0: http://www.apache.org/licenses/LICENSE-2.0" } dependencies = [ "kserve>=0.10.0", "multiprocess==0.70.14", "protobuf>=3.19.4", + "pyarrow==15.0.0", "pydantic>=1.9.1", "pydub>=0.25.1", "requests>=2.28.1", diff --git a/tests/unit_tests/metrics/audio_generation/test_audio_generation_metrics.py b/tests/unit_tests/metrics/audio_generation/test_audio_generation_metrics.py index 7af5108..bfcb030 100644 --- a/tests/unit_tests/metrics/audio_generation/test_audio_generation_metrics.py +++ b/tests/unit_tests/metrics/audio_generation/test_audio_generation_metrics.py @@ -5,7 +5,7 @@ import tempfile from aixplain.model_interfaces.interfaces.metric_models import AudioGenerationMetric -from aixplain.model_interfaces.schemas.metric_input import AudioGenerationMetricInput +from aixplain.model_interfaces.schemas.metric.metric_input import AudioGenerationMetricInput INPUTS_PATH = "tests/unit_tests/metrics/audio_generation/inputs.json" OUTPUTS_PATH = "tests/unit_tests/metrics/audio_generation/outputs.json" diff --git a/tests/unit_tests/metrics/referenceless_audio_generation/test_referenceless_audio_generation_metrics.py b/tests/unit_tests/metrics/referenceless_audio_generation/test_referenceless_audio_generation_metrics.py index 36df51e..37b2f11 100644 --- a/tests/unit_tests/metrics/referenceless_audio_generation/test_referenceless_audio_generation_metrics.py +++ b/tests/unit_tests/metrics/referenceless_audio_generation/test_referenceless_audio_generation_metrics.py @@ -7,7 +7,7 @@ from fastapi import status import os from aixplain.model_interfaces.interfaces.metric_models import ReferencelessAudioGenerationMetric -from aixplain.model_interfaces.schemas.metric_input import ReferencelessAudioGenerationMetricInput +from aixplain.model_interfaces.schemas.metric.metric_input import ReferencelessAudioGenerationMetricInput INPUTS_PATH = "tests/unit_tests/metrics/referenceless_audio_generation/inputs.json" OUTPUTS_PATH = "tests/unit_tests/metrics/referenceless_audio_generation/outputs.json" diff --git a/tests/unit_tests/metrics/referenceless_text_generation/test_referenceless_text_generation_metrics.py b/tests/unit_tests/metrics/referenceless_text_generation/test_referenceless_text_generation_metrics.py index d8a8e90..166f696 100644 --- a/tests/unit_tests/metrics/referenceless_text_generation/test_referenceless_text_generation_metrics.py +++ b/tests/unit_tests/metrics/referenceless_text_generation/test_referenceless_text_generation_metrics.py @@ -8,7 +8,7 @@ from aixplain.model_interfaces.interfaces.metric_models import ( ReferencelessTextGenerationMetric ) -from aixplain.model_interfaces.schemas.metric_input import ReferencelessTextGenerationMetricInput +from aixplain.model_interfaces.schemas.metric.metric_input import ReferencelessTextGenerationMetricInput import os INPUTS_PATH="tests/unit_tests/metrics/referenceless_text_generation/inputs.json" diff --git a/tests/unit_tests/metrics/text_generation/test_text_generation_metrics.py b/tests/unit_tests/metrics/text_generation/test_text_generation_metrics.py index 6001b7c..bb30bfe 100644 --- a/tests/unit_tests/metrics/text_generation/test_text_generation_metrics.py +++ b/tests/unit_tests/metrics/text_generation/test_text_generation_metrics.py @@ -8,7 +8,7 @@ from aixplain.model_interfaces.interfaces.metric_models import ( TextGenerationMetric ) -from aixplain.model_interfaces.schemas.metric_input import TextGenerationMetricInput +from aixplain.model_interfaces.schemas.metric.metric_input import TextGenerationMetricInput INPUTS_PATH="tests/unit_tests/metrics/text_generation/inputs.json" OUTPUTS_PATH="tests/unit_tests/metrics/text_generation/outputs.json" diff --git a/tests/unit_tests/models/test_mock_classification.py b/tests/unit_tests/models/test_mock_classification.py index 1d841d2..3367108 100644 --- a/tests/unit_tests/models/test_mock_classification.py +++ b/tests/unit_tests/models/test_mock_classification.py @@ -1,6 +1,6 @@ from unittest.mock import Mock -from aixplain.model_interfaces.schemas.function_input import ClassificationInput -from aixplain.model_interfaces.schemas.function_output import Label, ClassificationOutput +from aixplain.model_interfaces.schemas.function.function_input import ClassificationInput +from aixplain.model_interfaces.schemas.function.function_output import Label, ClassificationOutput from aixplain.model_interfaces.interfaces.function_models import ClassificationModel from typing import Dict, List @@ -31,7 +31,7 @@ def test_predict(self): assert output_dict["predicted_labels"][0]["confidence"] == 0.7 class MockModel(ClassificationModel): - def run_model(self, api_input: Dict[str, List[ClassificationInput]]) -> Dict[str, List[ClassificationOutput]]: + def run_model(self, api_input: Dict[str, List[ClassificationInput]], headers: Dict[str, str] = None) -> Dict[str, List[ClassificationOutput]]: instances = api_input["instances"] predictions_list = [] # There's only 1 instance in this case. diff --git a/tests/unit_tests/models/test_mock_diacritization.py b/tests/unit_tests/models/test_mock_diacritization.py index a20494f..c9986bf 100644 --- a/tests/unit_tests/models/test_mock_diacritization.py +++ b/tests/unit_tests/models/test_mock_diacritization.py @@ -1,6 +1,6 @@ from unittest.mock import Mock -from aixplain.model_interfaces.schemas.function_input import DiacritizationInput -from aixplain.model_interfaces.schemas.function_output import TextSegmentDetails, DiacritizationOutput +from aixplain.model_interfaces.schemas.function.function_input import DiacritizationInput +from aixplain.model_interfaces.schemas.function.function_output import TextSegmentDetails, DiacritizationOutput from aixplain.model_interfaces.interfaces.function_models import DiacritizationModel from typing import Dict, List @@ -31,7 +31,7 @@ def test_predict(self): assert output_dict["details"]["confidence"] == 0.7 class MockModel(DiacritizationModel): - def run_model(self, api_input: Dict[str, List[DiacritizationInput]]) -> Dict[str, List[DiacritizationOutput]]: + def run_model(self, api_input: Dict[str, List[DiacritizationInput]], headers: Dict[str, str] = None) -> Dict[str, List[DiacritizationOutput]]: instances = api_input["instances"] predictions_list = [] # There's only 1 instance in this case. diff --git a/tests/unit_tests/models/test_mock_fill_text_mask.py b/tests/unit_tests/models/test_mock_fill_text_mask.py new file mode 100644 index 0000000..bf77535 --- /dev/null +++ b/tests/unit_tests/models/test_mock_fill_text_mask.py @@ -0,0 +1,57 @@ +from unittest.mock import Mock +from aixplain.model_interfaces.schemas.function.function_input import FillTextMaskInput +from aixplain.model_interfaces.schemas.function.function_output import FillTextMaskOutput +from aixplain.model_interfaces.interfaces.function_models import FillTextMaskModel +from typing import Dict, List + +class TestMockFillTextMask(): + def test_predict(self): + data = "Text to reconstruct." + supplier = "mockGpt" + function = "fill-text-mask" + version = "" + language = "en" + dialect = "American" + script = "mock script" + + + input_dict = { + "data": data, + "supplier": supplier, + "function": function, + "version": version, + "language": language, + "dialect": dialect, + "script": script + } + + predict_input = {"instances": [input_dict]} + + mock_model = MockModel("Mock") + predict_output = mock_model.predict(predict_input) + output_dict = predict_output["predictions"][0] + + assert output_dict["data"] == "We are filling a text mask." + +class MockModel(FillTextMaskModel): + def run_model(self, api_input: Dict[str, List[FillTextMaskInput]], headers: Dict[str, str] = None) -> Dict[str, List[FillTextMaskOutput]]: + instances = api_input["instances"] + predictions_list = [] + # There's only 1 instance in this case. + for instance in instances: + instance_data = instance.dict() + model_instance = Mock() + model_instance.process_data.return_value = "We are filling a text mask." + result = model_instance.process_data(instance_data["data"]) + model_instance.delete() + + # Map back onto FillTextMaskOutputs + data = result + + output_dict = { + "data": data, + } + output = FillTextMaskOutput(**output_dict) + predictions_list.append(output) + predict_output = {"predictions": predictions_list} + return predict_output \ No newline at end of file diff --git a/tests/unit_tests/models/test_mock_speech_enhancement.py b/tests/unit_tests/models/test_mock_speech_enhancement.py index 71436f4..b383c88 100644 --- a/tests/unit_tests/models/test_mock_speech_enhancement.py +++ b/tests/unit_tests/models/test_mock_speech_enhancement.py @@ -1,6 +1,6 @@ from unittest.mock import Mock -from aixplain.model_interfaces.schemas.function_input import SpeechEnhancementInput, AudioEncoding -from aixplain.model_interfaces.schemas.function_output import SpeechEnhancementOutput +from aixplain.model_interfaces.schemas.function.function_input import SpeechEnhancementInput, AudioEncoding +from aixplain.model_interfaces.schemas.function.function_output import SpeechEnhancementOutput from aixplain.model_interfaces.interfaces.function_models import SpeechEnhancementModel from typing import Dict, List @@ -38,7 +38,7 @@ def test_predict(self): assert output_dict["data"] == "VGhpcyBpcyBhbiBhdWRpbyBvdXRwdXQ=" class MockModel(SpeechEnhancementModel): - def run_model(self, api_input: Dict[str, List[SpeechEnhancementInput]]) -> Dict[str, List[SpeechEnhancementOutput]]: + def run_model(self, api_input: Dict[str, List[SpeechEnhancementInput]], headers: Dict[str, str] = None) -> Dict[str, List[SpeechEnhancementOutput]]: instances = api_input["instances"] predictions_list = [] # There's only 1 instance in this case. diff --git a/tests/unit_tests/models/test_mock_speech_recognition.py b/tests/unit_tests/models/test_mock_speech_recognition.py index a58f229..5aaa599 100644 --- a/tests/unit_tests/models/test_mock_speech_recognition.py +++ b/tests/unit_tests/models/test_mock_speech_recognition.py @@ -1,6 +1,6 @@ from unittest.mock import Mock -from aixplain.model_interfaces.schemas.function_input import SpeechRecognitionInput, AudioEncoding -from aixplain.model_interfaces.schemas.function_output import TextSegmentDetails, SpeechRecognitionOutput +from aixplain.model_interfaces.schemas.function.function_input import SpeechRecognitionInput, AudioEncoding +from aixplain.model_interfaces.schemas.function.function_output import TextSegmentDetails, SpeechRecognitionOutput from aixplain.model_interfaces.interfaces.function_models import SpeechRecognitionModel from typing import Dict, List @@ -40,7 +40,7 @@ def test_predict(self): assert output_dict["details"]["confidence"] == 0.7 class MockModel(SpeechRecognitionModel): - def run_model(self, api_input: Dict[str, List[SpeechRecognitionInput]]) -> Dict[str, List[SpeechRecognitionOutput]]: + def run_model(self, api_input: Dict[str, List[SpeechRecognitionInput]], headers: Dict[str, str] = None) -> Dict[str, List[SpeechRecognitionOutput]]: instances = api_input["instances"] predictions_list = [] # There's only 1 instance in this case. diff --git a/tests/unit_tests/models/test_mock_subtitle_translation.py b/tests/unit_tests/models/test_mock_subtitle_translation.py new file mode 100644 index 0000000..2303d56 --- /dev/null +++ b/tests/unit_tests/models/test_mock_subtitle_translation.py @@ -0,0 +1,59 @@ +from unittest.mock import Mock +from aixplain.model_interfaces.schemas.function.function_input import SubtitleTranslationInput +from aixplain.model_interfaces.schemas.function.function_output import SubtitleTranslationOutput +from aixplain.model_interfaces.interfaces.function_models import SubtitleTranslationModel +from typing import Dict, List + +class TestMockSearch(): + def test_predict(self): + data = "Text to be searched." + supplier = "mockGpt" + function = "search" + version = "" + source_language = "en" + dialect_in = "American" + target_supplier = "mock supplier" + target_languages = ["fr", "de"] + + + input_dict = { + "data": data, + "supplier": supplier, + "function": function, + "version": version, + "source_language": source_language, + "dialect_in": dialect_in, + "target_supplier": target_supplier, + "target_languages": target_languages + } + + predict_input = {"instances": [input_dict]} + + mock_model = MockModel("Mock") + predict_output = mock_model.predict(predict_input) + output_dict = predict_output["predictions"][0] + + assert output_dict["data"] == "This is a subtitle translation." + +class MockModel(SubtitleTranslationModel): + def run_model(self, api_input: Dict[str, List[SubtitleTranslationInput]], headers: Dict[str, str] = None) -> Dict[str, List[SubtitleTranslationOutput]]: + instances = api_input["instances"] + predictions_list = [] + # There's only 1 instance in this case. + for instance in instances: + instance_data = instance.dict() + model_instance = Mock() + model_instance.process_data.return_value = "This is a subtitle translation." + result = model_instance.process_data(instance_data["data"]) + model_instance.delete() + + # Map back onto SubtitleTranslationOutputs + data = result + + output_dict = { + "data": data, + } + search_output = SubtitleTranslationOutput(**output_dict) + predictions_list.append(search_output) + predict_output = {"predictions": predictions_list} + return predict_output \ No newline at end of file diff --git a/tests/unit_tests/models/test_mock_text_generation.py b/tests/unit_tests/models/test_mock_text_generation.py new file mode 100644 index 0000000..b65aaa4 --- /dev/null +++ b/tests/unit_tests/models/test_mock_text_generation.py @@ -0,0 +1,61 @@ +from unittest.mock import Mock +from aixplain.model_interfaces.schemas.function.function_input import TextGenerationInput +from aixplain.model_interfaces.schemas.function.function_output import TextGenerationOutput +from aixplain.model_interfaces.interfaces.function_models import TextGenerationModel +from typing import Dict, List + +class TestMockTextGeneration(): + def test_predict(self): + data = "Hello, how are you?" + supplier = "mockGpt" + function = "text-generation" + version = "" + language = "" + temperature = 1.0 + max_new_tokens = 200 + top_p = 0.8 + top_k = 40 + num_return_sequences = 1 + + text_generation_input_dict = { + "data": data, + "supplier": supplier, + "function": function, + "version": version, + "language": language, + "temperature": temperature, + "max_new_tokens": max_new_tokens, + "top_p": top_p, + "top_k": top_k, + "num_return_sequences": num_return_sequences + } + predict_input = {"instances": [text_generation_input_dict]} + + mock_model = MockModel("Mock") + predict_output = mock_model.predict(predict_input) + text_generation_output_dict = predict_output["predictions"][0] + + assert text_generation_output_dict["data"] == "I am a text generation model." + +class MockModel(TextGenerationModel): + def run_model(self, api_input: Dict[str, List[TextGenerationInput]], headers: Dict[str, str] = None) -> Dict[str, List[TextGenerationOutput]]: + instances = api_input["instances"] + predictions_list = [] + # There's only 1 instance in this case. + for instance in instances: + instance_data = instance.dict() + model_instance = Mock() + model_instance.process_data.return_value = "I am a text generation model." + result = model_instance.process_data(instance_data["data"]) + model_instance.delete() + + # Map back onto TextGenerationOutput + data = result + + output_dict = { + "data": data, + } + text_generation_output = TextGenerationOutput(**output_dict) + predictions_list.append(text_generation_output) + predict_output = {"predictions": predictions_list} + return predict_output \ No newline at end of file diff --git a/tests/unit_tests/models/test_mock_text_reconstruction.py b/tests/unit_tests/models/test_mock_text_reconstruction.py new file mode 100644 index 0000000..536283e --- /dev/null +++ b/tests/unit_tests/models/test_mock_text_reconstruction.py @@ -0,0 +1,53 @@ +from unittest.mock import Mock +from aixplain.model_interfaces.schemas.function.function_input import TextReconstructionInput +from aixplain.model_interfaces.schemas.function.function_output import TextReconstructionOutput +from aixplain.model_interfaces.interfaces.function_models import TextReconstructionModel +from typing import Dict, List + +class TestMockTextReconstruction(): + def test_predict(self): + data = "Text to reconstruct." + supplier = "mockGpt" + function = "text-reconstruction" + version = "" + language = "en" + + + input_dict = { + "data": data, + "supplier": supplier, + "function": function, + "version": version, + "language": language + } + + predict_input = {"instances": [input_dict]} + + mock_model = MockModel("Mock") + predict_output = mock_model.predict(predict_input) + output_dict = predict_output["predictions"][0] + + assert output_dict["data"] == "This is a text reconstruction." + +class MockModel(TextReconstructionModel): + def run_model(self, api_input: Dict[str, List[TextReconstructionInput]], headers: Dict[str, str] = None) -> Dict[str, List[TextReconstructionOutput]]: + instances = api_input["instances"] + predictions_list = [] + # There's only 1 instance in this case. + for instance in instances: + instance_data = instance.dict() + model_instance = Mock() + model_instance.process_data.return_value = "This is a text reconstruction." + result = model_instance.process_data(instance_data["data"]) + model_instance.delete() + + # Map back onto TextReconstructionOutputs + data = result + + output_dict = { + "data": data, + } + output = TextReconstructionOutput(**output_dict) + predictions_list.append(output) + predict_output = {"predictions": predictions_list} + return predict_output \ No newline at end of file diff --git a/tests/unit_tests/models/test_mock_text_summarization b/tests/unit_tests/models/test_mock_text_summarization new file mode 100644 index 0000000..d9e91cf --- /dev/null +++ b/tests/unit_tests/models/test_mock_text_summarization @@ -0,0 +1,57 @@ +from unittest.mock import Mock +from aixplain.model_interfaces.schemas.function.function_input import TextSummarizationInput +from aixplain.model_interfaces.schemas.function.function_output import TextSummarizationOutput +from aixplain.model_interfaces.interfaces.function_models import TextSummarizationModel +from typing import Dict, List + +class TestMockTextSummarization(): + def test_predict(self): + data = "Text to be summarized." + supplier = "mockGpt" + function = "text-generation" + version = "" + language = "en" + script = "" + dialect = "American" + + + input_dict = { + "data": data, + "supplier": supplier, + "function": function, + "version": version, + "language": language, + "script": script, + "dialect": dialect + } + + predict_input = {"instances": [input_dict]} + + mock_model = MockModel("Mock") + predict_output = mock_model.predict(predict_input) + summarization_output_dict = predict_output["predictions"][0] + + assert summarization_output_dict["data"] == "This is a summary" + +class MockModel(TextSummarizationModel): + def run_model(self, api_input: Dict[str, List[TextSummarizationInput]], headers: Dict[str, str] = None) -> Dict[str, List[TextSummarizationOutput]]: + instances = api_input["instances"] + predictions_list = [] + # There's only 1 instance in this case. + for instance in instances: + instance_data = instance.dict() + model_instance = Mock() + model_instance.process_data.return_value = "This is a summary" + result = model_instance.process_data(instance_data["data"]) + model_instance.delete() + + # Map back onto TextSummarizationOutput + data = result + + output_dict = { + "data": data, + } + text_summarization_output = TextSummarizationOutput(**output_dict) + predictions_list.append(text_summarization_output) + predict_output = {"predictions": predictions_list} + return predict_output \ No newline at end of file diff --git a/tests/unit_tests/models/test_mock_translation.py b/tests/unit_tests/models/test_mock_translation.py index eb35bed..7f60da1 100644 --- a/tests/unit_tests/models/test_mock_translation.py +++ b/tests/unit_tests/models/test_mock_translation.py @@ -1,6 +1,6 @@ from unittest.mock import Mock -from aixplain.model_interfaces.schemas.function_input import TranslationInput -from aixplain.model_interfaces.schemas.function_output import TextSegmentDetails, TranslationOutput +from aixplain.model_interfaces.schemas.function.function_input import TranslationInput +from aixplain.model_interfaces.schemas.function.function_output import TextSegmentDetails, TranslationOutput from aixplain.model_interfaces.interfaces.function_models import TranslationModel from typing import Dict, List @@ -39,7 +39,7 @@ def test_predict(self): assert translation_output_dict["details"]["confidence"] == 0.7 class MockModel(TranslationModel): - def run_model(self, api_input: Dict[str, List[TranslationInput]]) -> Dict[str, List[TranslationOutput]]: + def run_model(self, api_input: Dict[str, List[TranslationInput]], headers: Dict[str, str] = None) -> Dict[str, List[TranslationOutput]]: instances = api_input["instances"] predictions_list = [] # There's only 1 instance in this case. diff --git a/tests/unit_tests/models/text_mock_search.py b/tests/unit_tests/models/text_mock_search.py new file mode 100644 index 0000000..55a4687 --- /dev/null +++ b/tests/unit_tests/models/text_mock_search.py @@ -0,0 +1,57 @@ +from unittest.mock import Mock +from aixplain.model_interfaces.schemas.function.function_input import SearchInput +from aixplain.model_interfaces.schemas.function.function_output import SearchOutput +from aixplain.model_interfaces.interfaces.function_models import SearchModel +from typing import Dict, List + +class TestMockSearch(): + def test_predict(self): + data = "Text to be searched." + supplier = "mockGpt" + function = "search" + version = "" + language = "en" + script = "" + supplier_model_id = "mockID" + + + input_dict = { + "data": data, + "supplier": supplier, + "function": function, + "version": version, + "language": language, + "script": script, + "supplier_model_id": supplier_model_id + } + + predict_input = {"instances": [input_dict]} + + mock_model = MockModel("Mock") + predict_output = mock_model.predict(predict_input) + output_dict = predict_output["predictions"][0] + + assert output_dict["data"] == "This is a search output." + +class MockModel(SearchModel): + def run_model(self, api_input: Dict[str, List[SearchInput]], headers: Dict[str, str] = None) -> Dict[str, List[SearchOutput]]: + instances = api_input["instances"] + predictions_list = [] + # There's only 1 instance in this case. + for instance in instances: + instance_data = instance.dict() + model_instance = Mock() + model_instance.process_data.return_value = "This is a search output." + result = model_instance.process_data(instance_data["data"]) + model_instance.delete() + + # Map back onto SearchOutputs + data = result + + output_dict = { + "data": data, + } + search_output = SearchOutput(**output_dict) + predictions_list.append(search_output) + predict_output = {"predictions": predictions_list} + return predict_output \ No newline at end of file