Skip to content

Commit

Permalink
count all usage from all retries and failures (#29)
Browse files Browse the repository at this point in the history
* Count all usage from failures to unmarshal and validate json

* usage counting: move provider-specific logic into provider chat files
  • Loading branch information
ericvanlare authored Jul 15, 2024
1 parent ea5dfe9 commit c9be9ff
Show file tree
Hide file tree
Showing 5 changed files with 207 additions and 11 deletions.
63 changes: 60 additions & 3 deletions pkg/instructor/anthropic_chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@ func (i *InstructorAnthropic) CreateMessages(ctx context.Context, request anthro

resp, err := chatHandler(i, ctx, request, responseType)
if err != nil {
return anthropic.MessagesResponse{}, err
if resp == nil {
return anthropic.MessagesResponse{}, err
}
return *nilAnthropicRespWithUsage(resp.(*anthropic.MessagesResponse)), err
}

response = *(resp.(*anthropic.MessagesResponse))
Expand Down Expand Up @@ -68,13 +71,13 @@ func (i *InstructorAnthropic) completionToolCall(ctx context.Context, request *a

toolInput, err := json.Marshal(c.Input)
if err != nil {
return "", nil, err
return "", nilAnthropicRespWithUsage(&resp), err
}
// TODO: handle more than 1 tool use
return string(toolInput), &resp, nil
}

return "", nil, errors.New("more than 1 tool response at a time is not implemented")
return "", nilAnthropicRespWithUsage(&resp), errors.New("more than 1 tool response at a time is not implemented")

}

Expand Down Expand Up @@ -103,3 +106,57 @@ Make sure to return an instance of the JSON, not the schema itself.

return *text, &resp, nil
}

func (i *InstructorAnthropic) emptyResponseWithUsageSum(usage *UsageSum) interface{} {
return &anthropic.MessagesResponse{
Usage: anthropic.MessagesUsage{
InputTokens: usage.InputTokens,
OutputTokens: usage.OutputTokens,
},
}
}

func (i *InstructorAnthropic) emptyResponseWithResponseUsage(response interface{}) interface{} {
resp, ok := response.(*anthropic.MessagesResponse)
if !ok || resp == nil {
return nil
}

return &anthropic.MessagesResponse{
Usage: resp.Usage,
}
}

func (i *InstructorAnthropic) addUsageSumToResponse(response interface{}, usage *UsageSum) (interface{}, error) {
resp, ok := response.(*anthropic.MessagesResponse)
if !ok {
return response, fmt.Errorf("internal type error: expected *anthropic.MessagesResponse, got %T", response)
}

resp.Usage.InputTokens += usage.InputTokens
resp.Usage.OutputTokens += usage.OutputTokens

return response, nil
}

func (i *InstructorAnthropic) countUsageFromResponse(response interface{}, usage *UsageSum) *UsageSum {
resp, ok := response.(*anthropic.MessagesResponse)
if !ok {
return usage
}

usage.InputTokens += resp.Usage.InputTokens
usage.OutputTokens += resp.Usage.OutputTokens

return usage
}

func nilAnthropicRespWithUsage(resp *anthropic.MessagesResponse) *anthropic.MessagesResponse {
if resp == nil {
return nil
}

return &anthropic.MessagesResponse{
Usage: resp.Usage,
}
}
19 changes: 16 additions & 3 deletions pkg/instructor/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@ import (
"github.com/go-playground/validator/v10"
)

type UsageSum struct {
InputTokens int
OutputTokens int
TotalTokens int
}

func chatHandler(i Instructor, ctx context.Context, request interface{}, response any) (interface{}, error) {

var err error
Expand All @@ -20,12 +26,15 @@ func chatHandler(i Instructor, ctx context.Context, request interface{}, respons
return nil, err
}

// keep a running total of usage
usage := &UsageSum{}

for attempt := 0; attempt < i.MaxRetries(); attempt++ {

text, resp, err := i.chat(ctx, request, schema)
if err != nil {
// no retry on non-marshalling/validation errors
return nil, err
return i.emptyResponseWithResponseUsage(resp), err
}

text = extractJSON(&text)
Expand All @@ -37,6 +46,8 @@ func chatHandler(i Instructor, ctx context.Context, request interface{}, respons
//
// Currently, its just recalling with no new information
// or attempt to fix the error with the last generated JSON

i.countUsageFromResponse(resp, usage)
continue
}

Expand All @@ -48,12 +59,14 @@ func chatHandler(i Instructor, ctx context.Context, request interface{}, respons
if err != nil {
// TODO:
// add more sophisticated retry logic (send back validator error and parse error for model to fix).

i.countUsageFromResponse(resp, usage)
continue
}
}

return resp, nil
return i.addUsageSumToResponse(resp, usage)
}

return nil, errors.New("hit max retry attempts")
return i.emptyResponseWithUsageSum(usage), errors.New("hit max retry attempts")
}
61 changes: 60 additions & 1 deletion pkg/instructor/cohere_chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@ func (i *InstructorCohere) Chat(

resp, err := chatHandler(i, ctx, request, response)
if err != nil {
return nil, err
if resp == nil {
return &cohere.NonStreamedChatResponse{}, err
}
return nilCohereRespWithUsage(resp.(*cohere.NonStreamedChatResponse)), err
}

return resp.(*cohere.NonStreamedChatResponse), nil
Expand Down Expand Up @@ -80,6 +83,52 @@ func (i *InstructorCohere) addOrConcatJSONSystemPrompt(request *cohere.ChatReque
}
}

func (i *InstructorCohere) emptyResponseWithUsageSum(usage *UsageSum) interface{} {
return &cohere.NonStreamedChatResponse{
Meta: &cohere.ApiMeta{
Tokens: &cohere.ApiMetaTokens{
InputTokens: toPtr(float64(usage.InputTokens)),
OutputTokens: toPtr(float64(usage.OutputTokens)),
},
},
}
}

func (i *InstructorCohere) emptyResponseWithResponseUsage(response interface{}) interface{} {
resp, ok := response.(*cohere.NonStreamedChatResponse)
if !ok || resp == nil {
return nil
}

return &cohere.NonStreamedChatResponse{
Meta: resp.Meta,
}
}

func (i *InstructorCohere) addUsageSumToResponse(response interface{}, usage *UsageSum) (interface{}, error) {
resp, ok := response.(*cohere.NonStreamedChatResponse)
if !ok {
return response, fmt.Errorf("internal type error: expected *cohere.NonStreamedChatResponse, got %T", response)
}

*resp.Meta.Tokens.InputTokens += float64(usage.InputTokens)
*resp.Meta.Tokens.OutputTokens += float64(usage.OutputTokens)

return response, nil
}

func (i *InstructorCohere) countUsageFromResponse(response interface{}, usage *UsageSum) *UsageSum {
resp, ok := response.(*cohere.NonStreamedChatResponse)
if !ok {
return usage
}

usage.InputTokens += int(*resp.Meta.Tokens.InputTokens)
usage.OutputTokens += int(*resp.Meta.Tokens.OutputTokens)

return usage
}

func createCohereTools(schema *Schema) *cohere.Tool {

tool := &cohere.Tool{
Expand All @@ -98,3 +147,13 @@ func createCohereTools(schema *Schema) *cohere.Tool {

return tool
}

func nilCohereRespWithUsage(resp *cohere.NonStreamedChatResponse) *cohere.NonStreamedChatResponse {
if resp == nil {
return nil
}

return &cohere.NonStreamedChatResponse{
Meta: resp.Meta,
}
}
7 changes: 7 additions & 0 deletions pkg/instructor/instructor.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,11 @@ type Instructor interface {
request interface{},
schema *Schema,
) (<-chan string, error)

// Usage counting

emptyResponseWithUsageSum(usage *UsageSum) interface{}
emptyResponseWithResponseUsage(response interface{}) interface{}
addUsageSumToResponse(response interface{}, usage *UsageSum) (interface{}, error)
countUsageFromResponse(response interface{}, usage *UsageSum) *UsageSum
}
68 changes: 64 additions & 4 deletions pkg/instructor/openai_chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@ func (i *InstructorOpenAI) CreateChatCompletion(

resp, err := chatHandler(i, ctx, request, responseType)
if err != nil {
return openai.ChatCompletionResponse{}, err
if resp == nil {
return openai.ChatCompletionResponse{}, err
}
return *nilOpenaiRespWithUsage(resp.(*openai.ChatCompletionResponse)), err
}

response = *(resp.(*openai.ChatCompletionResponse))
Expand Down Expand Up @@ -69,7 +72,7 @@ func (i *InstructorOpenAI) chatToolCall(ctx context.Context, request *openai.Cha
numTools := len(toolCalls)

if numTools < 1 {
return "", nil, errors.New("recieved no tool calls from model, expected at least 1")
return "", nilOpenaiRespWithUsage(&resp), errors.New("received no tool calls from model, expected at least 1")
}

if numTools == 1 {
Expand All @@ -84,14 +87,14 @@ func (i *InstructorOpenAI) chatToolCall(ctx context.Context, request *openai.Cha
var jsonObj map[string]interface{}
err = json.Unmarshal([]byte(toolCall.Function.Arguments), &jsonObj)
if err != nil {
return "", nil, err
return "", nilOpenaiRespWithUsage(&resp), err
}
jsonArray[i] = jsonObj
}

resultJSON, err := json.Marshal(jsonArray)
if err != nil {
return "", nil, err
return "", nilOpenaiRespWithUsage(&resp), err
}

return string(resultJSON), &resp, nil
Expand Down Expand Up @@ -128,6 +131,53 @@ func (i *InstructorOpenAI) chatJSONSchema(ctx context.Context, request *openai.C
return text, &resp, nil
}

func (i *InstructorOpenAI) emptyResponseWithUsageSum(usage *UsageSum) interface{} {
return &openai.ChatCompletionResponse{
Usage: openai.Usage{
PromptTokens: usage.InputTokens,
CompletionTokens: usage.OutputTokens,
TotalTokens: usage.TotalTokens,
},
}
}

func (i *InstructorOpenAI) emptyResponseWithResponseUsage(response interface{}) interface{} {
resp, ok := response.(*openai.ChatCompletionResponse)
if !ok || resp == nil {
return nil
}

return &openai.ChatCompletionResponse{
Usage: resp.Usage,
}
}

func (i *InstructorOpenAI) addUsageSumToResponse(response interface{}, usage *UsageSum) (interface{}, error) {
resp, ok := response.(*openai.ChatCompletionResponse)
if !ok {
return response, fmt.Errorf("internal type error: expected *openai.ChatCompletionResponse, got %T", response)
}

resp.Usage.PromptTokens += usage.InputTokens
resp.Usage.CompletionTokens += usage.OutputTokens
resp.Usage.TotalTokens += usage.TotalTokens

return response, nil
}

func (i *InstructorOpenAI) countUsageFromResponse(response interface{}, usage *UsageSum) *UsageSum {
resp, ok := response.(*openai.ChatCompletionResponse)
if !ok {
return usage
}

usage.InputTokens += resp.Usage.PromptTokens
usage.OutputTokens += resp.Usage.CompletionTokens
usage.TotalTokens += resp.Usage.TotalTokens

return usage
}

func createJSONMessage(schema *Schema) *openai.ChatCompletionMessage {
message := fmt.Sprintf(`
Please respond with JSON in the following JSON schema:
Expand All @@ -144,3 +194,13 @@ Make sure to return an instance of the JSON, not the schema itself

return msg
}

func nilOpenaiRespWithUsage(resp *openai.ChatCompletionResponse) *openai.ChatCompletionResponse {
if resp == nil {
return nil
}

return &openai.ChatCompletionResponse{
Usage: resp.Usage,
}
}

0 comments on commit c9be9ff

Please sign in to comment.