diff --git a/README.md b/README.md index 108b545..778a816 100644 --- a/README.md +++ b/README.md @@ -18,6 +18,16 @@ As shown in the example below, by adding extra metadata to each struct field (vi > For more information on the `jsonschema` tags available, see the [`jsonschema` godoc](https://pkg.go.dev/github.com/invopop/jsonschema?utm_source=godoc). +
+Running + +```bash +export OPENAI_API_KEY= +go run examples/user/main.go +``` + +
+ ```go package main @@ -76,9 +86,654 @@ Age: %d } ``` -## Coming Soon +### Other Examples + +
+Function Calling with OpenAI + +
+Running + +```bash +export OPENAI_API_KEY= +go run examples/function_calling/main.go +``` + +
+ +```go +package main + +import ( + "context" + "fmt" + "os" + + "github.com/instructor-ai/instructor-go/pkg/instructor" + openai "github.com/sashabaranov/go-openai" +) + +type SearchType string + +const ( + Web SearchType = "web" + Image SearchType = "image" + Video SearchType = "video" +) + +type Search struct { + Topic string `json:"topic" jsonschema:"title=Topic,description=Topic of the search,example=golang"` + Query string `json:"query" jsonschema:"title=Query,description=Query to search for relevant content,example=what is golang"` + Type SearchType `json:"type" jsonschema:"title=Type,description=Type of search,default=web,enum=web,enum=image,enum=video"` +} + +func (s *Search) execute() { + fmt.Printf("Searching for `%s` with query `%s` using `%s`\n", s.Topic, s.Query, s.Type) +} + +type Searches = []Search + +func segment(ctx context.Context, data string) *Searches { + + client, err := instructor.FromOpenAI( + openai.NewClient(os.Getenv("OPENAI_API_KEY")), + instructor.WithMode(instructor.ModeToolCall), + instructor.WithMaxRetries(3), + ) + if err != nil { + panic(err) + } + + var searches Searches + err = client.CreateChatCompletion( + ctx, + instructor.Request{ + Model: openai.GPT4o, + Messages: []instructor.Message{ + { + Role: instructor.RoleUser, + Content: fmt.Sprintf("Consider the data below: '\n%s' and segment it into multiple search queries", data), + }, + }, + }, + &searches, + ) + if err != nil { + panic(err) + } + + return &searches +} + +func main() { + ctx := context.Background() + + q := "Search for a picture of a cat, a video of a dog, and the taxonomy of each" + for _, search := range *segment(ctx, q) { + search.execute() + } + /* + Searching for `cat` with query `picture of a cat` using `image` + Searching for `dog` with query `video of a dog` using `video` + Searching for `cat` with query `taxonomy of a cat` using `web` + Searching for `dog` with query `taxonomy of a dog` using `web` + */ +} +``` + +
+ +
+Text Classification with Anthropic + +
+Running + +```bash +export ANTHROPIC_API_KEY= +go run examples/classification/main.go +``` + +
+ +```go + package main + +import ( + "context" + "fmt" + "os" + + "github.com/instructor-ai/instructor-go/pkg/instructor" + anthropic "github.com/liushuangls/go-anthropic/v2" +) + +type LabelType string + +const ( + LabelTechIssue LabelType = "tech_issue" + LabelBilling LabelType = "billing" + LabelGeneralQuery LabelType = "general_query" +) + +type Label struct { + Type LabelType `json:"type" jsonschema:"title=Label type,description=Type of the label,enum=tech_issue,enum=billing,enum=general_query"` +} + +type Prediction struct { + Labels []Label `json:"labels" jsonschema:"title=Predicted labels,description=Labels of the prediction"` +} + +func classify(data string) *Prediction { + ctx := context.Background() + + client, err := instructor.FromAnthropic( + anthropic.NewClient(os.Getenv("ANTHROPIC_API_KEY")), + instructor.WithMode(instructor.ModeToolCall), + instructor.WithMaxRetries(3), + ) + if err != nil { + panic(err) + } + + var prediction Prediction + err = client.CreateChatCompletion( + ctx, + instructor.Request{ + Model: anthropic.ModelClaude3Haiku20240307, + Messages: []instructor.Message{ + { + Role: instructor.RoleUser, + Content: fmt.Sprintf("Classify the following support ticket: %s", data), + }, + }, + }, + &prediction, + ) + if err != nil { + panic(err) + } + + return &prediction +} + +func main() { + + ticket := "My account is locked and I can't access my billing info." + prediction := classify(ticket) + + assert(prediction.contains(LabelTechIssue), "Expected ticket to be related to tech issue") + assert(prediction.contains(LabelBilling), "Expected ticket to be related to billing") + assert(!prediction.contains(LabelGeneralQuery), "Expected ticket NOT to be a general query") + + fmt.Printf("%+v\n", prediction) + /* + &{Labels:[{Type:tech_issue} {Type:billing}]} + */ +} + +/******/ + +func (p *Prediction) contains(label LabelType) bool { + for _, l := range p.Labels { + if l.Type == label { + return true + } + } + return false +} + +func assert(condition bool, message string) { + if !condition { + fmt.Println("Assertion failed:", message) + } +} +``` + +
+ +
+Images with OpenAI + +
+Running + +```bash +export OPENAI_API_KEY= +go run examples/images/openai/main.go +``` + +
+ +```go +package main + +import ( + "context" + "fmt" + "os" + + "github.com/instructor-ai/instructor-go/pkg/instructor" + openai "github.com/sashabaranov/go-openai" +) + +type Book struct { + Title string `json:"title,omitempty" jsonschema:"title=title,description=The title of the book,example=Harry Potter and the Philosopher's Stone"` + Author *string `json:"author,omitempty" jsonschema:"title=author,description=The author of the book,example=J.K. Rowling"` +} + +type BookCatalog struct { + Catalog []Book `json:"catalog"` +} + +func (bc *BookCatalog) PrintCatalog() { + fmt.Printf("Number of books in the catalog: %d\n\n", len(bc.Catalog)) + for _, book := range bc.Catalog { + fmt.Printf("Title: %s\n", book.Title) + fmt.Printf("Author: %s\n", *book.Author) + fmt.Println("--------------------") + } +} + +func main() { + ctx := context.Background() + + client, err := instructor.FromOpenAI( + openai.NewClient(os.Getenv("OPENAI_API_KEY")), + instructor.WithMode(instructor.ModeJSON), + instructor.WithMaxRetries(3), + ) + if err != nil { + panic(err) + } + + url := "https://utfs.io/f/fe55d6bd-e920-4a6f-8e93-a4c9dd851b90-eivhb2.png" + + var bookCatalog BookCatalog + err = client.CreateChatCompletion( + ctx, + instructor.Request{ + Model: openai.GPT4o, + Messages: []instructor.Message{ + { + Role: instructor.RoleUser, + MultiContent: []instructor.ChatMessagePart{ + { + Type: instructor.ChatMessagePartTypeText, + Text: "Extract book catelog from the image", + }, + { + Type: instructor.ChatMessagePartTypeImageURL, + ImageURL: &instructor.ChatMessageImageURL{ + URL: url, + }, + }, + }, + }, + }, + }, + &bookCatalog, + ) + + if err != nil { + panic(err) + } + + bookCatalog.PrintCatalog() + /* + Number of books in the catalog: 15 + + Title: Pride and Prejudice + Author: Jane Austen + -------------------- + Title: The Great Gatsby + Author: F. Scott Fitzgerald + -------------------- + Title: The Catcher in the Rye + Author: J. D. Salinger + -------------------- + Title: Don Quixote + Author: Miguel de Cervantes + -------------------- + Title: One Hundred Years of Solitude + Author: Gabriel García Márquez + -------------------- + Title: To Kill a Mockingbird + Author: Harper Lee + -------------------- + Title: Beloved + Author: Toni Morrison + -------------------- + Title: Ulysses + Author: James Joyce + -------------------- + Title: Harry Potter and the Cursed Child + Author: J.K. Rowling + -------------------- + Title: The Grapes of Wrath + Author: John Steinbeck + -------------------- + Title: 1984 + Author: George Orwell + -------------------- + Title: Lolita + Author: Vladimir Nabokov + -------------------- + Title: Anna Karenina + Author: Leo Tolstoy + -------------------- + Title: Moby-Dick + Author: Herman Melville + -------------------- + Title: Wuthering Heights + Author: Emily Brontë + -------------------- + */ +} +``` + +
+ +
+Images with Anthropic + +
+Running + +```bash +export ANTHROPIC_API_KEY= +go run examples/images/anthropic/main.go +``` + +
+ +```go +package main + +import ( + "context" + "fmt" + "os" + + "github.com/instructor-ai/instructor-go/pkg/instructor" + "github.com/liushuangls/go-anthropic/v2" +) + +type Movie struct { + Title string `json:"title" jsonschema:"title=title,description=The title of the movie,required=true,example=Ex Machina"` + Year int `json:"year,omitempty" jsonschema:"title=year,description=The year of the movie,required=false,example=2014"` +} + +type MovieCatalog struct { + Catalog []Movie `json:"catalog"` +} + +func (bc *MovieCatalog) PrintCatalog() { + fmt.Printf("Number of movies in the catalog: %d\n\n", len(bc.Catalog)) + for _, movie := range bc.Catalog { + fmt.Printf("Title: %s\n", movie.Title) + if movie.Year != 0 { + fmt.Printf("Year: %d\n", movie.Year) + } + fmt.Println("--------------------") + } +} + +func main() { + ctx := context.Background() + + client, err := instructor.FromAnthropic( + anthropic.NewClient(os.Getenv("ANTHROPIC_API_KEY")), + instructor.WithMode(instructor.ModeJSONSchema), + instructor.WithMaxRetries(3), + ) + if err != nil { + panic(err) + } + + url := "https://utfs.io/f/bd0dbae6-27e3-4604-b640-fd2ffea891b8-fxyywt.jpeg" + + var movieCatalog MovieCatalog + err = client.CreateChatCompletion( + ctx, + instructor.Request{ + Model: "claude-3-haiku-20240307", + Messages: []instructor.Message{ + { + Role: instructor.RoleUser, + MultiContent: []instructor.ChatMessagePart{ + { + Type: instructor.ChatMessagePartTypeText, + Text: "Extract the movie catalog from the screenshot", + }, + { + Type: instructor.ChatMessagePartTypeImageURL, + ImageURL: &instructor.ChatMessageImageURL{ + URL: url, + }, + }, + }, + }, + }, + }, + &movieCatalog, + ) + if err != nil { + panic(err) + } + + movieCatalog.PrintCatalog() + /* + Number of movies in the catalog: 18 + + Title: Oppenheimer + Year: 2023 + -------------------- + Title: The Dark Knight + Year: 2008 + -------------------- + Title: Interstellar + Year: 2014 + -------------------- + Title: Inception + Year: 2010 + -------------------- + Title: Tenet + Year: 2020 + -------------------- + Title: Dunkirk + Year: 2017 + -------------------- + Title: Memento + Year: 2000 + -------------------- + Title: The Dark Knight Rises + Year: 2012 + -------------------- + Title: Batman Begins + Year: 2005 + -------------------- + Title: The Prestige + Year: 2006 + -------------------- + Title: Insomnia + Year: 2002 + -------------------- + Title: Following + Year: 1998 + -------------------- + Title: Man of Steel + Year: 2013 + -------------------- + Title: Transcendence + Year: 2014 + -------------------- + Title: Justice League + Year: 2017 + -------------------- + Title: Batman v Superman: Dawn of Justice + Year: 2016 + -------------------- + Title: Ending the Knight + Year: 2016 + -------------------- + Title: Larceny + -------------------- + */ +} +``` + +
+ +
+Streaming with OpenAI + +
+Running + +```bash +export OPENAI_API_KEY= +go run examples/streaming/main.go +``` + +
+ +```go +package main + +import ( + "context" + "fmt" + "os" + + "github.com/instructor-ai/instructor-go/pkg/instructor" + openai "github.com/sashabaranov/go-openai" +) + +type Product struct { + ID string `json:"product_id" jsonschema:"title=Product ID,description=ID of the product,required=True"` + Name string `json:"product_name" jsonschema:"title=Product Name,description=Name of the product,required=True"` +} + +func (p *Product) String() string { + return fmt.Sprintf("Product [ID: %s, Name: %s]", p.ID, p.Name) +} + +type Recommendation struct { + Product + Reason string `json:"reason" jsonschema:"title=Recommendation Reason,description=Reason for the product recommendation"` +} + +func (r *Recommendation) String() string { + return fmt.Sprintf(` +Recommendation [ + %s + Reason [%s] +]`, r.Product.String(), r.Reason) +} + +func main() { + ctx := context.Background() + + client, err := instructor.FromOpenAI( + openai.NewClient(os.Getenv("OPENAI_API_KEY")), + instructor.WithMode(instructor.ModeJSON), + ) + if err != nil { + panic(err) + } + + profileData := ` +Customer ID: 12345 +Recent Purchases: [Laptop, Wireless Headphones, Smart Watch] +Frequently Browsed Categories: [Electronics, Books, Fitness Equipment] +Product Ratings: {Laptop: 5 stars, Wireless Headphones: 4 stars} +Recent Search History: [best budget laptops 2023, latest sci-fi books, yoga mats] +Preferred Brands: [Apple, AllBirds, Bench] +Responses to Previous Recommendations: {Philips: Not Interested, Adidas: Not Interested} +Loyalty Program Status: Gold Member +Average Monthly Spend: $500 +Preferred Shopping Times: Weekend Evenings +... +` + + products := []Product{ + {ID: "1", Name: "Sony WH-1000XM4 Wireless Headphones - Noise-canceling, long battery life"}, + {ID: "2", Name: "Apple Watch Series 7 - Advanced fitness tracking, seamless integration with Apple ecosystem"}, + {ID: "3", Name: "Kindle Oasis - Premium e-reader with adjustable warm light"}, + {ID: "4", Name: "AllBirds Wool Runners - Comfortable, eco-friendly sneakers"}, + {ID: "5", Name: "Manduka PRO Yoga Mat - High-quality, durable, eco-friendly"}, + {ID: "6", Name: "Bench Hooded Jacket - Stylish, durable, suitable for outdoor activities"}, + {ID: "7", Name: "Apple MacBook Air (2023) - Latest model, high performance, portable"}, + {ID: "8", Name: "GoPro HERO9 Black - 5K video, waterproof, for action photography"}, + {ID: "9", Name: "Nespresso Vertuo Next Coffee Machine - Quality coffee, easy to use, compact design"}, + {ID: "10", Name: "Project Hail Mary by Andy Weir - Latest sci-fi book from a renowned author"}, + } + + productList := "" + for _, product := range products { + productList += product.String() + "\n" + } + + recommendationChan, err := client.CreateChatCompletionStream( + ctx, + instructor.Request{ + Model: openai.GPT4o20240513, + Messages: []instructor.Message{ + { + Role: instructor.RoleSystem, + Content: fmt.Sprintf(`Generate the product recommendations from the product list based on the customer profile. +Return in order of highest recommended first. +Product list: +%s`, productList), + }, + { + Role: instructor.RoleUser, + Content: fmt.Sprintf("User profile:\n%s", profileData), + }, + }, + Stream: true, + }, + *new(Recommendation), + ) + if err != nil { + panic(err) + } + + for instance := range recommendationChan { + recommendation, _ := instance.(*Recommendation) + println(recommendation.String()) + } + /* + Recommendation [ + Product [ID: 7, Name: Apple MacBook Air (2023) - Latest model, high performance, portable] + Reason [As you have recently searched for budget laptops of 2023 and previously purchased a laptop, we believe the latest Apple MacBook Air will meet your high-performance requirements. Additionally, Apple is one of your preferred brands.] + ] + + Recommendation [ + Product [ID: 2, Name: Apple Watch Series 7 - Advanced fitness tracking, seamless integration with Apple ecosystem] + Reason [Based on your recent purchase history which includes a smart watch and your preference for Apple products, we recommend the Apple Watch Series 7 for its advanced fitness tracking features.] + ] + + Recommendation [ + Product [ID: 10, Name: Project Hail Mary by Andy Weir - Latest sci-fi book from a renowned author] + Reason [Given your recent search for the latest sci-fi books and frequent browsing in the Books category, 'Project Hail Mary' by Andy Weir may interest you.] + ] + + Recommendation [ + Product [ID: 5, Name: Manduka PRO Yoga Mat - High-quality, durable, eco-friendly] + Reason [Since you recently searched for yoga mats and frequently browse fitness equipment, we recommend the Manduka PRO Yoga Mat to support your fitness activities.] + ] + + Recommendation [ + Product [ID: 4, Name: AllBirds Wool Runners - Comfortable, eco-friendly sneakers] + Reason [Considering your preference for the AllBirds brand and your frequent browsing in fitness categories, the AllBirds Wool Runners would be a great fit for your lifestyle.] + ] + */ +} +``` -1. Streaming support +
## Providers diff --git a/examples/function_calling/main.go b/examples/function_calling/main.go index 1e6d7f2..9f9eed4 100644 --- a/examples/function_calling/main.go +++ b/examples/function_calling/main.go @@ -29,10 +29,6 @@ func (s *Search) execute() { type Searches = []Search -// type Searches struct { -// Items []Search `json:"searches" jsonschema:"title=Searches,description=A list of search results"` -// } - func segment(ctx context.Context, data string) *Searches { client, err := instructor.FromOpenAI( diff --git a/examples/streaming/main.go b/examples/streaming/main.go new file mode 100644 index 0000000..5a5220f --- /dev/null +++ b/examples/streaming/main.go @@ -0,0 +1,132 @@ +package main + +import ( + "context" + "fmt" + "os" + + "github.com/instructor-ai/instructor-go/pkg/instructor" + openai "github.com/sashabaranov/go-openai" +) + +type Product struct { + ID string `json:"product_id" jsonschema:"title=Product ID,description=ID of the product,required=True"` + Name string `json:"product_name" jsonschema:"title=Product Name,description=Name of the product,required=True"` +} + +func (p *Product) String() string { + return fmt.Sprintf("Product [ID: %s, Name: %s]", p.ID, p.Name) +} + +type Recommendation struct { + Product + Reason string `json:"reason" jsonschema:"title=Recommendation Reason,description=Reason for the product recommendation"` +} + +func (r *Recommendation) String() string { + return fmt.Sprintf(` +Recommendation [ + %s + Reason [%s] +]`, r.Product.String(), r.Reason) +} + +func main() { + ctx := context.Background() + + client, err := instructor.FromOpenAI( + openai.NewClient(os.Getenv("OPENAI_API_KEY")), + instructor.WithMode(instructor.ModeJSON), + ) + if err != nil { + panic(err) + } + + profileData := ` +Customer ID: 12345 +Recent Purchases: [Laptop, Wireless Headphones, Smart Watch] +Frequently Browsed Categories: [Electronics, Books, Fitness Equipment] +Product Ratings: {Laptop: 5 stars, Wireless Headphones: 4 stars} +Recent Search History: [best budget laptops 2023, latest sci-fi books, yoga mats] +Preferred Brands: [Apple, AllBirds, Bench] +Responses to Previous Recommendations: {Philips: Not Interested, Adidas: Not Interested} +Loyalty Program Status: Gold Member +Average Monthly Spend: $500 +Preferred Shopping Times: Weekend Evenings +... +` + + products := []Product{ + {ID: "1", Name: "Sony WH-1000XM4 Wireless Headphones - Noise-canceling, long battery life"}, + {ID: "2", Name: "Apple Watch Series 7 - Advanced fitness tracking, seamless integration with Apple ecosystem"}, + {ID: "3", Name: "Kindle Oasis - Premium e-reader with adjustable warm light"}, + {ID: "4", Name: "AllBirds Wool Runners - Comfortable, eco-friendly sneakers"}, + {ID: "5", Name: "Manduka PRO Yoga Mat - High-quality, durable, eco-friendly"}, + {ID: "6", Name: "Bench Hooded Jacket - Stylish, durable, suitable for outdoor activities"}, + {ID: "7", Name: "Apple MacBook Air (2023) - Latest model, high performance, portable"}, + {ID: "8", Name: "GoPro HERO9 Black - 5K video, waterproof, for action photography"}, + {ID: "9", Name: "Nespresso Vertuo Next Coffee Machine - Quality coffee, easy to use, compact design"}, + {ID: "10", Name: "Project Hail Mary by Andy Weir - Latest sci-fi book from a renowned author"}, + } + + productList := "" + for _, product := range products { + productList += product.String() + "\n" + } + + recommendationChan, err := client.CreateChatCompletionStream( + ctx, + instructor.Request{ + Model: openai.GPT4o20240513, + Messages: []instructor.Message{ + { + Role: instructor.RoleSystem, + Content: fmt.Sprintf(`Generate the product recommendations from the product list based on the customer profile. +Return in order of highest recommended first. +Product list: +%s`, productList), + }, + { + Role: instructor.RoleUser, + Content: fmt.Sprintf("User profile:\n%s", profileData), + }, + }, + Stream: true, + }, + *new(Recommendation), + ) + if err != nil { + panic(err) + } + + for instance := range recommendationChan { + recommendation, _ := instance.(*Recommendation) + println(recommendation.String()) + } + /* + Recommendation [ + Product [ID: 7, Name: Apple MacBook Air (2023) - Latest model, high performance, portable] + Reason [As you have recently searched for budget laptops of 2023 and previously purchased a laptop, we believe the latest Apple MacBook Air will meet your high-performance requirements. Additionally, Apple is one of your preferred brands.] + ] + + Recommendation [ + Product [ID: 2, Name: Apple Watch Series 7 - Advanced fitness tracking, seamless integration with Apple ecosystem] + Reason [Based on your recent purchase history which includes a smart watch and your preference for Apple products, we recommend the Apple Watch Series 7 for its advanced fitness tracking features.] + ] + + Recommendation [ + Product [ID: 10, Name: Project Hail Mary by Andy Weir - Latest sci-fi book from a renowned author] + Reason [Given your recent search for the latest sci-fi books and frequent browsing in the Books category, 'Project Hail Mary' by Andy Weir may interest you.] + ] + + Recommendation [ + Product [ID: 5, Name: Manduka PRO Yoga Mat - High-quality, durable, eco-friendly] + Reason [Since you recently searched for yoga mats and frequently browse fitness equipment, we recommend the Manduka PRO Yoga Mat to support your fitness activities.] + ] + + Recommendation [ + Product [ID: 4, Name: AllBirds Wool Runners - Comfortable, eco-friendly sneakers] + Reason [Considering your preference for the AllBirds brand and your frequent browsing in fitness categories, the AllBirds Wool Runners would be a great fit for your lifestyle.] + ] + */ +} diff --git a/pkg/instructor/anthropic.go b/pkg/instructor/anthropic.go index 5906ba2..8b4b208 100644 --- a/pkg/instructor/anthropic.go +++ b/pkg/instructor/anthropic.go @@ -7,6 +7,7 @@ import ( "fmt" anthropic "github.com/liushuangls/go-anthropic/v2" + openai "github.com/sashabaranov/go-openai" ) type AnthropicClient struct { @@ -163,3 +164,7 @@ func toAnthropicMessages(request *Request) (*[]anthropic.Message, error) { return &messages, nil } + +func (a *AnthropicClient) CreateChatCompletionStream(ctx context.Context, request openai.ChatCompletionRequest, mode string, schema *Schema) (<-chan string, error) { + panic("unimplemented") +} diff --git a/pkg/instructor/client.go b/pkg/instructor/client.go index dfeb75a..7d7a657 100644 --- a/pkg/instructor/client.go +++ b/pkg/instructor/client.go @@ -12,10 +12,10 @@ type Client interface { schema *Schema, ) (string, error) - // TODO: implement streaming - // CreateChatCompletionStream( - // ctx context.Context, - // request ChatCompletionRequest, - // opts ...ClientOptions, - // ) (*T, error) + CreateChatCompletionStream( + ctx context.Context, + request Request, + mode Mode, + schema *Schema, + ) (<-chan string, error) } diff --git a/pkg/instructor/instructor.go b/pkg/instructor/instructor.go index e119e21..c1a2fcb 100644 --- a/pkg/instructor/instructor.go +++ b/pkg/instructor/instructor.go @@ -90,59 +90,116 @@ func (i *Instructor) CreateChatCompletion(ctx context.Context, request Request, return errors.New("hit max retry attempts") } -func processResponse(responseStr string, response *any) error { +const WRAPPER_END = `"items": [` - err := json.Unmarshal([]byte(responseStr), response) +type StreamWrapper[T any] struct { + Items []T `json:"items"` +} + +func (i *Instructor) CreateChatCompletionStream(ctx context.Context, request Request, response any) (chan any, error) { + + responseType := reflect.TypeOf(response) + + streamWrapperType := reflect.StructOf([]reflect.StructField{ + { + Name: "Items", + Type: reflect.SliceOf(responseType), + Tag: `json:"items"`, + Anonymous: false, + }, + }) + + schema, err := NewSchema(streamWrapperType) if err != nil { - return err + return nil, err } - // TODO: if direct unmarshal fails: check common erors like wrapping struct with key name of struct, instead of just the value - - return nil -} + request.Stream = true -// Removes any prefixes before the JSON (like "Sure, here you go:") -func trimPrefix(jsonStr string) string { - startObject := strings.IndexByte(jsonStr, '{') - startArray := strings.IndexByte(jsonStr, '[') - - var start int - if startObject == -1 && startArray == -1 { - return jsonStr // No opening brace or bracket found, return the original string - } else if startObject == -1 { - start = startArray - } else if startArray == -1 { - start = startObject - } else { - start = min(startObject, startArray) + ch, err := i.Client.CreateChatCompletionStream(ctx, request, i.Mode, schema) + if err != nil { + return nil, err } - return jsonStr[start:] + parsedChan := make(chan any) // Buffered channel for parsed objects + + go func() { + defer close(parsedChan) + var buffer strings.Builder + inArray := false + + for { + select { + case <-ctx.Done(): + return + case text, ok := <-ch: + if !ok { + // Steeam closed + + // Get last element out of stream wrapper + + data := buffer.String() + + if idx := strings.LastIndex(data, "]"); idx != -1 { + data = data[:idx] + data[idx+1:] + } + + // Process the remaining data in the buffer + decoder := json.NewDecoder(strings.NewReader(data)) + for decoder.More() { + instance := reflect.New(responseType).Interface() + err := decoder.Decode(instance) + if err != nil { + break + } + parsedChan <- instance + } + return + } + buffer.WriteString(text) + + // eat all input until elements stream starts + if !inArray { + idx := strings.Index(buffer.String(), WRAPPER_END) + if idx == -1 { + continue + } + + inArray = true + bufferStr := buffer.String() + trimmed := strings.TrimSpace(bufferStr[idx+len(WRAPPER_END):]) + buffer.Reset() + buffer.WriteString(trimmed) + } + + data := buffer.String() + decoder := json.NewDecoder(strings.NewReader(data)) + + for decoder.More() { + instance := reflect.New(responseType).Interface() + err := decoder.Decode(instance) + if err != nil { + break + } + parsedChan <- instance + + buffer.Reset() + buffer.WriteString(data[len(data):]) + } + } + } + }() + + return parsedChan, nil } -// Removes any postfixes after the JSON -func trimPostfix(jsonStr string) string { - endObject := strings.LastIndexByte(jsonStr, '}') - endArray := strings.LastIndexByte(jsonStr, ']') - - var end int - if endObject == -1 && endArray == -1 { - return jsonStr // No closing brace or bracket found, return the original string - } else if endObject == -1 { - end = endArray - } else if endArray == -1 { - end = endObject - } else { - end = max(endObject, endArray) +func processResponse(responseStr string, response *any) error { + err := json.Unmarshal([]byte(responseStr), response) + if err != nil { + return err } - return jsonStr[:end+1] -} + // TODO: if direct unmarshal fails: check common errors like wrapping struct with key name of struct, instead of just the value -// Extracts the JSON by trimming prefixes and postfixes -func extractJSON(jsonStr string) string { - trimmedPrefix := trimPrefix(jsonStr) - trimmedJSON := trimPostfix(trimmedPrefix) - return trimmedJSON + return nil } diff --git a/pkg/instructor/messages.go b/pkg/instructor/messages.go index 637d8b6..fec52d9 100644 --- a/pkg/instructor/messages.go +++ b/pkg/instructor/messages.go @@ -4,15 +4,13 @@ import ( openai "github.com/sashabaranov/go-openai" ) -type Message = openai.ChatCompletionMessage - -type Request = openai.ChatCompletionRequest - -type ChatMessagePart = openai.ChatMessagePart - -type ChatMessageImageURL = openai.ChatMessageImageURL - -type ChatMessagePartType = openai.ChatMessagePartType +type ( + Message = openai.ChatCompletionMessage + Request = openai.ChatCompletionRequest + ChatMessagePart = openai.ChatMessagePart + ChatMessageImageURL = openai.ChatMessageImageURL + ChatMessagePartType = openai.ChatMessagePartType +) const ( ChatMessagePartTypeText ChatMessagePartType = "text" diff --git a/pkg/instructor/openai.go b/pkg/instructor/openai.go index 8f6eaf2..1dcdc8e 100644 --- a/pkg/instructor/openai.go +++ b/pkg/instructor/openai.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "fmt" + "io" openai "github.com/sashabaranov/go-openai" ) @@ -26,6 +27,10 @@ func NewOpenAIClient(client *openai.Client) (*OpenAIClient, error) { } func (o *OpenAIClient) CreateChatCompletion(ctx context.Context, request Request, mode Mode, schema *Schema) (string, error) { + if request.Stream { + return "", errors.New("streaming is not supported by this method; use CreateChatCompletionStream instead") + } + switch mode { case ModeToolCall: return o.completionToolCall(ctx, request, schema) @@ -40,22 +45,7 @@ func (o *OpenAIClient) CreateChatCompletion(ctx context.Context, request Request func (o *OpenAIClient) completionToolCall(ctx context.Context, request Request, schema *Schema) (string, error) { - tools := []openai.Tool{} - - for _, function := range schema.Functions { - f := openai.FunctionDefinition{ - Name: function.Name, - Description: function.Description, - Parameters: function.Parameters, - } - t := openai.Tool{ - Type: "function", - Function: &f, - } - tools = append(tools, t) - } - - request.Tools = tools + request.Tools = createTools(schema) resp, err := o.Client.CreateChatCompletion(ctx, request) if err != nil { @@ -103,20 +93,8 @@ func (o *OpenAIClient) completionToolCall(ctx context.Context, request Request, } func (o *OpenAIClient) completionJSON(ctx context.Context, request Request, schema *Schema) (string, error) { - message := fmt.Sprintf(` -Please responsd with json in the following json_schema: - -%s - -Make sure to return an instance of the JSON, not the schema itself -`, schema.String) - - msg := Message{ - Role: RoleSystem, - Content: message, - } - request.Messages = prepend(request.Messages, msg) + request.Messages = prepend(request.Messages, *createJSONMessage(schema)) // Set JSON mode request.ResponseFormat = &openai.ChatCompletionResponseFormat{Type: openai.ChatCompletionResponseFormatTypeJSONObject} @@ -133,27 +111,115 @@ Make sure to return an instance of the JSON, not the schema itself func (o *OpenAIClient) completionJSONSchema(ctx context.Context, request Request, schema *Schema) (string, error) { + request.Messages = prepend(request.Messages, *createJSONMessage(schema)) + + resp, err := o.Client.CreateChatCompletion(ctx, request) + if err != nil { + return "", err + } + + text := resp.Choices[0].Message.Content + + return text, nil +} + +func createJSONMessage(schema *Schema) *Message { message := fmt.Sprintf(` -Please responsd with json in the following json_schema: +Please respond with JSON in the following JSON schema: %s Make sure to return an instance of the JSON, not the schema itself `, schema.String) - - msg := Message{ + return &Message{ Role: RoleSystem, Content: message, } +} + +func (o *OpenAIClient) CreateChatCompletionStream(ctx context.Context, request Request, mode Mode, schema *Schema) (<-chan string, error) { + switch mode { + case ModeToolCall: + return o.completionToolCallStream(ctx, request, schema) + case ModeJSON: + return o.completionJSONStream(ctx, request, schema) + case ModeJSONSchema: + return o.completionJSONSchemaStream(ctx, request, schema) + default: + return nil, fmt.Errorf("mode '%s' is not supported for %s", mode, o.Name) + } +} - request.Messages = prepend(request.Messages, msg) +func (o *OpenAIClient) completionToolCallStream(ctx context.Context, request Request, schema *Schema) (<-chan string, error) { + request.Tools = createTools(schema) + return o.createStream(ctx, request) +} - resp, err := o.Client.CreateChatCompletion(ctx, request) - if err != nil { - return "", err +func (o *OpenAIClient) completionJSONStream(ctx context.Context, request Request, schema *Schema) (<-chan string, error) { + request.Messages = prepend(request.Messages, *createJSONMessageStream(schema)) + // Set JSON mode + request.ResponseFormat = &openai.ChatCompletionResponseFormat{Type: openai.ChatCompletionResponseFormatTypeJSONObject} + return o.createStream(ctx, request) +} + +func (o *OpenAIClient) completionJSONSchemaStream(ctx context.Context, request Request, schema *Schema) (<-chan string, error) { + request.Messages = prepend(request.Messages, *createJSONMessageStream(schema)) + return o.createStream(ctx, request) +} + +func createTools(schema *Schema) []openai.Tool { + tools := make([]openai.Tool, 0, len(schema.Functions)) + for _, function := range schema.Functions { + f := openai.FunctionDefinition{ + Name: function.Name, + Description: function.Description, + Parameters: function.Parameters, + } + t := openai.Tool{ + Type: "function", + Function: &f, + } + tools = append(tools, t) } + return tools +} - text := resp.Choices[0].Message.Content +func createJSONMessageStream(schema *Schema) *Message { + message := fmt.Sprintf(` +Please respond with a JSON array where the elements following JSON schema: - return text, nil +%s + +Make sure to return an array with the elements an instance of the JSON, not the schema itself. +`, schema.String) + return &Message{ + Role: RoleSystem, + Content: message, + } +} + +func (o *OpenAIClient) createStream(ctx context.Context, request Request) (<-chan string, error) { + stream, err := o.Client.CreateChatCompletionStream(ctx, request) + if err != nil { + return nil, err + } + + ch := make(chan string) + + go func() { + defer stream.Close() + defer close(ch) + for { + response, err := stream.Recv() + if errors.Is(err, io.EOF) { + return + } + if err != nil { + return + } + text := response.Choices[0].Delta.Content + ch <- text + } + }() + return ch, nil } diff --git a/pkg/instructor/utils.go b/pkg/instructor/utils.go index c4a07a0..7489b60 100644 --- a/pkg/instructor/utils.go +++ b/pkg/instructor/utils.go @@ -5,6 +5,7 @@ import ( "fmt" "io" "net/http" + "strings" ) func toPtr[T any](val T) *T { @@ -45,16 +46,47 @@ func urlToBase64(url string) (string, error) { return base64.StdEncoding.EncodeToString(data), nil } -func min(a, b int) int { - if a < b { - return a +// Removes any prefixes before the JSON (like "Sure, here you go:") +func trimPrefixBeforeJSON(jsonStr string) string { + startObject := strings.IndexByte(jsonStr, '{') + startArray := strings.IndexByte(jsonStr, '[') + + var start int + if startObject == -1 && startArray == -1 { + return jsonStr // No opening brace or bracket found, return the original string + } else if startObject == -1 { + start = startArray + } else if startArray == -1 { + start = startObject + } else { + start = min(startObject, startArray) } - return b + + return jsonStr[start:] } -func max(a, b int) int { - if a > b { - return a +// Removes any postfixes after the JSON +func trimPostfixAfterJSON(jsonStr string) string { + endObject := strings.LastIndexByte(jsonStr, '}') + endArray := strings.LastIndexByte(jsonStr, ']') + + var end int + if endObject == -1 && endArray == -1 { + return jsonStr // No closing brace or bracket found, return the original string + } else if endObject == -1 { + end = endArray + } else if endArray == -1 { + end = endObject + } else { + end = max(endObject, endArray) } - return b + + return jsonStr[:end+1] +} + +// Extracts the JSON by trimming prefixes and postfixes +func extractJSON(jsonStr string) string { + trimmedPrefix := trimPrefixBeforeJSON(jsonStr) + trimmedJSON := trimPostfixAfterJSON(trimmedPrefix) + return trimmedJSON }