Skip to content

Commit

Permalink
Add Validation Support (#26)
Browse files Browse the repository at this point in the history
* instructor validator

* update simple example

* validate in stream

* add with validator option

* add example for validator

* update example

* rename withValidator varible to validate

* rename with validator to withValidation

* remove required param for WithValidation

---------

Co-authored-by: Robby <[email protected]>
  • Loading branch information
Sushmithamallesh and h0rv authored Jun 14, 2024
1 parent 038c15c commit 91ffa0d
Show file tree
Hide file tree
Showing 10 changed files with 183 additions and 12 deletions.
85 changes: 85 additions & 0 deletions examples/validator/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package main

import (
"context"
"fmt"
"os"

"github.com/instructor-ai/instructor-go/pkg/instructor"
openai "github.com/sashabaranov/go-openai"
)

type User struct {
FirstName string `json:"first_name" jsonschema:"title=First Name,description=The first name of the user" validate:"required"`
LastName string `json:"last_name" jsonschema:"title=Last Name,description=The last name of the user" validate:"required"`
Age uint8 `json:"age" jsonschema:"title=Age,description=The age of the user" validate:"gte=0,lte=130"`
Email string `json:"email" jsonschema:"title=Email,description=The email address of the user" validate:"required,email"`
Gender string `json:"gender" jsonschema:"title=Gender,description=The gender of the user" validate:"oneof=male female prefer_not_to"`
FavouriteColor string `json:"favourite_color" jsonschema:"title=Favourite Color,description=The favourite color of the user" validate:"iscolor"`
Addresses []*Address `json:"addresses" jsonschema:"title=Addresses,description=The addresses of the user" validate:"required,dive,required"`
}

type Address struct {
Street string `json:"street" jsonschema:"title=Street,description=The street address" validate:"required"`
City string `json:"city" jsonschema:"title=City,description=The city" validate:"required"`
Planet string `json:"planet" jsonschema:"title=Planet,description=The planet" validate:"required"`
Phone string `json:"phone" jsonschema:"title=Phone,description=The phone number" validate:"required"`
}

func (u User) String() string {
result := fmt.Sprintf("First Name: %s\nLast Name: %s\nAge: %d\nEmail: %s\nGender: %s\nFavourite Color: %s\nAddresses:\n",
u.FirstName, u.LastName, u.Age, u.Email, u.Gender, u.FavouriteColor)
for _, address := range u.Addresses {
result += fmt.Sprintf(" %s\n", address)
}
return result
}

func (a Address) String() string {
return fmt.Sprintf("Street: %s, City: %s, Planet: %s, Phone: %s", a.Street, a.City, a.Planet, a.Phone)
}

func main() {
ctx := context.Background()

client := instructor.FromOpenAI(
openai.NewClient(os.Getenv("OPENAI_API_KEY")),
instructor.WithMode(instructor.ModeJSON),
instructor.WithMaxRetries(3),
instructor.WithValidation(),
)

var user User
_, err := client.CreateChatCompletion(
ctx,
openai.ChatCompletionRequest{
Model: openai.GPT4o,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
Content: "Meet Jane Doe: a 30-year-old adventurer who can be reached at [email protected]. " +
"Jane loves the vibrant hue of #FF5733. She resides in Metropolis at 456 Oak St, on the wonderful planet Earth. " +
"To chat with her, dial (555) 555-1234. Jane also spends her weekends at her cottage located at 789 Pine St, " +
"in Smallville, on the same planet. You can contact her there at (555) 555-5678.",
},
},
},
&user,
)
if err != nil {
panic(err)
}

fmt.Println(user)
/*
First Name: Jane
Last Name: Doe
Age: 30
Email: [email protected]
Gender: female
Favourite Color: #FF5733
Addresses:
Street: 456 Oak St, City: Metropolis, Planet: Earth, Phone: (555) 555-1234
Street: 789 Pine St, City: Smallville, Planet: Earth, Phone: (555) 555-5678
*/
}
9 changes: 9 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ go 1.21.8

require (
github.com/cohere-ai/cohere-go/v2 v2.8.1
github.com/go-playground/validator/v10 v10.21.0
github.com/invopop/jsonschema v0.12.0
github.com/liushuangls/go-anthropic/v2 v2.1.0
github.com/sashabaranov/go-openai v1.24.1
Expand All @@ -12,8 +13,16 @@ require (
require (
github.com/bahlo/generic-list-go v0.2.0 // indirect
github.com/buger/jsonparser v1.1.1 // indirect
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/leodido/go-urn v1.4.0 // indirect
github.com/mailru/easyjson v0.7.7 // indirect
github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect
golang.org/x/crypto v0.19.0 // indirect
golang.org/x/net v0.21.0 // indirect
golang.org/x/sys v0.17.0 // indirect
golang.org/x/text v0.14.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
26 changes: 22 additions & 4 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,23 @@ github.com/cohere-ai/cohere-go/v2 v2.8.1 h1:7+MCdXtz8onJLRmJik/cD5XGfgDNLhte4aW4
github.com/cohere-ai/cohere-go/v2 v2.8.1/go.mod h1:dlDCT66i8BqZDuuskFvYzsrc+O0M4l5J9Ibckoflvt4=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/google/uuid v1.4.0 h1:MtMxsa51/r9yyhkyLsVeVt0B+BGQZzpQiTQ4eHZ8bc4=
github.com/google/uuid v1.4.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
github.com/go-playground/validator/v10 v10.21.0 h1:4fZA11ovvtkdgaeev9RGWPgc1uj3H8W+rNYyH/ySBb0=
github.com/go-playground/validator/v10 v10.21.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/invopop/jsonschema v0.12.0 h1:6ovsNSuvn9wEQVOyc72aycBMVQFKz7cPdMJn10CvzRI=
github.com/invopop/jsonschema v0.12.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0=
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
github.com/liushuangls/go-anthropic/v2 v2.1.0 h1:5ntOeehozlMin0+hgnhxbTru+tmBH84ADaSPelG5fPg=
github.com/liushuangls/go-anthropic/v2 v2.1.0/go.mod h1:8BKv/fkeTaL5R9R9bGkaknYBueyw2WxY20o7bImbOek=
github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0=
Expand All @@ -21,10 +31,18 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/sashabaranov/go-openai v1.24.1 h1:DWK95XViNb+agQtuzsn+FyHhn3HQJ7Va8z04DQDJ1MI=
github.com/sashabaranov/go-openai v1.24.1/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc=
github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw=
golang.org/x/crypto v0.19.0 h1:ENy+Az/9Y1vSrlrvBSyna3PITt4tiZLf7sgCjZBX7Wo=
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4=
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y=
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
Expand Down
5 changes: 5 additions & 0 deletions pkg/instructor/anthropic.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ type InstructorAnthropic struct {
provider Provider
mode Mode
maxRetries int
validate bool
}

var _ Instructor = &InstructorAnthropic{}
Expand All @@ -24,6 +25,7 @@ func FromAnthropic(client *anthropic.Client, opts ...Options) *InstructorAnthrop
provider: ProviderOpenAI,
mode: *options.Mode,
maxRetries: *options.MaxRetries,
validate: *options.validate,
}
return i
}
Expand All @@ -39,3 +41,6 @@ func (i *InstructorAnthropic) Mode() string {
func (i *InstructorAnthropic) Provider() string {
return i.provider
}
func (i *InstructorAnthropic) Validate() bool {
return i.validate
}
14 changes: 14 additions & 0 deletions pkg/instructor/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"encoding/json"
"errors"
"reflect"

"github.com/go-playground/validator/v10"
)

func chatHandler(i Instructor, ctx context.Context, request interface{}, response any) (interface{}, error) {
Expand Down Expand Up @@ -38,6 +40,18 @@ func chatHandler(i Instructor, ctx context.Context, request interface{}, respons
continue
}

if i.Validate() {
validate = validator.New()
// Validate the response structure against the defined model using the validator
err = validate.Struct(response)

if err != nil {
// TODO:
// add more sophisticated retry logic (send back validator error and parse error for model to fix).
continue
}
}

return resp, nil
}

Expand Down
31 changes: 24 additions & 7 deletions pkg/instructor/chat_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"encoding/json"
"reflect"
"strings"

"github.com/go-playground/validator/v10"
)

type StreamWrapper[T any] struct {
Expand Down Expand Up @@ -36,12 +38,17 @@ func chatStreamHandler(i Instructor, ctx context.Context, request interface{}, r
return nil, err
}

parsedChan := parseStream(ctx, ch, responseType)
shouldValidate := i.Validate()
if shouldValidate {
validate = validator.New()
}

parsedChan := parseStream(ctx, ch, shouldValidate, responseType)

return parsedChan, nil
}

func parseStream(ctx context.Context, ch <-chan string, responseType reflect.Type) <-chan interface{} {
func parseStream(ctx context.Context, ch <-chan string, shouldValidate bool, responseType reflect.Type) <-chan interface{} {

parsedChan := make(chan any)

Expand All @@ -58,7 +65,7 @@ func parseStream(ctx context.Context, ch <-chan string, responseType reflect.Typ
case text, ok := <-ch:
if !ok {
// Stream closed
processRemainingBuffer(buffer, parsedChan, responseType)
processRemainingBuffer(buffer, parsedChan, shouldValidate, responseType)
return
}

Expand All @@ -69,7 +76,7 @@ func parseStream(ctx context.Context, ch <-chan string, responseType reflect.Typ
inArray = startArray(buffer)
}

processBuffer(buffer, parsedChan, responseType)
processBuffer(buffer, parsedChan, shouldValidate, responseType)
}
}
}()
Expand All @@ -93,7 +100,7 @@ func startArray(buffer *strings.Builder) bool {
return true
}

func processBuffer(buffer *strings.Builder, parsedChan chan<- interface{}, responseType reflect.Type) {
func processBuffer(buffer *strings.Builder, parsedChan chan<- interface{}, shouldValidate bool, responseType reflect.Type) {

data := buffer.String()

Expand All @@ -107,14 +114,23 @@ func processBuffer(buffer *strings.Builder, parsedChan chan<- interface{}, respo
if err != nil {
break
}

if shouldValidate {
// Validate the instance
err = validate.Struct(instance)
if err != nil {
break
}
}

parsedChan <- instance

buffer.Reset()
buffer.WriteString(remaining)
}
}

func processRemainingBuffer(buffer *strings.Builder, parsedChan chan<- interface{}, responseType reflect.Type) {
func processRemainingBuffer(buffer *strings.Builder, parsedChan chan<- interface{}, shouldValidate bool, responseType reflect.Type) {

data := buffer.String()

Expand All @@ -124,5 +140,6 @@ func processRemainingBuffer(buffer *strings.Builder, parsedChan chan<- interface
data = data[:idx]
}

processBuffer(buffer, parsedChan, responseType)
processBuffer(buffer, parsedChan, shouldValidate, responseType)

}
4 changes: 4 additions & 0 deletions pkg/instructor/cohere_struct.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ type InstructorCohere struct {
provider Provider
mode Mode
maxRetries int
validate bool
}

var _ Instructor = &InstructorCohere{}
Expand Down Expand Up @@ -39,3 +40,6 @@ func (i *InstructorCohere) Mode() string {
func (i *InstructorCohere) MaxRetries() int {
return i.maxRetries
}
func (i *InstructorCohere) Validate() bool {
return i.validate
}
5 changes: 5 additions & 0 deletions pkg/instructor/instructor.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,17 @@ package instructor

import (
"context"

"github.com/go-playground/validator/v10"
)

var validate *validator.Validate

type Instructor interface {
Provider() Provider
Mode() Mode
MaxRetries() int
Validate() bool

// Chat / Messages

Expand Down
5 changes: 5 additions & 0 deletions pkg/instructor/openai_struct.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ type InstructorOpenAI struct {
provider Provider
mode Mode
maxRetries int
validate bool
}

var _ Instructor = &InstructorOpenAI{}
Expand All @@ -24,6 +25,7 @@ func FromOpenAI(client *openai.Client, opts ...Options) *InstructorOpenAI {
provider: ProviderOpenAI,
mode: *options.Mode,
maxRetries: *options.MaxRetries,
validate: *options.validate,
}
return i
}
Expand All @@ -37,3 +39,6 @@ func (i *InstructorOpenAI) Mode() Mode {
func (i *InstructorOpenAI) MaxRetries() int {
return i.maxRetries
}
func (i *InstructorOpenAI) Validate() bool {
return i.validate
}
11 changes: 10 additions & 1 deletion pkg/instructor/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,20 @@ package instructor

const (
DefaultMaxRetries = 3
DefaultValidator = false
)

type Options struct {
Mode *Mode
MaxRetries *int

validate *bool
// Provider specific options:
}

var defaultOptions = Options{
Mode: toPtr(ModeDefault),
MaxRetries: toPtr(DefaultMaxRetries),
validate: toPtr(DefaultValidator),
}

func WithMode(mode Mode) Options {
Expand All @@ -24,13 +26,20 @@ func WithMaxRetries(maxRetries int) Options {
return Options{MaxRetries: toPtr(maxRetries)}
}

func WithValidation() Options {
return Options{validate: toPtr(true)}
}

func mergeOption(old, new Options) Options {
if new.Mode != nil {
old.Mode = new.Mode
}
if new.MaxRetries != nil {
old.MaxRetries = new.MaxRetries
}
if new.validate != nil {
old.validate = new.validate
}

return old
}
Expand Down

0 comments on commit 91ffa0d

Please sign in to comment.