Skip to content

Commit

Permalink
Added function_models
Browse files Browse the repository at this point in the history
Signed-off-by: Michael Lam <[email protected]>
  • Loading branch information
mikelam-us-aixplain committed Jun 17, 2024
1 parent 0839f1c commit 4111640
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions aixplain/model_interfaces/interfaces/function_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,9 +197,11 @@ def predict(self, request: Dict[str, str], headers: Dict[str, str] = None) -> Di
text_to_image_generation_output["predictions"][i] = text_to_image_generation_dict
return text_to_image_generation_output


class TextGenerationChatTemplatizeInput(BaseModel):
data: List[Dict]

class TextGenerationPredictInput(BaseModel):
instances: Union[List[TextGenerationInput], List[TextListInput]]
instances: Union[List[TextGenerationInput], List[TextListInput], List[TextGenerationChatTemplatizeInput]]
function: Optional[Text] = "PREDICT"

class TextGenerationRunModelOutput(BaseModel):
Expand Down Expand Up @@ -253,14 +255,14 @@ def predict(self, request: TextGenerationPredictInput, headers: Dict[str, str] =
return token_counts_output
elif request.function.upper() == "TEMPLATIZE":
templatize_output = {
"data": self.templatize(instances, headers)
"prompts": self.templatize(instances, headers)
}
return templatize_output
else:
raise ValueError("Invalid function.")

@validate_call
def templatize(self, api_input: List[TextListInput], headers: Dict[str, str] = None) -> List[TextListOutput]:
def templatize(self, api_input: List[TextGenerationChatTemplatizeInput], headers: Dict[str, str] = None) -> List[Text]:
pass

class TextSummarizationModel(AixplainModel):
Expand Down

0 comments on commit 4111640

Please sign in to comment.