Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

openai: Add Streaming Support #20

Merged
merged 5 commits into from
May 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
659 changes: 657 additions & 2 deletions README.md

Large diffs are not rendered by default.

4 changes: 0 additions & 4 deletions examples/function_calling/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,6 @@ func (s *Search) execute() {

type Searches = []Search

// type Searches struct {
// Items []Search `json:"searches" jsonschema:"title=Searches,description=A list of search results"`
// }

func segment(ctx context.Context, data string) *Searches {

client, err := instructor.FromOpenAI(
Expand Down
132 changes: 132 additions & 0 deletions examples/streaming/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
package main

import (
"context"
"fmt"
"os"

"github.com/instructor-ai/instructor-go/pkg/instructor"
openai "github.com/sashabaranov/go-openai"
)

type Product struct {
ID string `json:"product_id" jsonschema:"title=Product ID,description=ID of the product,required=True"`
Name string `json:"product_name" jsonschema:"title=Product Name,description=Name of the product,required=True"`
}

func (p *Product) String() string {
return fmt.Sprintf("Product [ID: %s, Name: %s]", p.ID, p.Name)
}

type Recommendation struct {
Product
Reason string `json:"reason" jsonschema:"title=Recommendation Reason,description=Reason for the product recommendation"`
}

func (r *Recommendation) String() string {
return fmt.Sprintf(`
Recommendation [
%s
Reason [%s]
]`, r.Product.String(), r.Reason)
}

func main() {
ctx := context.Background()

client, err := instructor.FromOpenAI(
openai.NewClient(os.Getenv("OPENAI_API_KEY")),
instructor.WithMode(instructor.ModeJSON),
)
if err != nil {
panic(err)
}

profileData := `
Customer ID: 12345
Recent Purchases: [Laptop, Wireless Headphones, Smart Watch]
Frequently Browsed Categories: [Electronics, Books, Fitness Equipment]
Product Ratings: {Laptop: 5 stars, Wireless Headphones: 4 stars}
Recent Search History: [best budget laptops 2023, latest sci-fi books, yoga mats]
Preferred Brands: [Apple, AllBirds, Bench]
Responses to Previous Recommendations: {Philips: Not Interested, Adidas: Not Interested}
Loyalty Program Status: Gold Member
Average Monthly Spend: $500
Preferred Shopping Times: Weekend Evenings
...
`

products := []Product{
{ID: "1", Name: "Sony WH-1000XM4 Wireless Headphones - Noise-canceling, long battery life"},
{ID: "2", Name: "Apple Watch Series 7 - Advanced fitness tracking, seamless integration with Apple ecosystem"},
{ID: "3", Name: "Kindle Oasis - Premium e-reader with adjustable warm light"},
{ID: "4", Name: "AllBirds Wool Runners - Comfortable, eco-friendly sneakers"},
{ID: "5", Name: "Manduka PRO Yoga Mat - High-quality, durable, eco-friendly"},
{ID: "6", Name: "Bench Hooded Jacket - Stylish, durable, suitable for outdoor activities"},
{ID: "7", Name: "Apple MacBook Air (2023) - Latest model, high performance, portable"},
{ID: "8", Name: "GoPro HERO9 Black - 5K video, waterproof, for action photography"},
{ID: "9", Name: "Nespresso Vertuo Next Coffee Machine - Quality coffee, easy to use, compact design"},
{ID: "10", Name: "Project Hail Mary by Andy Weir - Latest sci-fi book from a renowned author"},
}

productList := ""
for _, product := range products {
productList += product.String() + "\n"
}

recommendationChan, err := client.CreateChatCompletionStream(
ctx,
instructor.Request{
Model: openai.GPT4o20240513,
Messages: []instructor.Message{
{
Role: instructor.RoleSystem,
Content: fmt.Sprintf(`Generate the product recommendations from the product list based on the customer profile.
Return in order of highest recommended first.
Product list:
%s`, productList),
},
{
Role: instructor.RoleUser,
Content: fmt.Sprintf("User profile:\n%s", profileData),
},
},
Stream: true,
},
*new(Recommendation),
)
if err != nil {
panic(err)
}

for instance := range recommendationChan {
recommendation, _ := instance.(*Recommendation)
println(recommendation.String())
}
/*
Recommendation [
Product [ID: 7, Name: Apple MacBook Air (2023) - Latest model, high performance, portable]
Reason [As you have recently searched for budget laptops of 2023 and previously purchased a laptop, we believe the latest Apple MacBook Air will meet your high-performance requirements. Additionally, Apple is one of your preferred brands.]
]

Recommendation [
Product [ID: 2, Name: Apple Watch Series 7 - Advanced fitness tracking, seamless integration with Apple ecosystem]
Reason [Based on your recent purchase history which includes a smart watch and your preference for Apple products, we recommend the Apple Watch Series 7 for its advanced fitness tracking features.]
]

Recommendation [
Product [ID: 10, Name: Project Hail Mary by Andy Weir - Latest sci-fi book from a renowned author]
Reason [Given your recent search for the latest sci-fi books and frequent browsing in the Books category, 'Project Hail Mary' by Andy Weir may interest you.]
]

Recommendation [
Product [ID: 5, Name: Manduka PRO Yoga Mat - High-quality, durable, eco-friendly]
Reason [Since you recently searched for yoga mats and frequently browse fitness equipment, we recommend the Manduka PRO Yoga Mat to support your fitness activities.]
]

Recommendation [
Product [ID: 4, Name: AllBirds Wool Runners - Comfortable, eco-friendly sneakers]
Reason [Considering your preference for the AllBirds brand and your frequent browsing in fitness categories, the AllBirds Wool Runners would be a great fit for your lifestyle.]
]
*/
}
5 changes: 5 additions & 0 deletions pkg/instructor/anthropic.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"fmt"

anthropic "github.com/liushuangls/go-anthropic/v2"
openai "github.com/sashabaranov/go-openai"
)

type AnthropicClient struct {
Expand Down Expand Up @@ -163,3 +164,7 @@ func toAnthropicMessages(request *Request) (*[]anthropic.Message, error) {

return &messages, nil
}

func (a *AnthropicClient) CreateChatCompletionStream(ctx context.Context, request openai.ChatCompletionRequest, mode string, schema *Schema) (<-chan string, error) {
panic("unimplemented")
}
12 changes: 6 additions & 6 deletions pkg/instructor/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ type Client interface {
schema *Schema,
) (string, error)

// TODO: implement streaming
// CreateChatCompletionStream(
// ctx context.Context,
// request ChatCompletionRequest,
// opts ...ClientOptions,
// ) (*T, error)
CreateChatCompletionStream(
ctx context.Context,
request Request,
mode Mode,
schema *Schema,
) (<-chan string, error)
}
143 changes: 100 additions & 43 deletions pkg/instructor/instructor.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,59 +90,116 @@ func (i *Instructor) CreateChatCompletion(ctx context.Context, request Request,
return errors.New("hit max retry attempts")
}

func processResponse(responseStr string, response *any) error {
const WRAPPER_END = `"items": [`

err := json.Unmarshal([]byte(responseStr), response)
type StreamWrapper[T any] struct {
Items []T `json:"items"`
}

func (i *Instructor) CreateChatCompletionStream(ctx context.Context, request Request, response any) (chan any, error) {

responseType := reflect.TypeOf(response)

streamWrapperType := reflect.StructOf([]reflect.StructField{
{
Name: "Items",
Type: reflect.SliceOf(responseType),
Tag: `json:"items"`,
Anonymous: false,
},
})

schema, err := NewSchema(streamWrapperType)
if err != nil {
return err
return nil, err
}

// TODO: if direct unmarshal fails: check common erors like wrapping struct with key name of struct, instead of just the value

return nil
}
request.Stream = true

// Removes any prefixes before the JSON (like "Sure, here you go:")
func trimPrefix(jsonStr string) string {
startObject := strings.IndexByte(jsonStr, '{')
startArray := strings.IndexByte(jsonStr, '[')

var start int
if startObject == -1 && startArray == -1 {
return jsonStr // No opening brace or bracket found, return the original string
} else if startObject == -1 {
start = startArray
} else if startArray == -1 {
start = startObject
} else {
start = min(startObject, startArray)
ch, err := i.Client.CreateChatCompletionStream(ctx, request, i.Mode, schema)
if err != nil {
return nil, err
}

return jsonStr[start:]
parsedChan := make(chan any) // Buffered channel for parsed objects

go func() {
defer close(parsedChan)
var buffer strings.Builder
inArray := false

for {
select {
case <-ctx.Done():
return
case text, ok := <-ch:
if !ok {
// Steeam closed

// Get last element out of stream wrapper

data := buffer.String()

if idx := strings.LastIndex(data, "]"); idx != -1 {
data = data[:idx] + data[idx+1:]
}

// Process the remaining data in the buffer
decoder := json.NewDecoder(strings.NewReader(data))
for decoder.More() {
instance := reflect.New(responseType).Interface()
err := decoder.Decode(instance)
if err != nil {
break
}
parsedChan <- instance
}
return
}
buffer.WriteString(text)

// eat all input until elements stream starts
if !inArray {
idx := strings.Index(buffer.String(), WRAPPER_END)
if idx == -1 {
continue
}

inArray = true
bufferStr := buffer.String()
trimmed := strings.TrimSpace(bufferStr[idx+len(WRAPPER_END):])
buffer.Reset()
buffer.WriteString(trimmed)
}

data := buffer.String()
decoder := json.NewDecoder(strings.NewReader(data))

for decoder.More() {
instance := reflect.New(responseType).Interface()
err := decoder.Decode(instance)
if err != nil {
break
}
parsedChan <- instance

buffer.Reset()
buffer.WriteString(data[len(data):])
}
}
}
}()

return parsedChan, nil
}

// Removes any postfixes after the JSON
func trimPostfix(jsonStr string) string {
endObject := strings.LastIndexByte(jsonStr, '}')
endArray := strings.LastIndexByte(jsonStr, ']')

var end int
if endObject == -1 && endArray == -1 {
return jsonStr // No closing brace or bracket found, return the original string
} else if endObject == -1 {
end = endArray
} else if endArray == -1 {
end = endObject
} else {
end = max(endObject, endArray)
func processResponse(responseStr string, response *any) error {
err := json.Unmarshal([]byte(responseStr), response)
if err != nil {
return err
}

return jsonStr[:end+1]
}
// TODO: if direct unmarshal fails: check common errors like wrapping struct with key name of struct, instead of just the value

// Extracts the JSON by trimming prefixes and postfixes
func extractJSON(jsonStr string) string {
trimmedPrefix := trimPrefix(jsonStr)
trimmedJSON := trimPostfix(trimmedPrefix)
return trimmedJSON
return nil
}
16 changes: 7 additions & 9 deletions pkg/instructor/messages.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,13 @@ import (
openai "github.com/sashabaranov/go-openai"
)

type Message = openai.ChatCompletionMessage

type Request = openai.ChatCompletionRequest

type ChatMessagePart = openai.ChatMessagePart

type ChatMessageImageURL = openai.ChatMessageImageURL

type ChatMessagePartType = openai.ChatMessagePartType
type (
Message = openai.ChatCompletionMessage
Request = openai.ChatCompletionRequest
ChatMessagePart = openai.ChatMessagePart
ChatMessageImageURL = openai.ChatMessageImageURL
ChatMessagePartType = openai.ChatMessagePartType
)

const (
ChatMessagePartTypeText ChatMessagePartType = "text"
Expand Down
Loading