diff --git a/README.md b/README.md
index 108b545..778a816 100644
--- a/README.md
+++ b/README.md
@@ -18,6 +18,16 @@ As shown in the example below, by adding extra metadata to each struct field (vi
> For more information on the `jsonschema` tags available, see the [`jsonschema` godoc](https://pkg.go.dev/github.com/invopop/jsonschema?utm_source=godoc).
+
+Running
+
+```bash
+export OPENAI_API_KEY=
+go run examples/user/main.go
+```
+
+
+
```go
package main
@@ -76,9 +86,654 @@ Age: %d
}
```
-## Coming Soon
+### Other Examples
+
+
+Function Calling with OpenAI
+
+
+Running
+
+```bash
+export OPENAI_API_KEY=
+go run examples/function_calling/main.go
+```
+
+
+
+```go
+package main
+
+import (
+ "context"
+ "fmt"
+ "os"
+
+ "github.com/instructor-ai/instructor-go/pkg/instructor"
+ openai "github.com/sashabaranov/go-openai"
+)
+
+type SearchType string
+
+const (
+ Web SearchType = "web"
+ Image SearchType = "image"
+ Video SearchType = "video"
+)
+
+type Search struct {
+ Topic string `json:"topic" jsonschema:"title=Topic,description=Topic of the search,example=golang"`
+ Query string `json:"query" jsonschema:"title=Query,description=Query to search for relevant content,example=what is golang"`
+ Type SearchType `json:"type" jsonschema:"title=Type,description=Type of search,default=web,enum=web,enum=image,enum=video"`
+}
+
+func (s *Search) execute() {
+ fmt.Printf("Searching for `%s` with query `%s` using `%s`\n", s.Topic, s.Query, s.Type)
+}
+
+type Searches = []Search
+
+func segment(ctx context.Context, data string) *Searches {
+
+ client, err := instructor.FromOpenAI(
+ openai.NewClient(os.Getenv("OPENAI_API_KEY")),
+ instructor.WithMode(instructor.ModeToolCall),
+ instructor.WithMaxRetries(3),
+ )
+ if err != nil {
+ panic(err)
+ }
+
+ var searches Searches
+ err = client.CreateChatCompletion(
+ ctx,
+ instructor.Request{
+ Model: openai.GPT4o,
+ Messages: []instructor.Message{
+ {
+ Role: instructor.RoleUser,
+ Content: fmt.Sprintf("Consider the data below: '\n%s' and segment it into multiple search queries", data),
+ },
+ },
+ },
+ &searches,
+ )
+ if err != nil {
+ panic(err)
+ }
+
+ return &searches
+}
+
+func main() {
+ ctx := context.Background()
+
+ q := "Search for a picture of a cat, a video of a dog, and the taxonomy of each"
+ for _, search := range *segment(ctx, q) {
+ search.execute()
+ }
+ /*
+ Searching for `cat` with query `picture of a cat` using `image`
+ Searching for `dog` with query `video of a dog` using `video`
+ Searching for `cat` with query `taxonomy of a cat` using `web`
+ Searching for `dog` with query `taxonomy of a dog` using `web`
+ */
+}
+```
+
+
+
+
+Text Classification with Anthropic
+
+
+Running
+
+```bash
+export ANTHROPIC_API_KEY=
+go run examples/classification/main.go
+```
+
+
+
+```go
+ package main
+
+import (
+ "context"
+ "fmt"
+ "os"
+
+ "github.com/instructor-ai/instructor-go/pkg/instructor"
+ anthropic "github.com/liushuangls/go-anthropic/v2"
+)
+
+type LabelType string
+
+const (
+ LabelTechIssue LabelType = "tech_issue"
+ LabelBilling LabelType = "billing"
+ LabelGeneralQuery LabelType = "general_query"
+)
+
+type Label struct {
+ Type LabelType `json:"type" jsonschema:"title=Label type,description=Type of the label,enum=tech_issue,enum=billing,enum=general_query"`
+}
+
+type Prediction struct {
+ Labels []Label `json:"labels" jsonschema:"title=Predicted labels,description=Labels of the prediction"`
+}
+
+func classify(data string) *Prediction {
+ ctx := context.Background()
+
+ client, err := instructor.FromAnthropic(
+ anthropic.NewClient(os.Getenv("ANTHROPIC_API_KEY")),
+ instructor.WithMode(instructor.ModeToolCall),
+ instructor.WithMaxRetries(3),
+ )
+ if err != nil {
+ panic(err)
+ }
+
+ var prediction Prediction
+ err = client.CreateChatCompletion(
+ ctx,
+ instructor.Request{
+ Model: anthropic.ModelClaude3Haiku20240307,
+ Messages: []instructor.Message{
+ {
+ Role: instructor.RoleUser,
+ Content: fmt.Sprintf("Classify the following support ticket: %s", data),
+ },
+ },
+ },
+ &prediction,
+ )
+ if err != nil {
+ panic(err)
+ }
+
+ return &prediction
+}
+
+func main() {
+
+ ticket := "My account is locked and I can't access my billing info."
+ prediction := classify(ticket)
+
+ assert(prediction.contains(LabelTechIssue), "Expected ticket to be related to tech issue")
+ assert(prediction.contains(LabelBilling), "Expected ticket to be related to billing")
+ assert(!prediction.contains(LabelGeneralQuery), "Expected ticket NOT to be a general query")
+
+ fmt.Printf("%+v\n", prediction)
+ /*
+ &{Labels:[{Type:tech_issue} {Type:billing}]}
+ */
+}
+
+/******/
+
+func (p *Prediction) contains(label LabelType) bool {
+ for _, l := range p.Labels {
+ if l.Type == label {
+ return true
+ }
+ }
+ return false
+}
+
+func assert(condition bool, message string) {
+ if !condition {
+ fmt.Println("Assertion failed:", message)
+ }
+}
+```
+
+
+
+
+Images with OpenAI
+
+
+Running
+
+```bash
+export OPENAI_API_KEY=
+go run examples/images/openai/main.go
+```
+
+
+
+```go
+package main
+
+import (
+ "context"
+ "fmt"
+ "os"
+
+ "github.com/instructor-ai/instructor-go/pkg/instructor"
+ openai "github.com/sashabaranov/go-openai"
+)
+
+type Book struct {
+ Title string `json:"title,omitempty" jsonschema:"title=title,description=The title of the book,example=Harry Potter and the Philosopher's Stone"`
+ Author *string `json:"author,omitempty" jsonschema:"title=author,description=The author of the book,example=J.K. Rowling"`
+}
+
+type BookCatalog struct {
+ Catalog []Book `json:"catalog"`
+}
+
+func (bc *BookCatalog) PrintCatalog() {
+ fmt.Printf("Number of books in the catalog: %d\n\n", len(bc.Catalog))
+ for _, book := range bc.Catalog {
+ fmt.Printf("Title: %s\n", book.Title)
+ fmt.Printf("Author: %s\n", *book.Author)
+ fmt.Println("--------------------")
+ }
+}
+
+func main() {
+ ctx := context.Background()
+
+ client, err := instructor.FromOpenAI(
+ openai.NewClient(os.Getenv("OPENAI_API_KEY")),
+ instructor.WithMode(instructor.ModeJSON),
+ instructor.WithMaxRetries(3),
+ )
+ if err != nil {
+ panic(err)
+ }
+
+ url := "https://utfs.io/f/fe55d6bd-e920-4a6f-8e93-a4c9dd851b90-eivhb2.png"
+
+ var bookCatalog BookCatalog
+ err = client.CreateChatCompletion(
+ ctx,
+ instructor.Request{
+ Model: openai.GPT4o,
+ Messages: []instructor.Message{
+ {
+ Role: instructor.RoleUser,
+ MultiContent: []instructor.ChatMessagePart{
+ {
+ Type: instructor.ChatMessagePartTypeText,
+ Text: "Extract book catelog from the image",
+ },
+ {
+ Type: instructor.ChatMessagePartTypeImageURL,
+ ImageURL: &instructor.ChatMessageImageURL{
+ URL: url,
+ },
+ },
+ },
+ },
+ },
+ },
+ &bookCatalog,
+ )
+
+ if err != nil {
+ panic(err)
+ }
+
+ bookCatalog.PrintCatalog()
+ /*
+ Number of books in the catalog: 15
+
+ Title: Pride and Prejudice
+ Author: Jane Austen
+ --------------------
+ Title: The Great Gatsby
+ Author: F. Scott Fitzgerald
+ --------------------
+ Title: The Catcher in the Rye
+ Author: J. D. Salinger
+ --------------------
+ Title: Don Quixote
+ Author: Miguel de Cervantes
+ --------------------
+ Title: One Hundred Years of Solitude
+ Author: Gabriel García Márquez
+ --------------------
+ Title: To Kill a Mockingbird
+ Author: Harper Lee
+ --------------------
+ Title: Beloved
+ Author: Toni Morrison
+ --------------------
+ Title: Ulysses
+ Author: James Joyce
+ --------------------
+ Title: Harry Potter and the Cursed Child
+ Author: J.K. Rowling
+ --------------------
+ Title: The Grapes of Wrath
+ Author: John Steinbeck
+ --------------------
+ Title: 1984
+ Author: George Orwell
+ --------------------
+ Title: Lolita
+ Author: Vladimir Nabokov
+ --------------------
+ Title: Anna Karenina
+ Author: Leo Tolstoy
+ --------------------
+ Title: Moby-Dick
+ Author: Herman Melville
+ --------------------
+ Title: Wuthering Heights
+ Author: Emily Brontë
+ --------------------
+ */
+}
+```
+
+
+
+
+Images with Anthropic
+
+
+Running
+
+```bash
+export ANTHROPIC_API_KEY=
+go run examples/images/anthropic/main.go
+```
+
+
+
+```go
+package main
+
+import (
+ "context"
+ "fmt"
+ "os"
+
+ "github.com/instructor-ai/instructor-go/pkg/instructor"
+ "github.com/liushuangls/go-anthropic/v2"
+)
+
+type Movie struct {
+ Title string `json:"title" jsonschema:"title=title,description=The title of the movie,required=true,example=Ex Machina"`
+ Year int `json:"year,omitempty" jsonschema:"title=year,description=The year of the movie,required=false,example=2014"`
+}
+
+type MovieCatalog struct {
+ Catalog []Movie `json:"catalog"`
+}
+
+func (bc *MovieCatalog) PrintCatalog() {
+ fmt.Printf("Number of movies in the catalog: %d\n\n", len(bc.Catalog))
+ for _, movie := range bc.Catalog {
+ fmt.Printf("Title: %s\n", movie.Title)
+ if movie.Year != 0 {
+ fmt.Printf("Year: %d\n", movie.Year)
+ }
+ fmt.Println("--------------------")
+ }
+}
+
+func main() {
+ ctx := context.Background()
+
+ client, err := instructor.FromAnthropic(
+ anthropic.NewClient(os.Getenv("ANTHROPIC_API_KEY")),
+ instructor.WithMode(instructor.ModeJSONSchema),
+ instructor.WithMaxRetries(3),
+ )
+ if err != nil {
+ panic(err)
+ }
+
+ url := "https://utfs.io/f/bd0dbae6-27e3-4604-b640-fd2ffea891b8-fxyywt.jpeg"
+
+ var movieCatalog MovieCatalog
+ err = client.CreateChatCompletion(
+ ctx,
+ instructor.Request{
+ Model: "claude-3-haiku-20240307",
+ Messages: []instructor.Message{
+ {
+ Role: instructor.RoleUser,
+ MultiContent: []instructor.ChatMessagePart{
+ {
+ Type: instructor.ChatMessagePartTypeText,
+ Text: "Extract the movie catalog from the screenshot",
+ },
+ {
+ Type: instructor.ChatMessagePartTypeImageURL,
+ ImageURL: &instructor.ChatMessageImageURL{
+ URL: url,
+ },
+ },
+ },
+ },
+ },
+ },
+ &movieCatalog,
+ )
+ if err != nil {
+ panic(err)
+ }
+
+ movieCatalog.PrintCatalog()
+ /*
+ Number of movies in the catalog: 18
+
+ Title: Oppenheimer
+ Year: 2023
+ --------------------
+ Title: The Dark Knight
+ Year: 2008
+ --------------------
+ Title: Interstellar
+ Year: 2014
+ --------------------
+ Title: Inception
+ Year: 2010
+ --------------------
+ Title: Tenet
+ Year: 2020
+ --------------------
+ Title: Dunkirk
+ Year: 2017
+ --------------------
+ Title: Memento
+ Year: 2000
+ --------------------
+ Title: The Dark Knight Rises
+ Year: 2012
+ --------------------
+ Title: Batman Begins
+ Year: 2005
+ --------------------
+ Title: The Prestige
+ Year: 2006
+ --------------------
+ Title: Insomnia
+ Year: 2002
+ --------------------
+ Title: Following
+ Year: 1998
+ --------------------
+ Title: Man of Steel
+ Year: 2013
+ --------------------
+ Title: Transcendence
+ Year: 2014
+ --------------------
+ Title: Justice League
+ Year: 2017
+ --------------------
+ Title: Batman v Superman: Dawn of Justice
+ Year: 2016
+ --------------------
+ Title: Ending the Knight
+ Year: 2016
+ --------------------
+ Title: Larceny
+ --------------------
+ */
+}
+```
+
+
+
+
+Streaming with OpenAI
+
+
+Running
+
+```bash
+export OPENAI_API_KEY=
+go run examples/streaming/main.go
+```
+
+
+
+```go
+package main
+
+import (
+ "context"
+ "fmt"
+ "os"
+
+ "github.com/instructor-ai/instructor-go/pkg/instructor"
+ openai "github.com/sashabaranov/go-openai"
+)
+
+type Product struct {
+ ID string `json:"product_id" jsonschema:"title=Product ID,description=ID of the product,required=True"`
+ Name string `json:"product_name" jsonschema:"title=Product Name,description=Name of the product,required=True"`
+}
+
+func (p *Product) String() string {
+ return fmt.Sprintf("Product [ID: %s, Name: %s]", p.ID, p.Name)
+}
+
+type Recommendation struct {
+ Product
+ Reason string `json:"reason" jsonschema:"title=Recommendation Reason,description=Reason for the product recommendation"`
+}
+
+func (r *Recommendation) String() string {
+ return fmt.Sprintf(`
+Recommendation [
+ %s
+ Reason [%s]
+]`, r.Product.String(), r.Reason)
+}
+
+func main() {
+ ctx := context.Background()
+
+ client, err := instructor.FromOpenAI(
+ openai.NewClient(os.Getenv("OPENAI_API_KEY")),
+ instructor.WithMode(instructor.ModeJSON),
+ )
+ if err != nil {
+ panic(err)
+ }
+
+ profileData := `
+Customer ID: 12345
+Recent Purchases: [Laptop, Wireless Headphones, Smart Watch]
+Frequently Browsed Categories: [Electronics, Books, Fitness Equipment]
+Product Ratings: {Laptop: 5 stars, Wireless Headphones: 4 stars}
+Recent Search History: [best budget laptops 2023, latest sci-fi books, yoga mats]
+Preferred Brands: [Apple, AllBirds, Bench]
+Responses to Previous Recommendations: {Philips: Not Interested, Adidas: Not Interested}
+Loyalty Program Status: Gold Member
+Average Monthly Spend: $500
+Preferred Shopping Times: Weekend Evenings
+...
+`
+
+ products := []Product{
+ {ID: "1", Name: "Sony WH-1000XM4 Wireless Headphones - Noise-canceling, long battery life"},
+ {ID: "2", Name: "Apple Watch Series 7 - Advanced fitness tracking, seamless integration with Apple ecosystem"},
+ {ID: "3", Name: "Kindle Oasis - Premium e-reader with adjustable warm light"},
+ {ID: "4", Name: "AllBirds Wool Runners - Comfortable, eco-friendly sneakers"},
+ {ID: "5", Name: "Manduka PRO Yoga Mat - High-quality, durable, eco-friendly"},
+ {ID: "6", Name: "Bench Hooded Jacket - Stylish, durable, suitable for outdoor activities"},
+ {ID: "7", Name: "Apple MacBook Air (2023) - Latest model, high performance, portable"},
+ {ID: "8", Name: "GoPro HERO9 Black - 5K video, waterproof, for action photography"},
+ {ID: "9", Name: "Nespresso Vertuo Next Coffee Machine - Quality coffee, easy to use, compact design"},
+ {ID: "10", Name: "Project Hail Mary by Andy Weir - Latest sci-fi book from a renowned author"},
+ }
+
+ productList := ""
+ for _, product := range products {
+ productList += product.String() + "\n"
+ }
+
+ recommendationChan, err := client.CreateChatCompletionStream(
+ ctx,
+ instructor.Request{
+ Model: openai.GPT4o20240513,
+ Messages: []instructor.Message{
+ {
+ Role: instructor.RoleSystem,
+ Content: fmt.Sprintf(`Generate the product recommendations from the product list based on the customer profile.
+Return in order of highest recommended first.
+Product list:
+%s`, productList),
+ },
+ {
+ Role: instructor.RoleUser,
+ Content: fmt.Sprintf("User profile:\n%s", profileData),
+ },
+ },
+ Stream: true,
+ },
+ *new(Recommendation),
+ )
+ if err != nil {
+ panic(err)
+ }
+
+ for instance := range recommendationChan {
+ recommendation, _ := instance.(*Recommendation)
+ println(recommendation.String())
+ }
+ /*
+ Recommendation [
+ Product [ID: 7, Name: Apple MacBook Air (2023) - Latest model, high performance, portable]
+ Reason [As you have recently searched for budget laptops of 2023 and previously purchased a laptop, we believe the latest Apple MacBook Air will meet your high-performance requirements. Additionally, Apple is one of your preferred brands.]
+ ]
+
+ Recommendation [
+ Product [ID: 2, Name: Apple Watch Series 7 - Advanced fitness tracking, seamless integration with Apple ecosystem]
+ Reason [Based on your recent purchase history which includes a smart watch and your preference for Apple products, we recommend the Apple Watch Series 7 for its advanced fitness tracking features.]
+ ]
+
+ Recommendation [
+ Product [ID: 10, Name: Project Hail Mary by Andy Weir - Latest sci-fi book from a renowned author]
+ Reason [Given your recent search for the latest sci-fi books and frequent browsing in the Books category, 'Project Hail Mary' by Andy Weir may interest you.]
+ ]
+
+ Recommendation [
+ Product [ID: 5, Name: Manduka PRO Yoga Mat - High-quality, durable, eco-friendly]
+ Reason [Since you recently searched for yoga mats and frequently browse fitness equipment, we recommend the Manduka PRO Yoga Mat to support your fitness activities.]
+ ]
+
+ Recommendation [
+ Product [ID: 4, Name: AllBirds Wool Runners - Comfortable, eco-friendly sneakers]
+ Reason [Considering your preference for the AllBirds brand and your frequent browsing in fitness categories, the AllBirds Wool Runners would be a great fit for your lifestyle.]
+ ]
+ */
+}
+```
-1. Streaming support
+
## Providers
diff --git a/examples/function_calling/main.go b/examples/function_calling/main.go
index 1e6d7f2..9f9eed4 100644
--- a/examples/function_calling/main.go
+++ b/examples/function_calling/main.go
@@ -29,10 +29,6 @@ func (s *Search) execute() {
type Searches = []Search
-// type Searches struct {
-// Items []Search `json:"searches" jsonschema:"title=Searches,description=A list of search results"`
-// }
-
func segment(ctx context.Context, data string) *Searches {
client, err := instructor.FromOpenAI(
diff --git a/examples/streaming/main.go b/examples/streaming/main.go
new file mode 100644
index 0000000..5a5220f
--- /dev/null
+++ b/examples/streaming/main.go
@@ -0,0 +1,132 @@
+package main
+
+import (
+ "context"
+ "fmt"
+ "os"
+
+ "github.com/instructor-ai/instructor-go/pkg/instructor"
+ openai "github.com/sashabaranov/go-openai"
+)
+
+type Product struct {
+ ID string `json:"product_id" jsonschema:"title=Product ID,description=ID of the product,required=True"`
+ Name string `json:"product_name" jsonschema:"title=Product Name,description=Name of the product,required=True"`
+}
+
+func (p *Product) String() string {
+ return fmt.Sprintf("Product [ID: %s, Name: %s]", p.ID, p.Name)
+}
+
+type Recommendation struct {
+ Product
+ Reason string `json:"reason" jsonschema:"title=Recommendation Reason,description=Reason for the product recommendation"`
+}
+
+func (r *Recommendation) String() string {
+ return fmt.Sprintf(`
+Recommendation [
+ %s
+ Reason [%s]
+]`, r.Product.String(), r.Reason)
+}
+
+func main() {
+ ctx := context.Background()
+
+ client, err := instructor.FromOpenAI(
+ openai.NewClient(os.Getenv("OPENAI_API_KEY")),
+ instructor.WithMode(instructor.ModeJSON),
+ )
+ if err != nil {
+ panic(err)
+ }
+
+ profileData := `
+Customer ID: 12345
+Recent Purchases: [Laptop, Wireless Headphones, Smart Watch]
+Frequently Browsed Categories: [Electronics, Books, Fitness Equipment]
+Product Ratings: {Laptop: 5 stars, Wireless Headphones: 4 stars}
+Recent Search History: [best budget laptops 2023, latest sci-fi books, yoga mats]
+Preferred Brands: [Apple, AllBirds, Bench]
+Responses to Previous Recommendations: {Philips: Not Interested, Adidas: Not Interested}
+Loyalty Program Status: Gold Member
+Average Monthly Spend: $500
+Preferred Shopping Times: Weekend Evenings
+...
+`
+
+ products := []Product{
+ {ID: "1", Name: "Sony WH-1000XM4 Wireless Headphones - Noise-canceling, long battery life"},
+ {ID: "2", Name: "Apple Watch Series 7 - Advanced fitness tracking, seamless integration with Apple ecosystem"},
+ {ID: "3", Name: "Kindle Oasis - Premium e-reader with adjustable warm light"},
+ {ID: "4", Name: "AllBirds Wool Runners - Comfortable, eco-friendly sneakers"},
+ {ID: "5", Name: "Manduka PRO Yoga Mat - High-quality, durable, eco-friendly"},
+ {ID: "6", Name: "Bench Hooded Jacket - Stylish, durable, suitable for outdoor activities"},
+ {ID: "7", Name: "Apple MacBook Air (2023) - Latest model, high performance, portable"},
+ {ID: "8", Name: "GoPro HERO9 Black - 5K video, waterproof, for action photography"},
+ {ID: "9", Name: "Nespresso Vertuo Next Coffee Machine - Quality coffee, easy to use, compact design"},
+ {ID: "10", Name: "Project Hail Mary by Andy Weir - Latest sci-fi book from a renowned author"},
+ }
+
+ productList := ""
+ for _, product := range products {
+ productList += product.String() + "\n"
+ }
+
+ recommendationChan, err := client.CreateChatCompletionStream(
+ ctx,
+ instructor.Request{
+ Model: openai.GPT4o20240513,
+ Messages: []instructor.Message{
+ {
+ Role: instructor.RoleSystem,
+ Content: fmt.Sprintf(`Generate the product recommendations from the product list based on the customer profile.
+Return in order of highest recommended first.
+Product list:
+%s`, productList),
+ },
+ {
+ Role: instructor.RoleUser,
+ Content: fmt.Sprintf("User profile:\n%s", profileData),
+ },
+ },
+ Stream: true,
+ },
+ *new(Recommendation),
+ )
+ if err != nil {
+ panic(err)
+ }
+
+ for instance := range recommendationChan {
+ recommendation, _ := instance.(*Recommendation)
+ println(recommendation.String())
+ }
+ /*
+ Recommendation [
+ Product [ID: 7, Name: Apple MacBook Air (2023) - Latest model, high performance, portable]
+ Reason [As you have recently searched for budget laptops of 2023 and previously purchased a laptop, we believe the latest Apple MacBook Air will meet your high-performance requirements. Additionally, Apple is one of your preferred brands.]
+ ]
+
+ Recommendation [
+ Product [ID: 2, Name: Apple Watch Series 7 - Advanced fitness tracking, seamless integration with Apple ecosystem]
+ Reason [Based on your recent purchase history which includes a smart watch and your preference for Apple products, we recommend the Apple Watch Series 7 for its advanced fitness tracking features.]
+ ]
+
+ Recommendation [
+ Product [ID: 10, Name: Project Hail Mary by Andy Weir - Latest sci-fi book from a renowned author]
+ Reason [Given your recent search for the latest sci-fi books and frequent browsing in the Books category, 'Project Hail Mary' by Andy Weir may interest you.]
+ ]
+
+ Recommendation [
+ Product [ID: 5, Name: Manduka PRO Yoga Mat - High-quality, durable, eco-friendly]
+ Reason [Since you recently searched for yoga mats and frequently browse fitness equipment, we recommend the Manduka PRO Yoga Mat to support your fitness activities.]
+ ]
+
+ Recommendation [
+ Product [ID: 4, Name: AllBirds Wool Runners - Comfortable, eco-friendly sneakers]
+ Reason [Considering your preference for the AllBirds brand and your frequent browsing in fitness categories, the AllBirds Wool Runners would be a great fit for your lifestyle.]
+ ]
+ */
+}
diff --git a/pkg/instructor/anthropic.go b/pkg/instructor/anthropic.go
index 5906ba2..8b4b208 100644
--- a/pkg/instructor/anthropic.go
+++ b/pkg/instructor/anthropic.go
@@ -7,6 +7,7 @@ import (
"fmt"
anthropic "github.com/liushuangls/go-anthropic/v2"
+ openai "github.com/sashabaranov/go-openai"
)
type AnthropicClient struct {
@@ -163,3 +164,7 @@ func toAnthropicMessages(request *Request) (*[]anthropic.Message, error) {
return &messages, nil
}
+
+func (a *AnthropicClient) CreateChatCompletionStream(ctx context.Context, request openai.ChatCompletionRequest, mode string, schema *Schema) (<-chan string, error) {
+ panic("unimplemented")
+}
diff --git a/pkg/instructor/client.go b/pkg/instructor/client.go
index dfeb75a..7d7a657 100644
--- a/pkg/instructor/client.go
+++ b/pkg/instructor/client.go
@@ -12,10 +12,10 @@ type Client interface {
schema *Schema,
) (string, error)
- // TODO: implement streaming
- // CreateChatCompletionStream(
- // ctx context.Context,
- // request ChatCompletionRequest,
- // opts ...ClientOptions,
- // ) (*T, error)
+ CreateChatCompletionStream(
+ ctx context.Context,
+ request Request,
+ mode Mode,
+ schema *Schema,
+ ) (<-chan string, error)
}
diff --git a/pkg/instructor/instructor.go b/pkg/instructor/instructor.go
index e119e21..c1a2fcb 100644
--- a/pkg/instructor/instructor.go
+++ b/pkg/instructor/instructor.go
@@ -90,59 +90,116 @@ func (i *Instructor) CreateChatCompletion(ctx context.Context, request Request,
return errors.New("hit max retry attempts")
}
-func processResponse(responseStr string, response *any) error {
+const WRAPPER_END = `"items": [`
- err := json.Unmarshal([]byte(responseStr), response)
+type StreamWrapper[T any] struct {
+ Items []T `json:"items"`
+}
+
+func (i *Instructor) CreateChatCompletionStream(ctx context.Context, request Request, response any) (chan any, error) {
+
+ responseType := reflect.TypeOf(response)
+
+ streamWrapperType := reflect.StructOf([]reflect.StructField{
+ {
+ Name: "Items",
+ Type: reflect.SliceOf(responseType),
+ Tag: `json:"items"`,
+ Anonymous: false,
+ },
+ })
+
+ schema, err := NewSchema(streamWrapperType)
if err != nil {
- return err
+ return nil, err
}
- // TODO: if direct unmarshal fails: check common erors like wrapping struct with key name of struct, instead of just the value
-
- return nil
-}
+ request.Stream = true
-// Removes any prefixes before the JSON (like "Sure, here you go:")
-func trimPrefix(jsonStr string) string {
- startObject := strings.IndexByte(jsonStr, '{')
- startArray := strings.IndexByte(jsonStr, '[')
-
- var start int
- if startObject == -1 && startArray == -1 {
- return jsonStr // No opening brace or bracket found, return the original string
- } else if startObject == -1 {
- start = startArray
- } else if startArray == -1 {
- start = startObject
- } else {
- start = min(startObject, startArray)
+ ch, err := i.Client.CreateChatCompletionStream(ctx, request, i.Mode, schema)
+ if err != nil {
+ return nil, err
}
- return jsonStr[start:]
+ parsedChan := make(chan any) // Buffered channel for parsed objects
+
+ go func() {
+ defer close(parsedChan)
+ var buffer strings.Builder
+ inArray := false
+
+ for {
+ select {
+ case <-ctx.Done():
+ return
+ case text, ok := <-ch:
+ if !ok {
+ // Steeam closed
+
+ // Get last element out of stream wrapper
+
+ data := buffer.String()
+
+ if idx := strings.LastIndex(data, "]"); idx != -1 {
+ data = data[:idx] + data[idx+1:]
+ }
+
+ // Process the remaining data in the buffer
+ decoder := json.NewDecoder(strings.NewReader(data))
+ for decoder.More() {
+ instance := reflect.New(responseType).Interface()
+ err := decoder.Decode(instance)
+ if err != nil {
+ break
+ }
+ parsedChan <- instance
+ }
+ return
+ }
+ buffer.WriteString(text)
+
+ // eat all input until elements stream starts
+ if !inArray {
+ idx := strings.Index(buffer.String(), WRAPPER_END)
+ if idx == -1 {
+ continue
+ }
+
+ inArray = true
+ bufferStr := buffer.String()
+ trimmed := strings.TrimSpace(bufferStr[idx+len(WRAPPER_END):])
+ buffer.Reset()
+ buffer.WriteString(trimmed)
+ }
+
+ data := buffer.String()
+ decoder := json.NewDecoder(strings.NewReader(data))
+
+ for decoder.More() {
+ instance := reflect.New(responseType).Interface()
+ err := decoder.Decode(instance)
+ if err != nil {
+ break
+ }
+ parsedChan <- instance
+
+ buffer.Reset()
+ buffer.WriteString(data[len(data):])
+ }
+ }
+ }
+ }()
+
+ return parsedChan, nil
}
-// Removes any postfixes after the JSON
-func trimPostfix(jsonStr string) string {
- endObject := strings.LastIndexByte(jsonStr, '}')
- endArray := strings.LastIndexByte(jsonStr, ']')
-
- var end int
- if endObject == -1 && endArray == -1 {
- return jsonStr // No closing brace or bracket found, return the original string
- } else if endObject == -1 {
- end = endArray
- } else if endArray == -1 {
- end = endObject
- } else {
- end = max(endObject, endArray)
+func processResponse(responseStr string, response *any) error {
+ err := json.Unmarshal([]byte(responseStr), response)
+ if err != nil {
+ return err
}
- return jsonStr[:end+1]
-}
+ // TODO: if direct unmarshal fails: check common errors like wrapping struct with key name of struct, instead of just the value
-// Extracts the JSON by trimming prefixes and postfixes
-func extractJSON(jsonStr string) string {
- trimmedPrefix := trimPrefix(jsonStr)
- trimmedJSON := trimPostfix(trimmedPrefix)
- return trimmedJSON
+ return nil
}
diff --git a/pkg/instructor/messages.go b/pkg/instructor/messages.go
index 637d8b6..fec52d9 100644
--- a/pkg/instructor/messages.go
+++ b/pkg/instructor/messages.go
@@ -4,15 +4,13 @@ import (
openai "github.com/sashabaranov/go-openai"
)
-type Message = openai.ChatCompletionMessage
-
-type Request = openai.ChatCompletionRequest
-
-type ChatMessagePart = openai.ChatMessagePart
-
-type ChatMessageImageURL = openai.ChatMessageImageURL
-
-type ChatMessagePartType = openai.ChatMessagePartType
+type (
+ Message = openai.ChatCompletionMessage
+ Request = openai.ChatCompletionRequest
+ ChatMessagePart = openai.ChatMessagePart
+ ChatMessageImageURL = openai.ChatMessageImageURL
+ ChatMessagePartType = openai.ChatMessagePartType
+)
const (
ChatMessagePartTypeText ChatMessagePartType = "text"
diff --git a/pkg/instructor/openai.go b/pkg/instructor/openai.go
index 8f6eaf2..1dcdc8e 100644
--- a/pkg/instructor/openai.go
+++ b/pkg/instructor/openai.go
@@ -5,6 +5,7 @@ import (
"encoding/json"
"errors"
"fmt"
+ "io"
openai "github.com/sashabaranov/go-openai"
)
@@ -26,6 +27,10 @@ func NewOpenAIClient(client *openai.Client) (*OpenAIClient, error) {
}
func (o *OpenAIClient) CreateChatCompletion(ctx context.Context, request Request, mode Mode, schema *Schema) (string, error) {
+ if request.Stream {
+ return "", errors.New("streaming is not supported by this method; use CreateChatCompletionStream instead")
+ }
+
switch mode {
case ModeToolCall:
return o.completionToolCall(ctx, request, schema)
@@ -40,22 +45,7 @@ func (o *OpenAIClient) CreateChatCompletion(ctx context.Context, request Request
func (o *OpenAIClient) completionToolCall(ctx context.Context, request Request, schema *Schema) (string, error) {
- tools := []openai.Tool{}
-
- for _, function := range schema.Functions {
- f := openai.FunctionDefinition{
- Name: function.Name,
- Description: function.Description,
- Parameters: function.Parameters,
- }
- t := openai.Tool{
- Type: "function",
- Function: &f,
- }
- tools = append(tools, t)
- }
-
- request.Tools = tools
+ request.Tools = createTools(schema)
resp, err := o.Client.CreateChatCompletion(ctx, request)
if err != nil {
@@ -103,20 +93,8 @@ func (o *OpenAIClient) completionToolCall(ctx context.Context, request Request,
}
func (o *OpenAIClient) completionJSON(ctx context.Context, request Request, schema *Schema) (string, error) {
- message := fmt.Sprintf(`
-Please responsd with json in the following json_schema:
-
-%s
-
-Make sure to return an instance of the JSON, not the schema itself
-`, schema.String)
-
- msg := Message{
- Role: RoleSystem,
- Content: message,
- }
- request.Messages = prepend(request.Messages, msg)
+ request.Messages = prepend(request.Messages, *createJSONMessage(schema))
// Set JSON mode
request.ResponseFormat = &openai.ChatCompletionResponseFormat{Type: openai.ChatCompletionResponseFormatTypeJSONObject}
@@ -133,27 +111,115 @@ Make sure to return an instance of the JSON, not the schema itself
func (o *OpenAIClient) completionJSONSchema(ctx context.Context, request Request, schema *Schema) (string, error) {
+ request.Messages = prepend(request.Messages, *createJSONMessage(schema))
+
+ resp, err := o.Client.CreateChatCompletion(ctx, request)
+ if err != nil {
+ return "", err
+ }
+
+ text := resp.Choices[0].Message.Content
+
+ return text, nil
+}
+
+func createJSONMessage(schema *Schema) *Message {
message := fmt.Sprintf(`
-Please responsd with json in the following json_schema:
+Please respond with JSON in the following JSON schema:
%s
Make sure to return an instance of the JSON, not the schema itself
`, schema.String)
-
- msg := Message{
+ return &Message{
Role: RoleSystem,
Content: message,
}
+}
+
+func (o *OpenAIClient) CreateChatCompletionStream(ctx context.Context, request Request, mode Mode, schema *Schema) (<-chan string, error) {
+ switch mode {
+ case ModeToolCall:
+ return o.completionToolCallStream(ctx, request, schema)
+ case ModeJSON:
+ return o.completionJSONStream(ctx, request, schema)
+ case ModeJSONSchema:
+ return o.completionJSONSchemaStream(ctx, request, schema)
+ default:
+ return nil, fmt.Errorf("mode '%s' is not supported for %s", mode, o.Name)
+ }
+}
- request.Messages = prepend(request.Messages, msg)
+func (o *OpenAIClient) completionToolCallStream(ctx context.Context, request Request, schema *Schema) (<-chan string, error) {
+ request.Tools = createTools(schema)
+ return o.createStream(ctx, request)
+}
- resp, err := o.Client.CreateChatCompletion(ctx, request)
- if err != nil {
- return "", err
+func (o *OpenAIClient) completionJSONStream(ctx context.Context, request Request, schema *Schema) (<-chan string, error) {
+ request.Messages = prepend(request.Messages, *createJSONMessageStream(schema))
+ // Set JSON mode
+ request.ResponseFormat = &openai.ChatCompletionResponseFormat{Type: openai.ChatCompletionResponseFormatTypeJSONObject}
+ return o.createStream(ctx, request)
+}
+
+func (o *OpenAIClient) completionJSONSchemaStream(ctx context.Context, request Request, schema *Schema) (<-chan string, error) {
+ request.Messages = prepend(request.Messages, *createJSONMessageStream(schema))
+ return o.createStream(ctx, request)
+}
+
+func createTools(schema *Schema) []openai.Tool {
+ tools := make([]openai.Tool, 0, len(schema.Functions))
+ for _, function := range schema.Functions {
+ f := openai.FunctionDefinition{
+ Name: function.Name,
+ Description: function.Description,
+ Parameters: function.Parameters,
+ }
+ t := openai.Tool{
+ Type: "function",
+ Function: &f,
+ }
+ tools = append(tools, t)
}
+ return tools
+}
- text := resp.Choices[0].Message.Content
+func createJSONMessageStream(schema *Schema) *Message {
+ message := fmt.Sprintf(`
+Please respond with a JSON array where the elements following JSON schema:
- return text, nil
+%s
+
+Make sure to return an array with the elements an instance of the JSON, not the schema itself.
+`, schema.String)
+ return &Message{
+ Role: RoleSystem,
+ Content: message,
+ }
+}
+
+func (o *OpenAIClient) createStream(ctx context.Context, request Request) (<-chan string, error) {
+ stream, err := o.Client.CreateChatCompletionStream(ctx, request)
+ if err != nil {
+ return nil, err
+ }
+
+ ch := make(chan string)
+
+ go func() {
+ defer stream.Close()
+ defer close(ch)
+ for {
+ response, err := stream.Recv()
+ if errors.Is(err, io.EOF) {
+ return
+ }
+ if err != nil {
+ return
+ }
+ text := response.Choices[0].Delta.Content
+ ch <- text
+ }
+ }()
+ return ch, nil
}
diff --git a/pkg/instructor/utils.go b/pkg/instructor/utils.go
index c4a07a0..7489b60 100644
--- a/pkg/instructor/utils.go
+++ b/pkg/instructor/utils.go
@@ -5,6 +5,7 @@ import (
"fmt"
"io"
"net/http"
+ "strings"
)
func toPtr[T any](val T) *T {
@@ -45,16 +46,47 @@ func urlToBase64(url string) (string, error) {
return base64.StdEncoding.EncodeToString(data), nil
}
-func min(a, b int) int {
- if a < b {
- return a
+// Removes any prefixes before the JSON (like "Sure, here you go:")
+func trimPrefixBeforeJSON(jsonStr string) string {
+ startObject := strings.IndexByte(jsonStr, '{')
+ startArray := strings.IndexByte(jsonStr, '[')
+
+ var start int
+ if startObject == -1 && startArray == -1 {
+ return jsonStr // No opening brace or bracket found, return the original string
+ } else if startObject == -1 {
+ start = startArray
+ } else if startArray == -1 {
+ start = startObject
+ } else {
+ start = min(startObject, startArray)
}
- return b
+
+ return jsonStr[start:]
}
-func max(a, b int) int {
- if a > b {
- return a
+// Removes any postfixes after the JSON
+func trimPostfixAfterJSON(jsonStr string) string {
+ endObject := strings.LastIndexByte(jsonStr, '}')
+ endArray := strings.LastIndexByte(jsonStr, ']')
+
+ var end int
+ if endObject == -1 && endArray == -1 {
+ return jsonStr // No closing brace or bracket found, return the original string
+ } else if endObject == -1 {
+ end = endArray
+ } else if endArray == -1 {
+ end = endObject
+ } else {
+ end = max(endObject, endArray)
}
- return b
+
+ return jsonStr[:end+1]
+}
+
+// Extracts the JSON by trimming prefixes and postfixes
+func extractJSON(jsonStr string) string {
+ trimmedPrefix := trimPrefixBeforeJSON(jsonStr)
+ trimmedJSON := trimPostfixAfterJSON(trimmedPrefix)
+ return trimmedJSON
}