Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Function support #68

Merged
merged 14 commits into from
Dec 14, 2023
Merged
147 changes: 88 additions & 59 deletions schema/function.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ import (
type Function interface {
ID() string
Parameters() []Type
Output([]Type) (Type, error)
// Output determines the output type. This can be static, or it can depend on the input types.
// It also returns whether the handler may self-report an error.
Output([]Type) (Type, bool, error)
Display() Display
String() string
}
Expand All @@ -20,18 +22,37 @@ type CallableFunction interface {
Call(arguments []any) (any, error)
}

type FunctionCallError struct {
// isFunctionReportedError is true when the error originated from the function itself from its return value or a panic
// It is false when the function did something unexpected, or the call is invalid.
IsFunctionReportedError bool
SourceError error
}

func NewFunctionCallError(err error, isFunctionReportedError bool) *FunctionCallError {
return &FunctionCallError{
isFunctionReportedError,
err,
}
}

func (e *FunctionCallError) Error() string {
return e.SourceError.Error()
}

const errorType = "error"

// NewCallableFunction creates a CallableFunction schema type for the strictly typed function.
//
// - The handler types must match the input and output types specified.
// - The return type of the handler is determined by the value specified for output. If output is nil, it's a void
// function that may have no return, or a single error return.
// - If output is not nil, the return type must match, plus it may have an error return type, too.
// function that may have no return, or a single error return if outputsError is true.
// - If output is not nil, the return type must match, plus it must have an error return if outputsError is true.
func NewCallableFunction(
id string,
inputs []Type,
output Type,
outputsError bool,
display Display,
handler any,
) (CallableFunction, error) {
Expand All @@ -42,60 +63,56 @@ func NewCallableFunction(
return nil, err
}
// Validate the output type
if output == nil {
err := validateVoidFunc(parsedHandler)
if err != nil {
return nil, err
}
} else {
err := validateTypedReturnFunc(parsedHandler, output)
if err != nil {
return nil, err
}
err = validateTypedReturnFunc(parsedHandler, outputsError, output)
if err != nil {
return nil, err
}
return &CallableFunctionSchema{
IDValue: id,
InputsValue: inputs,
StaticOutputValue: output,
OutputsError: outputsError,
DisplayValue: display,
Handler: parsedHandler,
}, nil
}

func validateVoidFunc(parsedHandler reflect.Value) error {
func validateTypedReturnFunc(parsedHandler reflect.Value, errorExpected bool, outputType Type) error {
returnCount := parsedHandler.Type().NumOut()
// A void function can have no returns or an error return
if returnCount > 1 {
return fmt.Errorf(
"parameter output is nil, meaning it's a void function, or a function with just an error return, but got %d return types",
returnCount,
)
} else if returnCount == 1 {
// Validate that it's just an error return
returnTypeName := parsedHandler.Type().Out(0).Name()
if returnTypeName != errorType {
return fmt.Errorf("expected void or error return, but got %s", returnTypeName)
}
var expectedReturnCount int
// Validate error return
if errorExpected {
expectedReturnCount = 1
} else {
expectedReturnCount = 0
}
// Currently designed to allow only one output type. However, the output type could be an object with multiple fields.
if outputType != nil {
expectedReturnCount += 1
}
jaredoconnell marked this conversation as resolved.
Show resolved Hide resolved
if expectedReturnCount != returnCount {
return fmt.Errorf("incorrect return count '%d'; expected '%d', with type(s) '%s'",
returnCount, expectedReturnCount, getReturnTypeString(outputType, errorExpected))
}
return nil
}

func validateTypedReturnFunc(parsedHandler reflect.Value, outputType Type) error {
returnCount := parsedHandler.Type().NumOut()
// Validate error return
if errorExpected {
// Validate the last type as error
handlerLastTypeName := parsedHandler.Type().Out(returnCount - 1).Name()
if handlerLastTypeName != errorType {
return fmt.Errorf("expected last return type from handler to be error, but instead found '%s'", handlerLastTypeName)
}
}

if returnCount > 2 || returnCount < 1 {
return fmt.Errorf("expected handler to have one return, or one plus an error return, but got %d return types", returnCount)
} else {
// Validate the other return, if applicable
if outputType != nil {
// Validate the return type
expectedType := outputType.ReflectedType()
handlerType := parsedHandler.Type().Out(0)

if expectedType != handlerType {
return fmt.Errorf("mismatched return type. expected %s, handler has %s", expectedType, handlerType)
}
// Validate error return, if applicable.
if returnCount == 2 && parsedHandler.Type().Out(1).Name() != errorType {
return fmt.Errorf("expected additional return type to be an error return, but got %s", parsedHandler.Type().Out(1).Name())
}
}
return nil
}
Expand Down Expand Up @@ -135,6 +152,7 @@ func NewDynamicCallableFunction(
IDValue: id,
InputsValue: inputs,
StaticOutputValue: nil,
OutputsError: true,
DisplayValue: display,
Handler: parsedHandler,
DynamicTypeHandler: typeHandler,
Expand Down Expand Up @@ -198,23 +216,36 @@ func (f FunctionSchema) Display() Display {

func (f FunctionSchema) String() string {
result := f.ID() + "(" + strings.Join(f.ParameterTypeNames(), ", ") + ") "
if f.OutputValue != nil {
result += string(f.OutputValue.TypeID())
} else {
result += "void"
}
result += getReturnTypeString(f.OutputValue, false)
return result
}

func getReturnTypeString(returnType Type, hasError bool) string {
switch {
case returnType != nil:
if hasError {
return "(" + string(returnType.TypeID()) + ", error)"
} else {
return string(returnType.TypeID())
}
case hasError:
return "error"
default:
return "void"
}
jaredoconnell marked this conversation as resolved.
Show resolved Hide resolved
}

type CallableFunctionSchema struct {
IDValue string `json:"id"`
InputsValue []Type `json:"inputs"`
// The output type when the output type does not change. Nil for void.
StaticOutputValue Type `json:"output"`
OutputsError bool `json:"outputs_error"`
DisplayValue Display `json:"display"`
// A callable function whose parameters (if any) match the type schema specified in InputsValue,
// and whose return value type matches StaticOutputValue, the return type from DynamicTypeHandler,
// or is void if both StaticOutputValue and DynamicTypeHandler are nil.
// or is void if both StaticOutputValue and DynamicTypeHandler are nil. An error return must be present
// if OutputsError is true.
// The handler may also return an error type.
Handler reflect.Value
// Returns the output type based on the input type. For advanced use cases. Cannot be void.
Expand All @@ -237,11 +268,12 @@ func (f CallableFunctionSchema) ParameterTypeNames() []string {
return parameterNames
}

func (f CallableFunctionSchema) Output(inputType []Type) (Type, error) {
func (f CallableFunctionSchema) Output(inputType []Type) (Type, bool, error) {
if f.DynamicTypeHandler == nil {
return f.StaticOutputValue, nil
return f.StaticOutputValue, f.OutputsError, nil
} else {
return f.DynamicTypeHandler(inputType)
dynamicTypes, err := f.DynamicTypeHandler(inputType)
return dynamicTypes, f.OutputsError, err
}
}

Expand All @@ -251,13 +283,10 @@ func (f CallableFunctionSchema) Display() Display {

func (f CallableFunctionSchema) String() string {
result := f.ID() + "(" + strings.Join(f.ParameterTypeNames(), ", ") + ") "
switch {
case f.DynamicTypeHandler != nil:
result += "dynamic"
case f.StaticOutputValue != nil:
result += string(f.StaticOutputValue.TypeID())
default:
result += "void"
if f.DynamicTypeHandler != nil {
result += "(dynamic, error)"
jaredoconnell marked this conversation as resolved.
Show resolved Hide resolved
} else {
result += getReturnTypeString(f.StaticOutputValue, f.OutputsError)
}
return result
}
Expand All @@ -279,12 +308,12 @@ func (f CallableFunctionSchema) Call(arguments []any) (any, error) {
gotArgs := len(arguments)
expectedArgs := f.Handler.Type().NumIn()
if gotArgs != expectedArgs {
return nil, fmt.Errorf(
return nil, NewFunctionCallError(fmt.Errorf(
"incorrect number of args sent to function with ID '%s'. Expected %d, got %d",
f.ID(),
expectedArgs,
gotArgs,
)
), false)
}
// Convert to reflect values
args := make([]reflect.Value, gotArgs)
Expand All @@ -311,10 +340,10 @@ func (f CallableFunctionSchema) Call(arguments []any) (any, error) {
if !errorVal.IsNil() {
err, isError := errorVal.Interface().(error)
if !isError {
return nil, fmt.Errorf("error return val isn't an error '%w'", err)
return nil, NewFunctionCallError(fmt.Errorf("error return val isn't an error '%w'", err), false)
}
if err != nil {
return nil, fmt.Errorf("function returned error: %w", err)
return nil, NewFunctionCallError(err, true)
}
}
// Expected return plus error return
Expand All @@ -324,7 +353,7 @@ func (f CallableFunctionSchema) Call(arguments []any) (any, error) {
return result[0].Interface(), nil
}
default:
return nil, fmt.Errorf("unexpected return count. Expected %d or %d, got %d",
expectedReturnVals, expectedReturnVals+1, gotReturns)
return nil, NewFunctionCallError(fmt.Errorf("unexpected return count. Expected %d or %d, got %d",
expectedReturnVals, expectedReturnVals+1, gotReturns), false)
}
}
Loading
Loading