Skip to content

Commit

Permalink
Merge pull request #18 from aixplain/modality
Browse files Browse the repository at this point in the history
Text-to-text for text generation
  • Loading branch information
mikelam-us-aixplain authored Jul 19, 2024
2 parents b22a1ec + 4111640 commit 2c975b2
Show file tree
Hide file tree
Showing 33 changed files with 1,253 additions and 214 deletions.
45 changes: 36 additions & 9 deletions aixplain/model_interfaces/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from aixplain.model_interfaces.schemas.function_input import (
from aixplain.model_interfaces.schemas.function.function_input import (
APIInput,
AudioEncoding,
AudioConfig,
Expand All @@ -8,10 +8,16 @@
ClassificationInput,
SpeechEnhancementInput,
SpeechSynthesisInput,
TextToImageGenerationInput
TextToImageGenerationInput,
TextGenerationInput,
TextSummarizationInput,
SearchInput,
TextReconstructionInput,
FillTextMaskInput,
SubtitleTranslationInput
)

from aixplain.model_interfaces.schemas.function_output import(
from aixplain.model_interfaces.schemas.function.function_output import(
APIOutput,
WordDetails,
TextSegmentDetails,
Expand All @@ -21,10 +27,16 @@
DiacritizationOutput,
ClassificationOutput,
SpeechEnhancementOutput,
TextToImageGenerationOutput
TextToImageGenerationOutput,
TextGenerationOutput,
TextSummarizationOutput,
SearchOutput,
TextReconstructionOutput,
FillTextMaskOutput,
SubtitleTranslationOutput
)

from aixplain.model_interfaces.schemas.metric_input import(
from aixplain.model_interfaces.schemas.metric.metric_input import(
MetricInput,
MetricAggregate,
TextGenerationSettings,
Expand All @@ -38,7 +50,7 @@
NamedEntityRecognitionMetricInput
)

from aixplain.model_interfaces.schemas.metric_output import(
from aixplain.model_interfaces.schemas.metric.metric_output import(
MetricOutput,
TextGenerationMetricOutput,
ReferencelessTextGenerationMetricOutput,
Expand All @@ -55,7 +67,14 @@
ClassificationModel,
SpeechEnhancementModel,
SpeechSynthesis,
TextToImageGeneration
TextToImageGeneration,
TextGenerationModel,
TextGenerationChatModel,
TextSummarizationModel,
SearchModel,
TextReconstructionModel,
FillTextMaskModel,
SubtitleTranslationModel
)

from aixplain.model_interfaces.interfaces.metric_models import(
Expand All @@ -74,7 +93,14 @@
ClassificationModel,
SpeechEnhancementModel,
SpeechSynthesis,
TextToImageGeneration
TextToImageGeneration,
TextGenerationModel,
TextGenerationChatModel,
TextSummarizationModel,
SearchModel,
TextReconstructionModel,
FillTextMaskModel,
SubtitleTranslationModel
]

function_classes_input = [
Expand All @@ -87,7 +113,8 @@
ClassificationInput,
SpeechEnhancementInput,
SpeechSynthesisInput,
TextToImageGenerationInput
TextToImageGenerationInput,
TextGenerationInput
]

metric_classes_input = [
Expand Down
2 changes: 1 addition & 1 deletion aixplain/model_interfaces/__version__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
__title__ = "model-interfaces"
__description__ = "model-interfaces is the interface to host your models on aiXplain"
__url__ = "https://github.com/aixplain/aixplain-models/tree/main/docs"
__version__ = "0.0.1"
__version__ = "0.0.2rc2"
__author__ = "Duraikrishna Selvaraju and Michael Lam"
__author_email__ = "[email protected]"
__license__ = "http://www.apache.org/licenses/LICENSE-2.0"
Expand Down
4 changes: 2 additions & 2 deletions aixplain/model_interfaces/interfaces/aixplain_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
import time
from typing import Dict, List

from aixplain.model_interfaces.schemas.metric_input import MetricInput, MetricAggregate
from aixplain.model_interfaces.schemas.metric_output import MetricOutput
from aixplain.model_interfaces.schemas.metric.metric_input import MetricInput, MetricAggregate
from aixplain.model_interfaces.schemas.metric.metric_output import MetricOutput

class MetricType(Enum):
SCORE = 1
Expand Down
4 changes: 2 additions & 2 deletions aixplain/model_interfaces/interfaces/aixplain_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from kserve.model import Model
from typing import Dict, List
from aixplain.model_interfaces.schemas.function_input import APIInput
from aixplain.model_interfaces.schemas.function_output import APIOutput
from aixplain.model_interfaces.schemas.function.function_input import APIInput
from aixplain.model_interfaces.schemas.function.function_output import APIOutput

class AixplainModel(Model):

Expand Down
200 changes: 195 additions & 5 deletions aixplain/model_interfaces/interfaces/function_models.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,42 @@
import tornado.web

from http import HTTPStatus
from typing import Dict, List
from typing import Dict, List, Union, Optional, Text
from enum import Enum
from pydantic import BaseModel, validate_call

from aixplain.model_interfaces.schemas.function_input import (
from aixplain.model_interfaces.schemas.function.function_input import (
TranslationInput,
SpeechRecognitionInput,
DiacritizationInput,
ClassificationInput,
SpeechEnhancementInput,
SpeechSynthesisInput,
TextToImageGenerationInput
TextToImageGenerationInput,
TextGenerationInput,
TextSummarizationInput,
SearchInput,
TextReconstructionInput,
FillTextMaskInput,
SubtitleTranslationInput
)
from aixplain.model_interfaces.schemas.function_output import (
from aixplain.model_interfaces.schemas.function.function_output import (
TranslationOutput,
SpeechRecognitionOutput,
DiacritizationOutput,
ClassificationOutput,
SpeechEnhancementOutput,
SpeechSynthesisOutput,
TextToImageGenerationOutput
TextToImageGenerationOutput,
TextGenerationOutput,
TextSummarizationOutput,
SearchOutput,
TextReconstructionOutput,
FillTextMaskOutput,
SubtitleTranslationOutput
)
from aixplain.model_interfaces.schemas.modality.modality_input import TextInput, TextListInput
from aixplain.model_interfaces.schemas.modality.modality_output import TextListOutput
from aixplain.model_interfaces.interfaces.aixplain_model import AixplainModel

class TranslationModel(AixplainModel):
Expand Down Expand Up @@ -159,6 +175,7 @@ def predict(self, request: Dict[str, str], headers: Dict[str, str] = None) -> Di
SpeechSynthesisOutput(**speech_synthesis_dict)
speech_synthesis_output["instances"][i] = speech_synthesis_dict
return speech_synthesis_output

class TextToImageGeneration(AixplainModel):
def run_model(self, api_input: Dict[str, List[TextToImageGenerationInput]], headers: Dict[str, str] = None) -> Dict[str, List[TextToImageGenerationOutput]]:
pass
Expand All @@ -179,3 +196,176 @@ def predict(self, request: Dict[str, str], headers: Dict[str, str] = None) -> Di
TextToImageGenerationOutput(**text_to_image_generation_dict)
text_to_image_generation_output["predictions"][i] = text_to_image_generation_dict
return text_to_image_generation_output

class TextGenerationChatTemplatizeInput(BaseModel):
data: List[Dict]

class TextGenerationPredictInput(BaseModel):
instances: Union[List[TextGenerationInput], List[TextListInput], List[TextGenerationChatTemplatizeInput]]
function: Optional[Text] = "PREDICT"

class TextGenerationRunModelOutput(BaseModel):
predictions: List[TextGenerationOutput]

class TextGenerationTokenizeOutput(BaseModel):
token_counts: List[List[int]]

class TextGenerationModel(AixplainModel):
@validate_call
def predict(self, request: TextGenerationPredictInput, headers: Dict[str, str] = None) -> Dict:
instances = request.instances
if request.function.upper() == "PREDICT":
predict_output = {
"predictions": self.run_model(instances, headers)
}
return predict_output
elif request.function.upper() == "TOKENIZE":
token_counts_output = {
"token_counts": self.tokenize(instances, headers)
}
return token_counts_output
else:
raise ValueError("Invalid function.")

@validate_call
def run_model(self, api_input: List[TextGenerationInput], headers: Dict[str, str] = None) -> List[TextGenerationOutput]:
raise NotImplementedError

@validate_call
def tokenize(self, api_input: List[TextListInput], headers: Dict[str, str] = None) -> List[List[int]]:
raise NotImplementedError

class TextGenerationChatModel(TextGenerationModel):
@validate_call
def run_model(self, api_input: List[TextInput], headers: Dict[str, str] = None) -> List[TextGenerationOutput]:
raise NotImplementedError

@validate_call
def predict(self, request: TextGenerationPredictInput, headers: Dict[str, str] = None) -> Dict:
instances = request.instances
if request.function.upper() == "PREDICT":
predict_output = {
"predictions": self.run_model(instances, headers)
}
return predict_output
elif request.function.upper() == "TOKENIZE":
token_counts_output = {
"token_counts": self.tokenize(instances, headers)
}
return token_counts_output
elif request.function.upper() == "TEMPLATIZE":
templatize_output = {
"prompts": self.templatize(instances, headers)
}
return templatize_output
else:
raise ValueError("Invalid function.")

@validate_call
def templatize(self, api_input: List[TextGenerationChatTemplatizeInput], headers: Dict[str, str] = None) -> List[Text]:
pass

class TextSummarizationModel(AixplainModel):
def run_model(self, api_input: Dict[str, List[TextSummarizationInput]], headers: Dict[str, str] = None) -> Dict[str, List[TextSummarizationOutput]]:
pass

def predict(self, request: Dict[str, str], headers: Dict[str, str] = None) -> Dict:
instances = request['instances']
text_summarization_input_list = []
# Convert JSON serializables into TextSummarizationInputs
for instance in instances:
text_summarization_input = TextSummarizationInput(**instance)
text_summarization_input_list.append(text_summarization_input)

text_summarization_output = self.run_model({"instances": text_summarization_input_list})

# Convert JSON serializables into TextSummarizationOutputs
for i in range(len(text_summarization_output["predictions"])):
text_summarization_dict = text_summarization_output["predictions"][i].dict()
TextSummarizationOutput(**text_summarization_dict)
text_summarization_output["predictions"][i] = text_summarization_dict
return text_summarization_output

class SearchModel(AixplainModel):
def run_model(self, api_input: Dict[str, List[SearchInput]], headers: Dict[str, str] = None) -> Dict[str, List[SearchOutput]]:
pass

def predict(self, request: Dict[str, str], headers: Dict[str, str] = None) -> Dict:
instances = request['instances']
search_input_list = []
# Convert JSON serializables into SearchInputs
for instance in instances:
search_input = SearchInput(**instance)
search_input_list.append(search_input)

search_output = self.run_model({"instances": search_input_list})

# Convert JSON serializables into SearchOutputs
for i in range(len(search_output["predictions"])):
search_dict = search_output["predictions"][i].dict()
SearchOutput(**search_dict)
search_output["predictions"][i] = search_dict
return search_output

class TextReconstructionModel(AixplainModel):
def run_model(self, api_input: Dict[str, List[TextReconstructionInput]], headers: Dict[str, str] = None) -> Dict[str, List[TextReconstructionInput]]:
pass

def predict(self, request: Dict[str, str], headers: Dict[str, str] = None) -> Dict:
instances = request['instances']
text_reconstruction_input_list = []
# Convert JSON serializables into TextReconstructionInputs
for instance in instances:
text_reconstruction_input = TextReconstructionInput(**instance)
text_reconstruction_input_list.append(text_reconstruction_input)

text_reconstruction_output = self.run_model({"instances": text_reconstruction_input_list})

# Convert JSON serializables into TextReconstructionOutputs
for i in range(len(text_reconstruction_output["predictions"])):
text_reconstruction_dict = text_reconstruction_output["predictions"][i].dict()
TextReconstructionOutput(**text_reconstruction_dict)
text_reconstruction_output["predictions"][i] = text_reconstruction_dict
return text_reconstruction_output

class FillTextMaskModel(AixplainModel):
def run_model(self, api_input: Dict[str, List[FillTextMaskInput]], headers: Dict[str, str] = None) -> Dict[str, List[FillTextMaskOutput]]:
pass

def predict(self, request: Dict[str, str], headers: Dict[str, str] = None) -> Dict:
instances = request['instances']
fill_text_mask_input_list = []
# Convert JSON serializables into FillTextMaskInputs
for instance in instances:
fill_text_mask_input = FillTextMaskInput(**instance)
fill_text_mask_input_list.append(fill_text_mask_input)

fill_text_mask_output = self.run_model({"instances": fill_text_mask_input_list})

# Convert JSON serializables into FillTextMaskOutputs
for i in range(len(fill_text_mask_output["predictions"])):
fill_text_mask_dict = fill_text_mask_output["predictions"][i].dict()
FillTextMaskOutput(**fill_text_mask_dict)
fill_text_mask_output["predictions"][i] = fill_text_mask_dict
return fill_text_mask_output

class SubtitleTranslationModel(AixplainModel):
def run_model(self, api_input: Dict[str, List[SubtitleTranslationInput]], headers: Dict[str, str] = None) -> Dict[str, List[SubtitleTranslationOutput]]:
pass

def predict(self, request: Dict[str, str], headers: Dict[str, str] = None) -> Dict:
instances = request['instances']
subtitle_translation_input_list = []
# Convert JSON serializables into SubtitleTranslationInputs
for instance in instances:
subtitle_translation_input = SubtitleTranslationInput(**instance)
subtitle_translation_input_list.append(subtitle_translation_input)

subtitle_translation_output = self.run_model({"instances": subtitle_translation_input_list})

# Convert JSON serializables into SubtitleTranslationOutput
for i in range(len(subtitle_translation_output["predictions"])):
subtitle_translation_dict = subtitle_translation_output["predictions"][i].dict()
SubtitleTranslationOutput(**subtitle_translation_dict)
subtitle_translation_output["predictions"][i] = subtitle_translation_dict
return subtitle_translation_output
4 changes: 2 additions & 2 deletions aixplain/model_interfaces/interfaces/metric_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import tornado.web
from fastapi.responses import JSONResponse

from aixplain.model_interfaces.schemas.metric_input import (
from aixplain.model_interfaces.schemas.metric.metric_input import (
TextGenerationMetricInput,
ReferencelessTextGenerationMetricInput,
ClassificationMetricInput,
Expand All @@ -14,7 +14,7 @@
NamedEntityRecognitionMetricInput,
MetricAggregate,
)
from aixplain.model_interfaces.schemas.metric_output import (
from aixplain.model_interfaces.schemas.metric.metric_output import (
TextGenerationMetricOutput,
ReferencelessTextGenerationMetricOutput,
ClassificationMetricOutput,
Expand Down
Loading

0 comments on commit 2c975b2

Please sign in to comment.