From c9be9ffd67dfa3cb01a86d6f4cdae8e748cb024e Mon Sep 17 00:00:00 2001 From: ericvanlare Date: Mon, 15 Jul 2024 06:09:26 -0700 Subject: [PATCH] count all usage from all retries and failures (#29) * Count all usage from failures to unmarshal and validate json * usage counting: move provider-specific logic into provider chat files --- pkg/instructor/anthropic_chat.go | 63 +++++++++++++++++++++++++++-- pkg/instructor/chat.go | 19 +++++++-- pkg/instructor/cohere_chat.go | 61 +++++++++++++++++++++++++++- pkg/instructor/instructor.go | 7 ++++ pkg/instructor/openai_chat.go | 68 ++++++++++++++++++++++++++++++-- 5 files changed, 207 insertions(+), 11 deletions(-) diff --git a/pkg/instructor/anthropic_chat.go b/pkg/instructor/anthropic_chat.go index 683c124..3c7a57f 100644 --- a/pkg/instructor/anthropic_chat.go +++ b/pkg/instructor/anthropic_chat.go @@ -13,7 +13,10 @@ func (i *InstructorAnthropic) CreateMessages(ctx context.Context, request anthro resp, err := chatHandler(i, ctx, request, responseType) if err != nil { - return anthropic.MessagesResponse{}, err + if resp == nil { + return anthropic.MessagesResponse{}, err + } + return *nilAnthropicRespWithUsage(resp.(*anthropic.MessagesResponse)), err } response = *(resp.(*anthropic.MessagesResponse)) @@ -68,13 +71,13 @@ func (i *InstructorAnthropic) completionToolCall(ctx context.Context, request *a toolInput, err := json.Marshal(c.Input) if err != nil { - return "", nil, err + return "", nilAnthropicRespWithUsage(&resp), err } // TODO: handle more than 1 tool use return string(toolInput), &resp, nil } - return "", nil, errors.New("more than 1 tool response at a time is not implemented") + return "", nilAnthropicRespWithUsage(&resp), errors.New("more than 1 tool response at a time is not implemented") } @@ -103,3 +106,57 @@ Make sure to return an instance of the JSON, not the schema itself. return *text, &resp, nil } + +func (i *InstructorAnthropic) emptyResponseWithUsageSum(usage *UsageSum) interface{} { + return &anthropic.MessagesResponse{ + Usage: anthropic.MessagesUsage{ + InputTokens: usage.InputTokens, + OutputTokens: usage.OutputTokens, + }, + } +} + +func (i *InstructorAnthropic) emptyResponseWithResponseUsage(response interface{}) interface{} { + resp, ok := response.(*anthropic.MessagesResponse) + if !ok || resp == nil { + return nil + } + + return &anthropic.MessagesResponse{ + Usage: resp.Usage, + } +} + +func (i *InstructorAnthropic) addUsageSumToResponse(response interface{}, usage *UsageSum) (interface{}, error) { + resp, ok := response.(*anthropic.MessagesResponse) + if !ok { + return response, fmt.Errorf("internal type error: expected *anthropic.MessagesResponse, got %T", response) + } + + resp.Usage.InputTokens += usage.InputTokens + resp.Usage.OutputTokens += usage.OutputTokens + + return response, nil +} + +func (i *InstructorAnthropic) countUsageFromResponse(response interface{}, usage *UsageSum) *UsageSum { + resp, ok := response.(*anthropic.MessagesResponse) + if !ok { + return usage + } + + usage.InputTokens += resp.Usage.InputTokens + usage.OutputTokens += resp.Usage.OutputTokens + + return usage +} + +func nilAnthropicRespWithUsage(resp *anthropic.MessagesResponse) *anthropic.MessagesResponse { + if resp == nil { + return nil + } + + return &anthropic.MessagesResponse{ + Usage: resp.Usage, + } +} diff --git a/pkg/instructor/chat.go b/pkg/instructor/chat.go index d4a1478..4a6ec41 100644 --- a/pkg/instructor/chat.go +++ b/pkg/instructor/chat.go @@ -9,6 +9,12 @@ import ( "github.com/go-playground/validator/v10" ) +type UsageSum struct { + InputTokens int + OutputTokens int + TotalTokens int +} + func chatHandler(i Instructor, ctx context.Context, request interface{}, response any) (interface{}, error) { var err error @@ -20,12 +26,15 @@ func chatHandler(i Instructor, ctx context.Context, request interface{}, respons return nil, err } + // keep a running total of usage + usage := &UsageSum{} + for attempt := 0; attempt < i.MaxRetries(); attempt++ { text, resp, err := i.chat(ctx, request, schema) if err != nil { // no retry on non-marshalling/validation errors - return nil, err + return i.emptyResponseWithResponseUsage(resp), err } text = extractJSON(&text) @@ -37,6 +46,8 @@ func chatHandler(i Instructor, ctx context.Context, request interface{}, respons // // Currently, its just recalling with no new information // or attempt to fix the error with the last generated JSON + + i.countUsageFromResponse(resp, usage) continue } @@ -48,12 +59,14 @@ func chatHandler(i Instructor, ctx context.Context, request interface{}, respons if err != nil { // TODO: // add more sophisticated retry logic (send back validator error and parse error for model to fix). + + i.countUsageFromResponse(resp, usage) continue } } - return resp, nil + return i.addUsageSumToResponse(resp, usage) } - return nil, errors.New("hit max retry attempts") + return i.emptyResponseWithUsageSum(usage), errors.New("hit max retry attempts") } diff --git a/pkg/instructor/cohere_chat.go b/pkg/instructor/cohere_chat.go index 28373a2..5648143 100644 --- a/pkg/instructor/cohere_chat.go +++ b/pkg/instructor/cohere_chat.go @@ -17,7 +17,10 @@ func (i *InstructorCohere) Chat( resp, err := chatHandler(i, ctx, request, response) if err != nil { - return nil, err + if resp == nil { + return &cohere.NonStreamedChatResponse{}, err + } + return nilCohereRespWithUsage(resp.(*cohere.NonStreamedChatResponse)), err } return resp.(*cohere.NonStreamedChatResponse), nil @@ -80,6 +83,52 @@ func (i *InstructorCohere) addOrConcatJSONSystemPrompt(request *cohere.ChatReque } } +func (i *InstructorCohere) emptyResponseWithUsageSum(usage *UsageSum) interface{} { + return &cohere.NonStreamedChatResponse{ + Meta: &cohere.ApiMeta{ + Tokens: &cohere.ApiMetaTokens{ + InputTokens: toPtr(float64(usage.InputTokens)), + OutputTokens: toPtr(float64(usage.OutputTokens)), + }, + }, + } +} + +func (i *InstructorCohere) emptyResponseWithResponseUsage(response interface{}) interface{} { + resp, ok := response.(*cohere.NonStreamedChatResponse) + if !ok || resp == nil { + return nil + } + + return &cohere.NonStreamedChatResponse{ + Meta: resp.Meta, + } +} + +func (i *InstructorCohere) addUsageSumToResponse(response interface{}, usage *UsageSum) (interface{}, error) { + resp, ok := response.(*cohere.NonStreamedChatResponse) + if !ok { + return response, fmt.Errorf("internal type error: expected *cohere.NonStreamedChatResponse, got %T", response) + } + + *resp.Meta.Tokens.InputTokens += float64(usage.InputTokens) + *resp.Meta.Tokens.OutputTokens += float64(usage.OutputTokens) + + return response, nil +} + +func (i *InstructorCohere) countUsageFromResponse(response interface{}, usage *UsageSum) *UsageSum { + resp, ok := response.(*cohere.NonStreamedChatResponse) + if !ok { + return usage + } + + usage.InputTokens += int(*resp.Meta.Tokens.InputTokens) + usage.OutputTokens += int(*resp.Meta.Tokens.OutputTokens) + + return usage +} + func createCohereTools(schema *Schema) *cohere.Tool { tool := &cohere.Tool{ @@ -98,3 +147,13 @@ func createCohereTools(schema *Schema) *cohere.Tool { return tool } + +func nilCohereRespWithUsage(resp *cohere.NonStreamedChatResponse) *cohere.NonStreamedChatResponse { + if resp == nil { + return nil + } + + return &cohere.NonStreamedChatResponse{ + Meta: resp.Meta, + } +} diff --git a/pkg/instructor/instructor.go b/pkg/instructor/instructor.go index 47b86f7..b1e5688 100644 --- a/pkg/instructor/instructor.go +++ b/pkg/instructor/instructor.go @@ -29,4 +29,11 @@ type Instructor interface { request interface{}, schema *Schema, ) (<-chan string, error) + + // Usage counting + + emptyResponseWithUsageSum(usage *UsageSum) interface{} + emptyResponseWithResponseUsage(response interface{}) interface{} + addUsageSumToResponse(response interface{}, usage *UsageSum) (interface{}, error) + countUsageFromResponse(response interface{}, usage *UsageSum) *UsageSum } diff --git a/pkg/instructor/openai_chat.go b/pkg/instructor/openai_chat.go index 7505de4..b6f73a7 100644 --- a/pkg/instructor/openai_chat.go +++ b/pkg/instructor/openai_chat.go @@ -17,7 +17,10 @@ func (i *InstructorOpenAI) CreateChatCompletion( resp, err := chatHandler(i, ctx, request, responseType) if err != nil { - return openai.ChatCompletionResponse{}, err + if resp == nil { + return openai.ChatCompletionResponse{}, err + } + return *nilOpenaiRespWithUsage(resp.(*openai.ChatCompletionResponse)), err } response = *(resp.(*openai.ChatCompletionResponse)) @@ -69,7 +72,7 @@ func (i *InstructorOpenAI) chatToolCall(ctx context.Context, request *openai.Cha numTools := len(toolCalls) if numTools < 1 { - return "", nil, errors.New("recieved no tool calls from model, expected at least 1") + return "", nilOpenaiRespWithUsage(&resp), errors.New("received no tool calls from model, expected at least 1") } if numTools == 1 { @@ -84,14 +87,14 @@ func (i *InstructorOpenAI) chatToolCall(ctx context.Context, request *openai.Cha var jsonObj map[string]interface{} err = json.Unmarshal([]byte(toolCall.Function.Arguments), &jsonObj) if err != nil { - return "", nil, err + return "", nilOpenaiRespWithUsage(&resp), err } jsonArray[i] = jsonObj } resultJSON, err := json.Marshal(jsonArray) if err != nil { - return "", nil, err + return "", nilOpenaiRespWithUsage(&resp), err } return string(resultJSON), &resp, nil @@ -128,6 +131,53 @@ func (i *InstructorOpenAI) chatJSONSchema(ctx context.Context, request *openai.C return text, &resp, nil } +func (i *InstructorOpenAI) emptyResponseWithUsageSum(usage *UsageSum) interface{} { + return &openai.ChatCompletionResponse{ + Usage: openai.Usage{ + PromptTokens: usage.InputTokens, + CompletionTokens: usage.OutputTokens, + TotalTokens: usage.TotalTokens, + }, + } +} + +func (i *InstructorOpenAI) emptyResponseWithResponseUsage(response interface{}) interface{} { + resp, ok := response.(*openai.ChatCompletionResponse) + if !ok || resp == nil { + return nil + } + + return &openai.ChatCompletionResponse{ + Usage: resp.Usage, + } +} + +func (i *InstructorOpenAI) addUsageSumToResponse(response interface{}, usage *UsageSum) (interface{}, error) { + resp, ok := response.(*openai.ChatCompletionResponse) + if !ok { + return response, fmt.Errorf("internal type error: expected *openai.ChatCompletionResponse, got %T", response) + } + + resp.Usage.PromptTokens += usage.InputTokens + resp.Usage.CompletionTokens += usage.OutputTokens + resp.Usage.TotalTokens += usage.TotalTokens + + return response, nil +} + +func (i *InstructorOpenAI) countUsageFromResponse(response interface{}, usage *UsageSum) *UsageSum { + resp, ok := response.(*openai.ChatCompletionResponse) + if !ok { + return usage + } + + usage.InputTokens += resp.Usage.PromptTokens + usage.OutputTokens += resp.Usage.CompletionTokens + usage.TotalTokens += resp.Usage.TotalTokens + + return usage +} + func createJSONMessage(schema *Schema) *openai.ChatCompletionMessage { message := fmt.Sprintf(` Please respond with JSON in the following JSON schema: @@ -144,3 +194,13 @@ Make sure to return an instance of the JSON, not the schema itself return msg } + +func nilOpenaiRespWithUsage(resp *openai.ChatCompletionResponse) *openai.ChatCompletionResponse { + if resp == nil { + return nil + } + + return &openai.ChatCompletionResponse{ + Usage: resp.Usage, + } +}