Skip to content

Commit

Permalink
Merge pull request #7 from tellmewyatt/main
Browse files Browse the repository at this point in the history
Fixed JSON Schema Serialization for Components
  • Loading branch information
silvanocerza authored May 9, 2024
2 parents 4b22762 + 96a27e1 commit 75f91c3
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 33 deletions.
52 changes: 19 additions & 33 deletions src/hayhooks/server/pipelines/models.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
from typing import get_args, get_origin, List
from pandas import DataFrame
from pydantic import BaseModel, ConfigDict, create_model

from pydantic import BaseModel, create_model, ConfigDict
from haystack.dataclasses import Document


class HaystackDocument(BaseModel):
id: str
content: str
from hayhooks.server.utils.create_valid_type import handle_unsupported_types


class PipelineDefinition(BaseModel):
Expand All @@ -29,13 +24,16 @@ def get_request_model(pipeline_name: str, pipeline_inputs):
config = ConfigDict(arbitrary_types_allowed=True)

for component_name, inputs in pipeline_inputs.items():

component_model = {}
for name, typedef in inputs.items():
component_model[name] = (typedef["type"], typedef.get("default_value", ...))
request_model[component_name] = (create_model('ComponentParams', **component_model, __config__=config), ...)
input_type = handle_unsupported_types(typedef["type"], {DataFrame: dict})
component_model[name] = (
input_type,
typedef.get("default_value", ...),
)
request_model[component_name] = (create_model("ComponentParams", **component_model, __config__=config), ...)

return create_model(f'{pipeline_name.capitalize()}RunRequest', **request_model, __config__=config)
return create_model(f"{pipeline_name.capitalize()}RunRequest", **request_model, __config__=config)


def get_response_model(pipeline_name: str, pipeline_outputs):
Expand All @@ -49,44 +47,32 @@ def get_response_model(pipeline_name: str, pipeline_outputs):
"""
response_model = {}
config = ConfigDict(arbitrary_types_allowed=True)

for component_name, outputs in pipeline_outputs.items():
component_model = {}
for name, typedef in outputs.items():
output_type = typedef["type"]
if get_origin(output_type) == list and get_args(output_type)[0] == Document:
component_model[name] = (List[HaystackDocument], ...)
else:
component_model[name] = (typedef["type"], ...)
response_model[component_name] = (create_model('ComponentParams', **component_model, __config__=config), ...)
component_model[name] = (handle_unsupported_types(output_type, {DataFrame: dict}), ...)
response_model[component_name] = (create_model("ComponentParams", **component_model, __config__=config), ...)

return create_model(f'{pipeline_name.capitalize()}RunResponse', **response_model, __config__=config)
return create_model(f"{pipeline_name.capitalize()}RunResponse", **response_model, __config__=config)


def convert_component_output(component_output):
"""
Converts outputs from a component as a dict so that it can be validated against response model
Component output has this form:
"documents":[
{"id":"818170...", "content":"RapidAPI for Mac is a full-featured HTTP client."}
]
We inspect the output and convert haystack.Document into the HaystackDocument pydantic model as needed
"""
result = {}
for output_name, data in component_output.items():
# Empty containers, None values, empty strings and the likes: do nothing
if not data:
result[output_name] = data

# Output contains a list of Document
if type(data) is list and type(data[0]) is Document:
result[output_name] = [HaystackDocument(id=d.id, content=d.content) for d in data]
# Output is a single Document
elif type(data) is Document:
result[output_name] = HaystackDocument(id=data.id, content=data.content or "")
# Any other type: do nothing
get_value = lambda data: data.to_dict()["init_parameters"] if hasattr(data, "to_dict") else data
if type(data) is list:
result[output_name] = [get_value(d) for d in data]
else:
result[output_name] = data

result[output_name] = get_value(data)
return result
43 changes: 43 additions & 0 deletions src/hayhooks/server/utils/create_valid_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from inspect import isclass
from types import GenericAlias
from typing import Dict, Union, get_args, get_origin, get_type_hints

from typing_extensions import TypedDict


def handle_unsupported_types(type_: type, types_mapping: Dict[type, type]) -> Union[GenericAlias, type]:
"""
Recursively handle types that are not supported by Pydantic by replacing them with the given types mapping.
:param type_: Type to replace if not supported
:param types_mapping: Mapping of types to replace
"""

def _handle_generics(t_) -> GenericAlias:
"""
Handle generics recursively
"""
child_typing = []
for t in get_args(t_):
if t in types_mapping:
result = types_mapping[t]
elif isclass(t):
result = handle_unsupported_types(t, types_mapping)
else:
result = t
child_typing.append(result)
return GenericAlias(get_origin(t_), tuple(child_typing))

if isclass(type_):
new_type = {}
for arg_name, arg_type in get_type_hints(type_).items():
if get_args(arg_type):
new_type[arg_name] = _handle_generics(arg_type)
else:
new_type[arg_name] = arg_type
if new_type:
return TypedDict(type_.__name__, new_type)

return type_

return _handle_generics(type_)

0 comments on commit 75f91c3

Please sign in to comment.