diff --git a/src/extraction.jl b/src/extraction.jl index bd77aa1f4..4d816a5b7 100644 --- a/src/extraction.jl +++ b/src/extraction.jl @@ -6,6 +6,24 @@ # There are potential formats: 1) JSON-based for OpenAI compatible APIs, 2) XML-based for Anthropic compatible APIs (used also by Hermes-2-Pro model). # +""" + JSON_PRIMITIVE_TYPES + +A set of primitive types that are supported by JSON. If a type +is not in this set, the JSON typer [`to_json_type`](@ref) will +assume that the type is a `struct` and will attempt to recursively +unpack the fields of the struct. +""" +const JSON_PRIMITIVE_TYPES = Union{ + Integer, + Real, + AbstractString, + Bool, + Nothing, + Missing, + AbstractArray +} + ###################### # 1) OpenAI / JSON format ###################### @@ -15,7 +33,14 @@ to_json_type(n::Type{<:Real}) = "number" to_json_type(n::Type{<:Integer}) = "integer" to_json_type(b::Type{Bool}) = "boolean" to_json_type(t::Type{<:Union{Missing, Nothing}}) = "null" -to_json_type(t::Type{<:Any}) = "string" # object? +to_json_type(t::Type{T}) where {T <: AbstractArray} = to_json_type(eltype(t)) * "[]" +to_json_type(t::Type{Any}) = throw(ArgumentError(""" +Type $t is not a valid type for to_json_type. Please provide a valid type found in: + +$JSON_PRIMITIVE_TYPES + +You may be using to_json_schema but forgot to properly type the fields of your struct. +""")) has_null_type(T::Type{Missing}) = true has_null_type(T::Type{Nothing}) = true @@ -236,3 +261,108 @@ Extract zero, one or more specified items from the provided data. struct ItemsExtract{T <: Any} items::Vector{T} end + +""" + typed_json_schema(x::Type{T}) where {T} + +Convert a Julia type to a JSON schema that lists keys as field names and values as +the types of those field names. + +WARNING! Every field in your struct, and all nested structs, must be typed using a subtype of values in [`JSON_PRIMITIVE_TYPES`](@ref) +before calling this function. Otherwise, you will get a recursion error. + +## Example + +```julia +# Simple flat structure where each field is a primitive type +struct SimpleSingleton + singleton_value::Int +end + +typed_json_schema(SimpleSingleton) +``` + +``` +Dict{Any, Any} with 1 entry: + :singleton_value => "integer" +``` + +Or using nested structs + +```julia +# Test a struct that contains another struct. +struct Nested + inside_element::SimpleSingleton +end + +typed_json_schema(Nested) +``` + +```julia +Dict{Any, Any} with 1 entry: + :inside_element => Dict{Any, Any}("singleton_value" => "integer") +``` + +Lists of created Julia types will be specified as `List[Object]` with the value being the type of the elements, +i.e. + +```julia +# Test a struct with a vector of primitives +struct ABunchOfVectors + strings::Vector{String} + ints::Vector{Int} + floats::Vector{Float64} + nested_vector::Vector{Nested} +end + +typed_json_schema(ABunchOfVectors) +``` + +``` +Dict{Any, Any} with 4 entries: + :strings => "string[]" + :ints => "integer[]" + :nested_vector => Dict("list[Object]"=>"{\"inside_element\":{\"singleton_value\":\"integer\"}}") + :floats => "number[]" +``` + +## Resources +- the [original issue](https://github.com/svilupp/PromptingTools.jl/issues/143) +- the [motivation](https://www.boundaryml.com/blog/type-definition-prompting-baml) +""" +function typed_json_schema(x::Type{T}) where {T} + # We can return early if the type is a non-array primitive + if T <: JSON_PRIMITIVE_TYPES && !(T <: AbstractArray) + return to_json_type(T) + end + + # If there are no fields, return the type + if isempty(fieldnames(T)) + # Check if this is a vector type. If so, return the type of the elements. + if T <: AbstractArray + # Now check if the element type is a non-primitive. If so, recursively call typed_json_schema. + if eltype(T) <: JSON_PRIMITIVE_TYPES + return to_json_type(T) + else + return Dict("list[Object]" => JSON3.write(typed_json_schema(eltype(T)))) + # return "List[" * JSON3.write(typed_json_schema(eltype(T))) * "]" + end + end + + # Check if the type is a non-primitive. + if T <: JSON_PRIMITIVE_TYPES + return to_json_type(T) + else + return typed_json_schema(T) + end + end + + # Preallocate a mapping + mapping = Dict() + for (type, field) in zip(T.types, fieldnames(T)) + mapping[field] = typed_json_schema(type) + end + + # Get property names + return mapping +end diff --git a/test/extraction.jl b/test/extraction.jl index a10b8478b..d2c1804ea 100644 --- a/test/extraction.jl +++ b/test/extraction.jl @@ -225,9 +225,11 @@ end end output = function_call_signature(MyMeasurement2)#|> JSON3.pretty expected_output = Dict{String, Any}("name" => "MyMeasurement2_extractor", - "parameters" => Dict{String, Any}("properties" => Dict{String, Any}("height" => Dict{ + "parameters" => Dict{String, Any}( + "properties" => Dict{String, Any}( + "height" => Dict{ String, - Any, + Any }("type" => "integer"), "weight" => Dict{String, Any}("type" => "number"), "age" => Dict{String, Any}("type" => "integer")), @@ -240,3 +242,123 @@ end schema = function_call_signature(MaybeExtract{MyMeasurement2}) @test schema["name"] == "MaybeExtractMyMeasurement2_extractor" end +@testset "to_json_schema-primitive_types" begin + @test to_json_schema(Int) == Dict("type" => "integer") + @test to_json_schema(Float64) == Dict("type" => "number") + @test to_json_schema(Bool) == Dict("type" => "boolean") + @test to_json_schema(String) == Dict("type" => "string") + @test_throws ArgumentError to_json_schema(Any) # Type Any is not supported +end +@testset "to_json_schema-structs" begin + # Function to check the equivalence of two JSON strings, since Dict is + # unordered, we need to sort keys before comparison. + function check_json_equivalence(json1::AbstractString, json2::AbstractString) + println("\ncheck_json_equivalence\n===json1===") + println(json1) + println("===json2===") + println(json2) + println() + # JSON dictionary + d1 = JSON3.read(json1) + d2 = JSON3.read(json2) + + # Get all the keys + k1 = sort(collect(keys(d1))) + k2 = sort(collect(keys(d2))) + + # Test that all the keys are present + @test setdiff(k1, k2) == [] + @test setdiff(k2, k1) == [] + + # Test that all the values are equivalent + for (k, v) in d1 + @test d2[k] == v + end + + # @test JSON3.write(JSON3.read(json1)) == JSON3.write(JSON3.read(json2)) + end + function check_json_equivalence(d::Dict, s::AbstractString) + return check_json_equivalence(JSON3.write(d), s) + end + + # Simple flat structure where each field is a primitive type + struct SimpleSingleton + singleton_value::Int + end + + check_json_equivalence( + JSON3.write(typed_json_schema(SimpleSingleton)), + "{\"singleton_value\":\"integer\"}" + ) + + # Test a struct that contains another struct. + struct Nested + inside_element::SimpleSingleton + end + + check_json_equivalence( + JSON3.write(typed_json_schema(Nested)), + "{\"inside_element\":{\"singleton_value\":\"integer\"}}" + ) + + # Test a struct with two primitive types + struct IntFloatFlat + int_value::Int + float_value::Float64 + end + check_json_equivalence( + typed_json_schema(IntFloatFlat), + "{\"int_value\":\"integer\",\"float_value\":\"number\"}" + ) + + # Test a struct that contains all primitive types + struct AllJSONPrimitives + int::Integer + float::Real + string::AbstractString + bool::Bool + nothing::Nothing + missing::Missing + + # Array types + array_of_strings::Vector{String} + array_of_ints::Vector{Int} + array_of_floats::Vector{Float64} + array_of_bools::Vector{Bool} + array_of_nothings::Vector{Nothing} + array_of_missings::Vector{Missing} + end + + check_json_equivalence( + typed_json_schema(AllJSONPrimitives), + "{\"int\":\"integer\",\"float\":\"number\",\"string\":\"string\",\"bool\":\"boolean\",\"nothing\":\"null\",\"missing\":\"null\",\"array_of_strings\":\"string[]\",\"array_of_ints\":\"integer[]\",\"array_of_floats\":\"number[]\",\"array_of_bools\":\"boolean[]\",\"array_of_nothings\":\"null[]\",\"array_of_missings\":\"null[]\"}" + ) + + # Test a struct with a vector of primitives + struct ABunchOfVectors + strings::Vector{String} + ints::Vector{Int} + floats::Vector{Float64} + nested_vector::Vector{Nested} + end + + check_json_equivalence( + typed_json_schema(ABunchOfVectors), + "{\"strings\":\"string[]\",\"ints\":\"integer[]\",\"nested_vector\":{\"list[Object]\":\"{\\\"inside_element\\\":{\\\"singleton_value\\\":\\\"integer\\\"}}\"},\"floats\":\"number[]\"}" + ) + + # Weird struct with a bunch of different types + struct Monster + name::String + age::Int + height::Float64 + friends::Vector{String} + nested::Nested + flat::IntFloatFlat + end + + check_json_equivalence( + typed_json_schema(Monster), + "{\"flat\":{\"float_value\":\"number\",\"int_value\":\"integer\"},\"nested\":{\"inside_element\":{\"singleton_value\":\"integer\"}},\"age\":\"integer\",\"name\":\"string\",\"height\":\"number\",\"friends\":\"string[]\"}" + ) +end;