Skip to content

Commit

Permalink
OpenAI Retry Client (#77)
Browse files Browse the repository at this point in the history
* OpenaiRetryClient wrapper
  • Loading branch information
danielchalef authored May 24, 2023
1 parent 21f9661 commit 27b132e
Show file tree
Hide file tree
Showing 11 changed files with 134 additions and 58 deletions.
2 changes: 1 addition & 1 deletion cmd/zep/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func run() {
// and creates the OpenAI client
func NewAppState(cfg *config.Config) *models.AppState {
appState := &models.AppState{
OpenAIClient: llms.CreateOpenAIClient(cfg),
OpenAIClient: llms.NewOpenAIRetryClient(cfg),
Config: cfg,
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/extractors/embedder.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func (ee *EmbeddingExtractor) Extract(
for i, r := range messageEvent.Messages {
embeddingRecords[i] = models.Embeddings{
TextUUID: r.UUID,
Embedding: (*embeddings)[i].Embedding,
Embedding: embeddings[i].Embedding,
}
}
err = appState.MemoryStore.PutMessageVectors(
Expand Down
4 changes: 2 additions & 2 deletions pkg/extractors/embedder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func TestEmbeddingExtractor_Extract(t *testing.T) {
store, err := memorystore.NewPostgresMemoryStore(appState, db)
assert.NoError(t, err)
appState.MemoryStore = store
appState.OpenAIClient = llms.CreateOpenAIClient(cfg)
appState.OpenAIClient = llms.NewOpenAIRetryClient(cfg)

sessionID, err := test.GenerateRandomSessionID(16)
assert.NoError(t, err)
Expand Down Expand Up @@ -64,7 +64,7 @@ func TestEmbeddingExtractor_Extract(t *testing.T) {
expectedEmbeddingRecords[i] = models.Embeddings{
TextUUID: r.UUID,
Text: r.Content,
Embedding: (*embeddings)[i].Embedding,
Embedding: embeddings[i].Embedding,
}
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/extractors/summarizer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func TestSummarize(t *testing.T) {
store, err := memorystore.NewPostgresMemoryStore(appState, db)
assert.NoError(t, err)

appState.OpenAIClient = llms.CreateOpenAIClient(cfg)
appState.OpenAIClient = llms.NewOpenAIRetryClient(cfg)
appState.MemoryStore = store

windowSize := 10
Expand Down
25 changes: 3 additions & 22 deletions pkg/llms/embeddings_openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@ import (
"context"
"fmt"

"github.com/avast/retry-go/v4"

"github.com/getzep/zep/pkg/models"
"github.com/sashabaranov/go-openai"
)
Expand All @@ -14,7 +12,7 @@ func EmbedMessages(
ctx context.Context,
appState *models.AppState,
text []string,
) (*[]openai.Embedding, error) {
) ([]openai.Embedding, error) {
if len(text) == 0 {
return nil, NewLLMError("no text to embed", nil)
}
Expand All @@ -33,27 +31,10 @@ func EmbedMessages(
User: "zep_user",
}

// Retry up to 3 times with exponential backoff, cancel after openAIAPITimeout
retryCtx, cancel := context.WithTimeout(ctx, openAIAPITimeout)
defer cancel()
var resp openai.EmbeddingResponse
err := retry.Do(
func() error {
var err error
resp, err = appState.OpenAIClient.CreateEmbeddings(ctx, req)
return err
},
retry.Attempts(3),
retry.Context(retryCtx),
retry.DelayType(retry.BackOffDelay),
retry.OnRetry(func(n uint, err error) {
log.Warningf("Retrying OpenAI API attempt #%d: %s\n", n, err)
}),
)

resp, err := appState.OpenAIClient.CreateEmbeddings(ctx, req)
if err != nil {
return nil, NewLLMError("error while creating embedding", err)
}

return &resp.Data, nil
return resp.Data, nil
}
6 changes: 3 additions & 3 deletions pkg/llms/embeddings_openai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ func TestEmbedMessages(t *testing.T) {
cfg := test.NewTestConfig()

appState := &models.AppState{Config: cfg}
appState.OpenAIClient = CreateOpenAIClient(cfg)
appState.OpenAIClient = NewOpenAIRetryClient(cfg)

vectorLength := 1536

Expand All @@ -27,10 +27,10 @@ func TestEmbedMessages(t *testing.T) {
embeddings, err := EmbedMessages(ctx, appState, messageContents)
assert.NoError(t, err)
assert.NotNil(t, embeddings)
assert.Len(t, *embeddings, 2)
assert.Len(t, embeddings, 2)

// Check if the embeddings are of the correct length
for _, embedding := range *embeddings {
for _, embedding := range embeddings {
assert.Len(t, embedding.Embedding, int(vectorLength))
}
}
45 changes: 21 additions & 24 deletions pkg/llms/llm_openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ import (
"sync"
"time"

"github.com/avast/retry-go/v4"
"github.com/getzep/zep/pkg/llms/openairetryclient"

"github.com/getzep/zep/config"
"github.com/getzep/zep/pkg/models"
"github.com/pkoukk/tiktoken-go"
Expand All @@ -22,6 +23,24 @@ var (
tkmError error
)

func NewOpenAIRetryClient(cfg *config.Config) *openairetryclient.OpenAIRetryClient {
apiKey := cfg.LLM.OpenAIAPIKey
if apiKey == "" {
log.Fatal(OpenAIAPIKeyNotSetError)
}
client := openai.NewClient(apiKey)
return &openairetryclient.OpenAIRetryClient{
Client: *client,
Config: struct {
Timeout time.Duration
MaxAttempts uint
}{
Timeout: openAIAPITimeout,
MaxAttempts: 3,
},
}
}

func getTokenCountObject() (*tiktoken.Tiktoken, error) {
once.Do(func() {
encoding := "cl100k_base"
Expand All @@ -31,14 +50,6 @@ func getTokenCountObject() (*tiktoken.Tiktoken, error) {
return tkm, tkmError
}

func CreateOpenAIClient(cfg *config.Config) *openai.Client {
openAIKey := cfg.LLM.OpenAIAPIKey
if openAIKey == "" {
log.Fatal(OpenAIAPIKeyNotSetError)
}
return openai.NewClient(openAIKey)
}

func RunChatCompletion(
ctx context.Context,
appState *models.AppState,
Expand All @@ -60,21 +71,7 @@ func RunChatCompletion(
},
Temperature: DefaultTemperature,
}
// Retry up to 3 times with exponential backoff, cancel after openAIAPITimeout
retryCtx, cancel := context.WithTimeout(ctx, openAIAPITimeout)
defer cancel()
err = retry.Do(
func() error {
resp, err = appState.OpenAIClient.CreateChatCompletion(retryCtx, req)
return err
},
retry.Attempts(3),
retry.Context(retryCtx),
retry.DelayType(retry.BackOffDelay),
retry.OnRetry(func(n uint, err error) {
log.Warningf("Retrying OpenAI API attempt #%d: %s\n", n, err)
}),
)
resp, err = appState.OpenAIClient.CreateChatCompletion(ctx, req)
if err != nil {
return openai.ChatCompletionResponse{}, err
}
Expand Down
98 changes: 98 additions & 0 deletions pkg/llms/openairetryclient/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
package openairetryclient

import (
"context"
"errors"
"fmt"
"time"

"github.com/getzep/zep/internal"

"github.com/avast/retry-go/v4"
"github.com/sashabaranov/go-openai"
)

var log = internal.GetLogger()

type OpenAIRetryClient struct {
openai.Client
Config struct {
Timeout time.Duration
MaxAttempts uint
}
}

func (c *OpenAIRetryClient) CreateChatCompletionWithRetry(
ctx context.Context,
request openai.ChatCompletionRequest,
) (*openai.ChatCompletionResponse, error) {
fn := func(ctx context.Context, arg interface{}) (interface{}, error) {
req := arg.(openai.ChatCompletionRequest)
return c.CreateChatCompletion(ctx, req)
}

result, err := c.retryFunction(ctx, c.Config.Timeout, c.Config.MaxAttempts, fn, request)
if err != nil {
return nil, fmt.Errorf("unexpected response from OpenAI API: %w", err)
}

response, ok := result.(openai.ChatCompletionResponse)
if !ok {
return nil, errors.New(
"unexpected type returned from openai client CreateChatCompletion",
)
}
return &response, nil
}

func (c *OpenAIRetryClient) CreateEmbeddingsWithRetry(
ctx context.Context,
request openai.EmbeddingRequest,
) (*openai.EmbeddingResponse, error) {
fn := func(ctx context.Context, arg interface{}) (interface{}, error) {
req := arg.(openai.EmbeddingRequest)
return c.CreateEmbeddings(ctx, req)
}

result, err := c.retryFunction(ctx, c.Config.Timeout, c.Config.MaxAttempts, fn, request)
if err != nil {
return nil, fmt.Errorf("unexpected response from OpenAI API: %w", err)
}

response, ok := result.(openai.EmbeddingResponse)
if !ok {
return nil, errors.New("unexpected type returned from openai client CreateEmbeddings")
}
return &response, nil
}

func (c *OpenAIRetryClient) retryFunction(
ctx context.Context,
timeout time.Duration,
maxAttempts uint,
fn func(context.Context, interface{}) (interface{}, error),
arg interface{}) (interface{}, error) {
var result interface{}
var err error
retryCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()

err = retry.Do(
func() error {
result, err = fn(retryCtx, arg)
return err
},
retry.Attempts(maxAttempts),
retry.Context(retryCtx),
retry.DelayType(retry.BackOffDelay),
retry.OnRetry(func(n uint, err error) {
log.Warningf("retrying function attempt #%d: %s\n", n, err)
}),
)

if err != nil {
return nil, err
}

return result, nil
}
2 changes: 1 addition & 1 deletion pkg/memorystore/postgres_search.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func searchMessages(
if err != nil {
return nil, NewStorageError("failed to embed query", err)
}
vector := pgvector.NewVector((*e)[0].Embedding)
vector := pgvector.NewVector(e[0].Embedding)

var results []models.SearchResult
err = db.NewSelect().
Expand Down
2 changes: 1 addition & 1 deletion pkg/memorystore/postgres_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func setup() {
cfg := test.NewTestConfig()

appState = &models.AppState{}
appState.OpenAIClient = llms.CreateOpenAIClient(cfg)
appState.OpenAIClient = llms.NewOpenAIRetryClient(cfg)
appState.Config = cfg
store, err := NewPostgresMemoryStore(appState, testDB)
if err != nil {
Expand Down
4 changes: 2 additions & 2 deletions pkg/models/appstate.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@ package models

import (
"github.com/getzep/zep/config"
"github.com/sashabaranov/go-openai"
openairetryclient "github.com/getzep/zep/pkg/llms/openairetryclient"
)

// AppState is a struct that holds the state of the application
// Use cmd.NewAppState to create a new instance
type AppState struct {
OpenAIClient *openai.Client
OpenAIClient *openairetryclient.OpenAIRetryClient
MemoryStore MemoryStore[any]
Config *config.Config
}

0 comments on commit 27b132e

Please sign in to comment.