Skip to content

Commit

Permalink
Add Streaming Support (#20)
Browse files Browse the repository at this point in the history
* start streaming

* messy but working

* use generic stream array wrapper struct

* cleanup print

* Update readme

---------

Co-authored-by: Robby <[email protected]>
  • Loading branch information
h0rv and h0rv authored May 24, 2024
1 parent 7a8358a commit c5d4f7b
Show file tree
Hide file tree
Showing 9 changed files with 1,051 additions and 110 deletions.
659 changes: 657 additions & 2 deletions README.md

Large diffs are not rendered by default.

4 changes: 0 additions & 4 deletions examples/function_calling/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,6 @@ func (s *Search) execute() {

type Searches = []Search

// type Searches struct {
// Items []Search `json:"searches" jsonschema:"title=Searches,description=A list of search results"`
// }

func segment(ctx context.Context, data string) *Searches {

client, err := instructor.FromOpenAI(
Expand Down
132 changes: 132 additions & 0 deletions examples/streaming/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
package main

import (
"context"
"fmt"
"os"

"github.com/instructor-ai/instructor-go/pkg/instructor"
openai "github.com/sashabaranov/go-openai"
)

type Product struct {
ID string `json:"product_id" jsonschema:"title=Product ID,description=ID of the product,required=True"`
Name string `json:"product_name" jsonschema:"title=Product Name,description=Name of the product,required=True"`
}

func (p *Product) String() string {
return fmt.Sprintf("Product [ID: %s, Name: %s]", p.ID, p.Name)
}

type Recommendation struct {
Product
Reason string `json:"reason" jsonschema:"title=Recommendation Reason,description=Reason for the product recommendation"`
}

func (r *Recommendation) String() string {
return fmt.Sprintf(`
Recommendation [
%s
Reason [%s]
]`, r.Product.String(), r.Reason)
}

func main() {
ctx := context.Background()

client, err := instructor.FromOpenAI(
openai.NewClient(os.Getenv("OPENAI_API_KEY")),
instructor.WithMode(instructor.ModeJSON),
)
if err != nil {
panic(err)
}

profileData := `
Customer ID: 12345
Recent Purchases: [Laptop, Wireless Headphones, Smart Watch]
Frequently Browsed Categories: [Electronics, Books, Fitness Equipment]
Product Ratings: {Laptop: 5 stars, Wireless Headphones: 4 stars}
Recent Search History: [best budget laptops 2023, latest sci-fi books, yoga mats]
Preferred Brands: [Apple, AllBirds, Bench]
Responses to Previous Recommendations: {Philips: Not Interested, Adidas: Not Interested}
Loyalty Program Status: Gold Member
Average Monthly Spend: $500
Preferred Shopping Times: Weekend Evenings
...
`

products := []Product{
{ID: "1", Name: "Sony WH-1000XM4 Wireless Headphones - Noise-canceling, long battery life"},
{ID: "2", Name: "Apple Watch Series 7 - Advanced fitness tracking, seamless integration with Apple ecosystem"},
{ID: "3", Name: "Kindle Oasis - Premium e-reader with adjustable warm light"},
{ID: "4", Name: "AllBirds Wool Runners - Comfortable, eco-friendly sneakers"},
{ID: "5", Name: "Manduka PRO Yoga Mat - High-quality, durable, eco-friendly"},
{ID: "6", Name: "Bench Hooded Jacket - Stylish, durable, suitable for outdoor activities"},
{ID: "7", Name: "Apple MacBook Air (2023) - Latest model, high performance, portable"},
{ID: "8", Name: "GoPro HERO9 Black - 5K video, waterproof, for action photography"},
{ID: "9", Name: "Nespresso Vertuo Next Coffee Machine - Quality coffee, easy to use, compact design"},
{ID: "10", Name: "Project Hail Mary by Andy Weir - Latest sci-fi book from a renowned author"},
}

productList := ""
for _, product := range products {
productList += product.String() + "\n"
}

recommendationChan, err := client.CreateChatCompletionStream(
ctx,
instructor.Request{
Model: openai.GPT4o20240513,
Messages: []instructor.Message{
{
Role: instructor.RoleSystem,
Content: fmt.Sprintf(`Generate the product recommendations from the product list based on the customer profile.
Return in order of highest recommended first.
Product list:
%s`, productList),
},
{
Role: instructor.RoleUser,
Content: fmt.Sprintf("User profile:\n%s", profileData),
},
},
Stream: true,
},
*new(Recommendation),
)
if err != nil {
panic(err)
}

for instance := range recommendationChan {
recommendation, _ := instance.(*Recommendation)
println(recommendation.String())
}
/*
Recommendation [
Product [ID: 7, Name: Apple MacBook Air (2023) - Latest model, high performance, portable]
Reason [As you have recently searched for budget laptops of 2023 and previously purchased a laptop, we believe the latest Apple MacBook Air will meet your high-performance requirements. Additionally, Apple is one of your preferred brands.]
]
Recommendation [
Product [ID: 2, Name: Apple Watch Series 7 - Advanced fitness tracking, seamless integration with Apple ecosystem]
Reason [Based on your recent purchase history which includes a smart watch and your preference for Apple products, we recommend the Apple Watch Series 7 for its advanced fitness tracking features.]
]
Recommendation [
Product [ID: 10, Name: Project Hail Mary by Andy Weir - Latest sci-fi book from a renowned author]
Reason [Given your recent search for the latest sci-fi books and frequent browsing in the Books category, 'Project Hail Mary' by Andy Weir may interest you.]
]
Recommendation [
Product [ID: 5, Name: Manduka PRO Yoga Mat - High-quality, durable, eco-friendly]
Reason [Since you recently searched for yoga mats and frequently browse fitness equipment, we recommend the Manduka PRO Yoga Mat to support your fitness activities.]
]
Recommendation [
Product [ID: 4, Name: AllBirds Wool Runners - Comfortable, eco-friendly sneakers]
Reason [Considering your preference for the AllBirds brand and your frequent browsing in fitness categories, the AllBirds Wool Runners would be a great fit for your lifestyle.]
]
*/
}
5 changes: 5 additions & 0 deletions pkg/instructor/anthropic.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"fmt"

anthropic "github.com/liushuangls/go-anthropic/v2"
openai "github.com/sashabaranov/go-openai"
)

type AnthropicClient struct {
Expand Down Expand Up @@ -163,3 +164,7 @@ func toAnthropicMessages(request *Request) (*[]anthropic.Message, error) {

return &messages, nil
}

func (a *AnthropicClient) CreateChatCompletionStream(ctx context.Context, request openai.ChatCompletionRequest, mode string, schema *Schema) (<-chan string, error) {
panic("unimplemented")
}
12 changes: 6 additions & 6 deletions pkg/instructor/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ type Client interface {
schema *Schema,
) (string, error)

// TODO: implement streaming
// CreateChatCompletionStream(
// ctx context.Context,
// request ChatCompletionRequest,
// opts ...ClientOptions,
// ) (*T, error)
CreateChatCompletionStream(
ctx context.Context,
request Request,
mode Mode,
schema *Schema,
) (<-chan string, error)
}
143 changes: 100 additions & 43 deletions pkg/instructor/instructor.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,59 +90,116 @@ func (i *Instructor) CreateChatCompletion(ctx context.Context, request Request,
return errors.New("hit max retry attempts")
}

func processResponse(responseStr string, response *any) error {
const WRAPPER_END = `"items": [`

err := json.Unmarshal([]byte(responseStr), response)
type StreamWrapper[T any] struct {
Items []T `json:"items"`
}

func (i *Instructor) CreateChatCompletionStream(ctx context.Context, request Request, response any) (chan any, error) {

responseType := reflect.TypeOf(response)

streamWrapperType := reflect.StructOf([]reflect.StructField{
{
Name: "Items",
Type: reflect.SliceOf(responseType),
Tag: `json:"items"`,
Anonymous: false,
},
})

schema, err := NewSchema(streamWrapperType)
if err != nil {
return err
return nil, err
}

// TODO: if direct unmarshal fails: check common erors like wrapping struct with key name of struct, instead of just the value

return nil
}
request.Stream = true

// Removes any prefixes before the JSON (like "Sure, here you go:")
func trimPrefix(jsonStr string) string {
startObject := strings.IndexByte(jsonStr, '{')
startArray := strings.IndexByte(jsonStr, '[')

var start int
if startObject == -1 && startArray == -1 {
return jsonStr // No opening brace or bracket found, return the original string
} else if startObject == -1 {
start = startArray
} else if startArray == -1 {
start = startObject
} else {
start = min(startObject, startArray)
ch, err := i.Client.CreateChatCompletionStream(ctx, request, i.Mode, schema)
if err != nil {
return nil, err
}

return jsonStr[start:]
parsedChan := make(chan any) // Buffered channel for parsed objects

go func() {
defer close(parsedChan)
var buffer strings.Builder
inArray := false

for {
select {
case <-ctx.Done():
return
case text, ok := <-ch:
if !ok {
// Steeam closed

// Get last element out of stream wrapper

data := buffer.String()

if idx := strings.LastIndex(data, "]"); idx != -1 {
data = data[:idx] + data[idx+1:]
}

// Process the remaining data in the buffer
decoder := json.NewDecoder(strings.NewReader(data))
for decoder.More() {
instance := reflect.New(responseType).Interface()
err := decoder.Decode(instance)
if err != nil {
break
}
parsedChan <- instance
}
return
}
buffer.WriteString(text)

// eat all input until elements stream starts
if !inArray {
idx := strings.Index(buffer.String(), WRAPPER_END)
if idx == -1 {
continue
}

inArray = true
bufferStr := buffer.String()
trimmed := strings.TrimSpace(bufferStr[idx+len(WRAPPER_END):])
buffer.Reset()
buffer.WriteString(trimmed)
}

data := buffer.String()
decoder := json.NewDecoder(strings.NewReader(data))

for decoder.More() {
instance := reflect.New(responseType).Interface()
err := decoder.Decode(instance)
if err != nil {
break
}
parsedChan <- instance

buffer.Reset()
buffer.WriteString(data[len(data):])
}
}
}
}()

return parsedChan, nil
}

// Removes any postfixes after the JSON
func trimPostfix(jsonStr string) string {
endObject := strings.LastIndexByte(jsonStr, '}')
endArray := strings.LastIndexByte(jsonStr, ']')

var end int
if endObject == -1 && endArray == -1 {
return jsonStr // No closing brace or bracket found, return the original string
} else if endObject == -1 {
end = endArray
} else if endArray == -1 {
end = endObject
} else {
end = max(endObject, endArray)
func processResponse(responseStr string, response *any) error {
err := json.Unmarshal([]byte(responseStr), response)
if err != nil {
return err
}

return jsonStr[:end+1]
}
// TODO: if direct unmarshal fails: check common errors like wrapping struct with key name of struct, instead of just the value

// Extracts the JSON by trimming prefixes and postfixes
func extractJSON(jsonStr string) string {
trimmedPrefix := trimPrefix(jsonStr)
trimmedJSON := trimPostfix(trimmedPrefix)
return trimmedJSON
return nil
}
16 changes: 7 additions & 9 deletions pkg/instructor/messages.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,13 @@ import (
openai "github.com/sashabaranov/go-openai"
)

type Message = openai.ChatCompletionMessage

type Request = openai.ChatCompletionRequest

type ChatMessagePart = openai.ChatMessagePart

type ChatMessageImageURL = openai.ChatMessageImageURL

type ChatMessagePartType = openai.ChatMessagePartType
type (
Message = openai.ChatCompletionMessage
Request = openai.ChatCompletionRequest
ChatMessagePart = openai.ChatMessagePart
ChatMessageImageURL = openai.ChatMessageImageURL
ChatMessagePartType = openai.ChatMessagePartType
)

const (
ChatMessagePartTypeText ChatMessagePartType = "text"
Expand Down
Loading

0 comments on commit c5d4f7b

Please sign in to comment.