diff --git a/examples/validator/main.go b/examples/validator/main.go new file mode 100644 index 0000000..74fb143 --- /dev/null +++ b/examples/validator/main.go @@ -0,0 +1,85 @@ +package main + +import ( + "context" + "fmt" + "os" + + "github.com/instructor-ai/instructor-go/pkg/instructor" + openai "github.com/sashabaranov/go-openai" +) + +type User struct { + FirstName string `json:"first_name" jsonschema:"title=First Name,description=The first name of the user" validate:"required"` + LastName string `json:"last_name" jsonschema:"title=Last Name,description=The last name of the user" validate:"required"` + Age uint8 `json:"age" jsonschema:"title=Age,description=The age of the user" validate:"gte=0,lte=130"` + Email string `json:"email" jsonschema:"title=Email,description=The email address of the user" validate:"required,email"` + Gender string `json:"gender" jsonschema:"title=Gender,description=The gender of the user" validate:"oneof=male female prefer_not_to"` + FavouriteColor string `json:"favourite_color" jsonschema:"title=Favourite Color,description=The favourite color of the user" validate:"iscolor"` + Addresses []*Address `json:"addresses" jsonschema:"title=Addresses,description=The addresses of the user" validate:"required,dive,required"` +} + +type Address struct { + Street string `json:"street" jsonschema:"title=Street,description=The street address" validate:"required"` + City string `json:"city" jsonschema:"title=City,description=The city" validate:"required"` + Planet string `json:"planet" jsonschema:"title=Planet,description=The planet" validate:"required"` + Phone string `json:"phone" jsonschema:"title=Phone,description=The phone number" validate:"required"` +} + +func (u User) String() string { + result := fmt.Sprintf("First Name: %s\nLast Name: %s\nAge: %d\nEmail: %s\nGender: %s\nFavourite Color: %s\nAddresses:\n", + u.FirstName, u.LastName, u.Age, u.Email, u.Gender, u.FavouriteColor) + for _, address := range u.Addresses { + result += fmt.Sprintf(" %s\n", address) + } + return result +} + +func (a Address) String() string { + return fmt.Sprintf("Street: %s, City: %s, Planet: %s, Phone: %s", a.Street, a.City, a.Planet, a.Phone) +} + +func main() { + ctx := context.Background() + + client := instructor.FromOpenAI( + openai.NewClient(os.Getenv("OPENAI_API_KEY")), + instructor.WithMode(instructor.ModeJSON), + instructor.WithMaxRetries(3), + instructor.WithValidation(), + ) + + var user User + _, err := client.CreateChatCompletion( + ctx, + openai.ChatCompletionRequest{ + Model: openai.GPT4o, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Meet Jane Doe: a 30-year-old adventurer who can be reached at janed@example.com. " + + "Jane loves the vibrant hue of #FF5733. She resides in Metropolis at 456 Oak St, on the wonderful planet Earth. " + + "To chat with her, dial (555) 555-1234. Jane also spends her weekends at her cottage located at 789 Pine St, " + + "in Smallville, on the same planet. You can contact her there at (555) 555-5678.", + }, + }, + }, + &user, + ) + if err != nil { + panic(err) + } + + fmt.Println(user) + /* + First Name: Jane + Last Name: Doe + Age: 30 + Email: janed@example.com + Gender: female + Favourite Color: #FF5733 + Addresses: + Street: 456 Oak St, City: Metropolis, Planet: Earth, Phone: (555) 555-1234 + Street: 789 Pine St, City: Smallville, Planet: Earth, Phone: (555) 555-5678 + */ +} diff --git a/go.mod b/go.mod index a11bb7a..f190eeb 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.21.8 require ( github.com/cohere-ai/cohere-go/v2 v2.8.1 + github.com/go-playground/validator/v10 v10.21.0 github.com/invopop/jsonschema v0.12.0 github.com/liushuangls/go-anthropic/v2 v2.1.0 github.com/sashabaranov/go-openai v1.24.1 @@ -12,8 +13,16 @@ require ( require ( github.com/bahlo/generic-list-go v0.2.0 // indirect github.com/buger/jsonparser v1.1.1 // indirect + github.com/gabriel-vasile/mimetype v1.4.3 // indirect + github.com/go-playground/locales v0.14.1 // indirect + github.com/go-playground/universal-translator v0.18.1 // indirect github.com/google/uuid v1.6.0 // indirect + github.com/leodido/go-urn v1.4.0 // indirect github.com/mailru/easyjson v0.7.7 // indirect github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect + golang.org/x/crypto v0.19.0 // indirect + golang.org/x/net v0.21.0 // indirect + golang.org/x/sys v0.17.0 // indirect + golang.org/x/text v0.14.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index ec9cab3..87c39f9 100644 --- a/go.sum +++ b/go.sum @@ -6,13 +6,23 @@ github.com/cohere-ai/cohere-go/v2 v2.8.1 h1:7+MCdXtz8onJLRmJik/cD5XGfgDNLhte4aW4 github.com/cohere-ai/cohere-go/v2 v2.8.1/go.mod h1:dlDCT66i8BqZDuuskFvYzsrc+O0M4l5J9Ibckoflvt4= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/google/uuid v1.4.0 h1:MtMxsa51/r9yyhkyLsVeVt0B+BGQZzpQiTQ4eHZ8bc4= -github.com/google/uuid v1.4.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0= +github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk= +github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= +github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= +github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= +github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= +github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= +github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= +github.com/go-playground/validator/v10 v10.21.0 h1:4fZA11ovvtkdgaeev9RGWPgc1uj3H8W+rNYyH/ySBb0= +github.com/go-playground/validator/v10 v10.21.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/invopop/jsonschema v0.12.0 h1:6ovsNSuvn9wEQVOyc72aycBMVQFKz7cPdMJn10CvzRI= github.com/invopop/jsonschema v0.12.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0= github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= +github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= +github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= github.com/liushuangls/go-anthropic/v2 v2.1.0 h1:5ntOeehozlMin0+hgnhxbTru+tmBH84ADaSPelG5fPg= github.com/liushuangls/go-anthropic/v2 v2.1.0/go.mod h1:8BKv/fkeTaL5R9R9bGkaknYBueyw2WxY20o7bImbOek= github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= @@ -21,10 +31,18 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/sashabaranov/go-openai v1.24.1 h1:DWK95XViNb+agQtuzsn+FyHhn3HQJ7Va8z04DQDJ1MI= github.com/sashabaranov/go-openai v1.24.1/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= -github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= -github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= +golang.org/x/crypto v0.19.0 h1:ENy+Az/9Y1vSrlrvBSyna3PITt4tiZLf7sgCjZBX7Wo= +golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= +golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4= +golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= +golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y= +golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/pkg/instructor/anthropic.go b/pkg/instructor/anthropic.go index 3b4870d..ea14b2b 100644 --- a/pkg/instructor/anthropic.go +++ b/pkg/instructor/anthropic.go @@ -10,6 +10,7 @@ type InstructorAnthropic struct { provider Provider mode Mode maxRetries int + validate bool } var _ Instructor = &InstructorAnthropic{} @@ -24,6 +25,7 @@ func FromAnthropic(client *anthropic.Client, opts ...Options) *InstructorAnthrop provider: ProviderOpenAI, mode: *options.Mode, maxRetries: *options.MaxRetries, + validate: *options.validate, } return i } @@ -39,3 +41,6 @@ func (i *InstructorAnthropic) Mode() string { func (i *InstructorAnthropic) Provider() string { return i.provider } +func (i *InstructorAnthropic) Validate() bool { + return i.validate +} diff --git a/pkg/instructor/chat.go b/pkg/instructor/chat.go index 6a4aeaa..d4a1478 100644 --- a/pkg/instructor/chat.go +++ b/pkg/instructor/chat.go @@ -5,6 +5,8 @@ import ( "encoding/json" "errors" "reflect" + + "github.com/go-playground/validator/v10" ) func chatHandler(i Instructor, ctx context.Context, request interface{}, response any) (interface{}, error) { @@ -38,6 +40,18 @@ func chatHandler(i Instructor, ctx context.Context, request interface{}, respons continue } + if i.Validate() { + validate = validator.New() + // Validate the response structure against the defined model using the validator + err = validate.Struct(response) + + if err != nil { + // TODO: + // add more sophisticated retry logic (send back validator error and parse error for model to fix). + continue + } + } + return resp, nil } diff --git a/pkg/instructor/chat_stream.go b/pkg/instructor/chat_stream.go index 80642db..68718b9 100644 --- a/pkg/instructor/chat_stream.go +++ b/pkg/instructor/chat_stream.go @@ -5,6 +5,8 @@ import ( "encoding/json" "reflect" "strings" + + "github.com/go-playground/validator/v10" ) type StreamWrapper[T any] struct { @@ -36,12 +38,17 @@ func chatStreamHandler(i Instructor, ctx context.Context, request interface{}, r return nil, err } - parsedChan := parseStream(ctx, ch, responseType) + shouldValidate := i.Validate() + if shouldValidate { + validate = validator.New() + } + + parsedChan := parseStream(ctx, ch, shouldValidate, responseType) return parsedChan, nil } -func parseStream(ctx context.Context, ch <-chan string, responseType reflect.Type) <-chan interface{} { +func parseStream(ctx context.Context, ch <-chan string, shouldValidate bool, responseType reflect.Type) <-chan interface{} { parsedChan := make(chan any) @@ -58,7 +65,7 @@ func parseStream(ctx context.Context, ch <-chan string, responseType reflect.Typ case text, ok := <-ch: if !ok { // Stream closed - processRemainingBuffer(buffer, parsedChan, responseType) + processRemainingBuffer(buffer, parsedChan, shouldValidate, responseType) return } @@ -69,7 +76,7 @@ func parseStream(ctx context.Context, ch <-chan string, responseType reflect.Typ inArray = startArray(buffer) } - processBuffer(buffer, parsedChan, responseType) + processBuffer(buffer, parsedChan, shouldValidate, responseType) } } }() @@ -93,7 +100,7 @@ func startArray(buffer *strings.Builder) bool { return true } -func processBuffer(buffer *strings.Builder, parsedChan chan<- interface{}, responseType reflect.Type) { +func processBuffer(buffer *strings.Builder, parsedChan chan<- interface{}, shouldValidate bool, responseType reflect.Type) { data := buffer.String() @@ -107,6 +114,15 @@ func processBuffer(buffer *strings.Builder, parsedChan chan<- interface{}, respo if err != nil { break } + + if shouldValidate { + // Validate the instance + err = validate.Struct(instance) + if err != nil { + break + } + } + parsedChan <- instance buffer.Reset() @@ -114,7 +130,7 @@ func processBuffer(buffer *strings.Builder, parsedChan chan<- interface{}, respo } } -func processRemainingBuffer(buffer *strings.Builder, parsedChan chan<- interface{}, responseType reflect.Type) { +func processRemainingBuffer(buffer *strings.Builder, parsedChan chan<- interface{}, shouldValidate bool, responseType reflect.Type) { data := buffer.String() @@ -124,5 +140,6 @@ func processRemainingBuffer(buffer *strings.Builder, parsedChan chan<- interface data = data[:idx] } - processBuffer(buffer, parsedChan, responseType) + processBuffer(buffer, parsedChan, shouldValidate, responseType) + } diff --git a/pkg/instructor/cohere_struct.go b/pkg/instructor/cohere_struct.go index e6cdb3d..08f1aa7 100644 --- a/pkg/instructor/cohere_struct.go +++ b/pkg/instructor/cohere_struct.go @@ -10,6 +10,7 @@ type InstructorCohere struct { provider Provider mode Mode maxRetries int + validate bool } var _ Instructor = &InstructorCohere{} @@ -39,3 +40,6 @@ func (i *InstructorCohere) Mode() string { func (i *InstructorCohere) MaxRetries() int { return i.maxRetries } +func (i *InstructorCohere) Validate() bool { + return i.validate +} diff --git a/pkg/instructor/instructor.go b/pkg/instructor/instructor.go index de0e111..47b86f7 100644 --- a/pkg/instructor/instructor.go +++ b/pkg/instructor/instructor.go @@ -2,12 +2,17 @@ package instructor import ( "context" + + "github.com/go-playground/validator/v10" ) +var validate *validator.Validate + type Instructor interface { Provider() Provider Mode() Mode MaxRetries() int + Validate() bool // Chat / Messages diff --git a/pkg/instructor/openai_struct.go b/pkg/instructor/openai_struct.go index a827771..9219839 100644 --- a/pkg/instructor/openai_struct.go +++ b/pkg/instructor/openai_struct.go @@ -10,6 +10,7 @@ type InstructorOpenAI struct { provider Provider mode Mode maxRetries int + validate bool } var _ Instructor = &InstructorOpenAI{} @@ -24,6 +25,7 @@ func FromOpenAI(client *openai.Client, opts ...Options) *InstructorOpenAI { provider: ProviderOpenAI, mode: *options.Mode, maxRetries: *options.MaxRetries, + validate: *options.validate, } return i } @@ -37,3 +39,6 @@ func (i *InstructorOpenAI) Mode() Mode { func (i *InstructorOpenAI) MaxRetries() int { return i.maxRetries } +func (i *InstructorOpenAI) Validate() bool { + return i.validate +} diff --git a/pkg/instructor/options.go b/pkg/instructor/options.go index 2afcfa5..5eb5342 100644 --- a/pkg/instructor/options.go +++ b/pkg/instructor/options.go @@ -2,18 +2,20 @@ package instructor const ( DefaultMaxRetries = 3 + DefaultValidator = false ) type Options struct { Mode *Mode MaxRetries *int - + validate *bool // Provider specific options: } var defaultOptions = Options{ Mode: toPtr(ModeDefault), MaxRetries: toPtr(DefaultMaxRetries), + validate: toPtr(DefaultValidator), } func WithMode(mode Mode) Options { @@ -24,6 +26,10 @@ func WithMaxRetries(maxRetries int) Options { return Options{MaxRetries: toPtr(maxRetries)} } +func WithValidation() Options { + return Options{validate: toPtr(true)} +} + func mergeOption(old, new Options) Options { if new.Mode != nil { old.Mode = new.Mode @@ -31,6 +37,9 @@ func mergeOption(old, new Options) Options { if new.MaxRetries != nil { old.MaxRetries = new.MaxRetries } + if new.validate != nil { + old.validate = new.validate + } return old }