Skip to content

Commit

Permalink
Merge pull request #22 from aixplain/project-node-3
Browse files Browse the repository at this point in the history
Project node 3
  • Loading branch information
mikelam-us-aixplain authored Jul 26, 2024
2 parents ba99f6f + d829ff6 commit be4bbeb
Show file tree
Hide file tree
Showing 12 changed files with 310 additions and 50 deletions.
13 changes: 13 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Change Log
All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](http://keepachangelog.com/)
and this project adheres to [Semantic Versioning](http://semver.org/).

## [v0.0.2] - 2024-07-25

### Added
- Added support for script nodes.
### Changed

### Fixed
39 changes: 33 additions & 6 deletions aixplain/model_interfaces/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
SearchInput,
TextReconstructionInput,
FillTextMaskInput,
SubtitleTranslationInput
SubtitleTranslationInput,
SegmentationInput
)

from aixplain.model_interfaces.schemas.function.function_output import(
Expand All @@ -33,7 +34,8 @@
SearchOutput,
TextReconstructionOutput,
FillTextMaskOutput,
SubtitleTranslationOutput
SubtitleTranslationOutput,
SegmentationOutput
)

from aixplain.model_interfaces.schemas.metric.metric_input import(
Expand All @@ -60,6 +62,14 @@
NamedEntityRecognitionMetricOutput
)

from aixplain.model_interfaces.schemas.script.script_input import(
ScriptInput
)

from aixplain.model_interfaces.schemas.script.script_output import(
ScriptOutput
)

from aixplain.model_interfaces.interfaces.function_models import(
TranslationModel,
SpeechRecognitionModel,
Expand All @@ -74,7 +84,8 @@
SearchModel,
TextReconstructionModel,
FillTextMaskModel,
SubtitleTranslationModel
SubtitleTranslationModel,
SegmentationModel
)

from aixplain.model_interfaces.interfaces.metric_models import(
Expand All @@ -86,6 +97,10 @@
NamedEntityRecognitionMetric
)

from aixplain.model_interfaces.interfaces.project_node import(
ProjectNode
)

function_classes = [
TranslationModel,
SpeechRecognitionModel,
Expand All @@ -100,7 +115,8 @@
SearchModel,
TextReconstructionModel,
FillTextMaskModel,
SubtitleTranslationModel
SubtitleTranslationModel,
SegmentationModel
]

function_classes_input = [
Expand All @@ -114,7 +130,8 @@
SpeechEnhancementInput,
SpeechSynthesisInput,
TextToImageGenerationInput,
TextGenerationInput
TextGenerationInput,
SegmentationInput
]

metric_classes_input = [
Expand All @@ -140,7 +157,17 @@
NamedEntityRecognitionMetric
]

script_classes_input = [
ScriptInput
]

script_classes = [
ProjectNode
]

function_input_interface_map = {clazz.__name__.replace("Input", ""): clazz for clazz in function_classes_input}
metric_input_interface_map = {clazz.__name__.replace("Input", ""): clazz for clazz in metric_classes_input}
script_input_interface_map = {clazz.__name__.replace("Input", ""): clazz for clazz in script_classes_input}
function_interface_map = {clazz.__name__: clazz for clazz in function_classes}
metric_interface_map = {clazz.__name__: clazz for clazz in metric_classes}
metric_interface_map = {clazz.__name__: clazz for clazz in metric_classes}
script_interface_map = {clazz.__name__: clazz for clazz in script_classes}
2 changes: 1 addition & 1 deletion aixplain/model_interfaces/__version__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
__title__ = "model-interfaces"
__description__ = "model-interfaces is the interface to host your models on aiXplain"
__url__ = "https://github.com/aixplain/aixplain-models/tree/main/docs"
__version__ = "0.0.2rc2"
__version__ = "0.0.2rc3"
__author__ = "Duraikrishna Selvaraju and Michael Lam"
__author_email__ = "[email protected]"
__license__ = "http://www.apache.org/licenses/LICENSE-2.0"
Expand Down
27 changes: 26 additions & 1 deletion aixplain/model_interfaces/interfaces/function_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pydantic import BaseModel, validate_call

from aixplain.model_interfaces.schemas.function.function_input import (
SegmentationInput,
TranslationInput,
SpeechRecognitionInput,
DiacritizationInput,
Expand All @@ -21,6 +22,7 @@
SubtitleTranslationInput
)
from aixplain.model_interfaces.schemas.function.function_output import (
SegmentationOutput,
TranslationOutput,
SpeechRecognitionOutput,
DiacritizationOutput,
Expand Down Expand Up @@ -368,4 +370,27 @@ def predict(self, request: Dict[str, str], headers: Dict[str, str] = None) -> Di
subtitle_translation_dict = subtitle_translation_output["predictions"][i].dict()
SubtitleTranslationOutput(**subtitle_translation_dict)
subtitle_translation_output["predictions"][i] = subtitle_translation_dict
return subtitle_translation_output
return subtitle_translation_output

class SegmentationModel(AixplainModel):
def run_model(
self,
api_input: Dict[str, List[SegmentationInput]],
headers: Dict[str, str] = None
) -> Dict[str, List[SegmentationOutput]]:
pass

def predict(self, request: Dict[str, str],
headers: Dict[str, str] = None) -> dict:
instances = []

for instance in request['instances']:
segmentation_input = SegmentationInput(**instance)
instances.append(segmentation_input)

output = self.run_model({"instances": instances}, headers)

for i, prediction in enumerate(output["predictions"]):
output["predictions"][i] = prediction.dict()

return output
45 changes: 45 additions & 0 deletions aixplain/model_interfaces/interfaces/project_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
__author__='aiXplain'

from abc import abstractmethod
from kserve.model import Model
from typing import Dict, List, Any
from pydantic import validate_call

from aixplain.model_interfaces.schemas.script.script_input import ScriptInput
from aixplain.model_interfaces.schemas.script.script_output import ScriptOutput

class ProjectNode(Model):
def __init__(self, name, *args, **kwargs):
super().__init__(name)
self.name = name
self.ready = False
ready = self.load()
print(f"Readiness: {ready}")
assert ready

@validate_call
def predict(self, request: Dict[str, List[ScriptInput]], headers: Dict[str, str] = None) -> Dict[str, List[ScriptOutput]]:
instances = request["instances"]
results = []
for instance in instances:
result = self.run_script(instance)
results.append(result)
predictions = {
"predictions": results
}
return predictions

@validate_call
def run_script(self, input: Any) -> Any:
raise NotImplementedError

def load(self) -> bool:
"""Load handler can be overridden to load the metric from storage
``self.ready`` flag is used for metric health check
:return:
True if metric is ready, False otherwise
:rtype:
Bool
"""
self.ready = True
return self.ready
25 changes: 24 additions & 1 deletion aixplain/model_interfaces/schemas/function/function_input.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from enum import Enum
from http import HTTPStatus
from typing import Optional, Any, List
from typing import Optional, Any, List, Union, Tuple

from pydantic import BaseModel, validator
import tornado.web
Expand Down Expand Up @@ -262,6 +262,29 @@ def __init__(self, **input):
status_code=HTTPStatus.BAD_REQUEST,
reason="Incorrect type passed into TextToImageGenerationInput."
)

class Segment(BaseModel):
"""Segment information with optional url field. If the url field populated
this means the segmentation performed and the regarding segment uploaded
to the defined location under `url` field.
"""
segment_id: int
start: Union[float, int, Tuple[int, int]]
end: Union[float, int, Tuple[int, int]]
url: Optional[str]

class SegmentationInputSchema(APIInput):
"""The standardized schema of the aiXplain's Segmenation API input.
:param details:
List of segments in Segment type.
:type data:
Segment
"""
details: Optional[Union[List[Segment], str]]

class SegmentationInput(SegmentationInputSchema):
pass

class TextGenerationInputSchema(TextInput):
"""The standardized schema of aiXplains text generation API Input
Expand Down
34 changes: 22 additions & 12 deletions aixplain/model_interfaces/schemas/function/function_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import tornado.web
from http import HTTPStatus

from aixplain.model_interfaces.schemas.function.function_input import AudioConfig, AudioEncoding
from aixplain.model_interfaces.schemas.function.function_input import AudioConfig, AudioEncoding, SegmentationInputSchema
from aixplain.model_interfaces.utils import serialize
from aixplain.model_interfaces.schemas.api.basic_api_output import APIOutput
from aixplain.model_interfaces.schemas.modality.modality_output import TextOutput
Expand All @@ -27,8 +27,8 @@ class WordDetails(BaseModel):
Dict
"""
word: str
confidence: Optional[float]
details: Optional[Dict[str, Any]]
confidence: Optional[float] = None
details: Optional[Dict[str, Any]] = None

class TextSegmentDetails(BaseModel):
"""The standardized schema of the aiXplain's representation of text
Expand All @@ -48,8 +48,8 @@ class TextSegmentDetails(BaseModel):
WordDetails
"""
text: str
confidence: Optional[float]
word_details: Optional[List[WordDetails]]
confidence: Optional[float] = None
word_details: Optional[List[WordDetails]] = None

class Label(BaseModel):
"""The standardized schema of the aiXplain's representation of label
Expand All @@ -65,7 +65,7 @@ class Label(BaseModel):
float
"""
label: str
confidence: Optional[float]
confidence: Optional[float] = None

class TranslationOutputSchema(APIOutput):
"""The standardized schema of the aiXplain's Translation Output.
Expand Down Expand Up @@ -127,7 +127,7 @@ class ClassificationOutput(APIOutput):
List[Label]
"""
predicted_labels: List[Label]
all_labels: Optional[List[Label]]
all_labels: Optional[List[Label]] = None

class SpeechEnhancementOutputSchema(APIOutput):
"""The standardized schema of the aiXplain's Speech Enhancement Output.
Expand Down Expand Up @@ -245,7 +245,7 @@ class TextSummarizationOutputSchema(TextOutput):
:type details:
Any. Optional.
"""
details: Optional[Any]
details: Optional[Any] = None

class TextSummarizationOutput(TextSummarizationOutputSchema):
def __init__(self, **input):
Expand All @@ -269,7 +269,7 @@ class SearchOutputSchema(TextOutput):
:type details:
Any. Optional.
"""
details: Optional[Any]
details: Optional[Any] = None

class SearchOutput(SearchOutputSchema):
def __init__(self, **input):
Expand Down Expand Up @@ -315,7 +315,7 @@ class TextReconstructionOutputSchema(TextOutput):
:type details:
TextSegmentDetails
"""
details: Optional[TextSegmentDetails]
details: Optional[TextSegmentDetails] = None

class TextReconstructionOutput(TextReconstructionOutputSchema):
def __init__(self, **input):
Expand All @@ -338,7 +338,7 @@ class FillTextMaskOutputSchema(TextOutput):
:type details:
TextSegmentDetails
"""
details: Optional[TextSegmentDetails]
details: Optional[TextSegmentDetails] = None

class FillTextMaskOutput(FillTextMaskOutputSchema):
def __init__(self, **input):
Expand All @@ -361,7 +361,7 @@ class SubtitleTranslationOutputSchema(TextOutput):
:type details:
TextSegmentDetails
"""
details: Optional[TextSegmentDetails]
details: Optional[TextSegmentDetails] = None

class SubtitleTranslationOutput(SubtitleTranslationOutputSchema):
def __init__(self, **input):
Expand All @@ -372,3 +372,13 @@ def __init__(self, **input):
status_code=HTTPStatus.BAD_REQUEST,
reason="Incorrect types passed into SubtitleTranslationOutput"
)

class SegmentationOutput(SegmentationInputSchema):
def __init__(self, **input):
try:
super().__init__(**input)
except ValueError:
raise tornado.web.HTTPError(
status_code=HTTPStatus.BAD_REQUEST,
reason="Incorrect type passed into SegmentationInputSchema."
)
12 changes: 12 additions & 0 deletions aixplain/model_interfaces/schemas/script/script_input.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
__author__ = "aiXplain"

from pydantic import BaseModel
from typing import Dict

class ScriptInput(BaseModel):
"""The standardized schema of the aiXplain's Script API input.
:param inputs:
Input values to script.
"""
inputs: Dict
12 changes: 12 additions & 0 deletions aixplain/model_interfaces/schemas/script/script_output.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
__author__ = "aiXplain"

from pydantic import BaseModel
from typing import Dict

class ScriptOutput(BaseModel):
"""The standardized schema of the aiXplain's Script API input.
:param inputs:
Input values to script.
"""
outputs: Dict
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ includes = ["aixplain"]

[project]
name = "model-interfaces"
version = "0.0.2rc2"
version = "0.0.2rc3"
description = "A package specifying the model interfaces supported by aiXplain"
license = { text = "Apache License, Version 2.0: http://www.apache.org/licenses/LICENSE-2.0" }
dependencies = [
Expand Down
Loading

0 comments on commit be4bbeb

Please sign in to comment.