From 61c70f6d383c90f97c5aaa696c3d112d9451443e Mon Sep 17 00:00:00 2001 From: Brice Macias Date: Mon, 30 Oct 2023 12:19:05 -0600 Subject: [PATCH 1/8] creating separate embeddings client option logic --- cmd/zep/run.go | 15 +++- config.yaml | 15 ++++ config/config.go | 9 ++- config/models.go | 33 +++++--- pkg/llms/embeddings.go | 13 +++- pkg/llms/embeddings_openai.go | 88 +++++++++++++++++++++ pkg/llms/embeddings_openai_test.go | 118 +++++++++++++++++++++++++++++ pkg/llms/llm_base.go | 88 ++++++++++++++++++++- pkg/llms/llm_openai.go | 37 ++------- pkg/models/appstate.go | 15 ++-- pkg/models/zembeddings.go | 14 ++++ pkg/server/webhandlers/settings.go | 1 + pkg/testutils/utils.go | 2 + 13 files changed, 386 insertions(+), 62 deletions(-) create mode 100644 pkg/llms/embeddings_openai.go create mode 100644 pkg/llms/embeddings_openai_test.go create mode 100644 pkg/models/zembeddings.go diff --git a/cmd/zep/run.go b/cmd/zep/run.go index 0db92035..e77cd061 100644 --- a/cmd/zep/run.go +++ b/cmd/zep/run.go @@ -67,9 +67,20 @@ func NewAppState(cfg *config.Config) *models.AppState { log.Fatal(err) } + var embeddingsClient models.ZepEmbeddingsClient = nil + + if cfg.EmbeddingsClient.Enabled { + embeddingsClient, err = llms.NewEmbeddingsClient(ctx, cfg) + } + + if err != nil { + log.Fatal(err) + } + appState := &models.AppState{ - LLMClient: llmClient, - Config: cfg, + LLMClient: llmClient, + EmbeddingsClient: embeddingsClient, + Config: cfg, } initializeStores(ctx, appState) diff --git a/config.yaml b/config.yaml index 3fe3e92e..d3fa94f4 100644 --- a/config.yaml +++ b/config.yaml @@ -18,6 +18,18 @@ llm: # Use only with an alternate OpenAI-compatible API endpoint openai_endpoint: openai_org_id: +embeddings_client: + # Enable if, when not using local embeddings, you want to use distinct clients for your llm and embeddings. + # For now, only OpenAI embeddings are available as external embeddings client. + enabled: false + # When enabled, please make sure to have your ZEP_EMBEDDINGS_OPENAI_API_KEY environment variable set + service: "openai" + azure_openai_endpoint: + azure_openai: + #embeddings deployment is required when using azure deployment for embeddings + embedding_deployment: "text-embedding-ada-002-customname" + openai_endpoint: + openai_org_id: nlp: server_url: "http://localhost:5557" memory: @@ -96,6 +108,9 @@ custom_prompts: # {{.MessagesJoined}} # Response without preamble. # + # For Open Source models compatible with the OpenAI client, + # follow the prompt guidelines provided by the model owners. + # # If left empty, the default Anthropic summary prompt from zep/pkg/extractors/prompts.go will be used. anthropic: | diff --git a/config/config.go b/config/config.go index def08e27..ab925551 100644 --- a/config/config.go +++ b/config/config.go @@ -16,10 +16,11 @@ var log = logrus.New() // EnvVars is a set of secrets that should be stored in the environment, not config file var EnvVars = map[string]string{ - "llm.anthropic_api_key": "ZEP_ANTHROPIC_API_KEY", - "llm.openai_api_key": "ZEP_OPENAI_API_KEY", - "auth.secret": "ZEP_AUTH_SECRET", - "development": "ZEP_DEVELOPMENT", + "llm.anthropic_api_key": "ZEP_ANTHROPIC_API_KEY", + "llm.openai_api_key": "ZEP_OPENAI_API_KEY", + "embeddings_client.openai_api_key": "ZEP_EMBEDDINGS_OPENAI_API_KEY", + "auth.secret": "ZEP_AUTH_SECRET", + "development": "ZEP_DEVELOPMENT", } // LoadConfig loads the config file and ENV variables into a Config struct diff --git a/config/models.go b/config/models.go index b58b1466..88bce1f3 100644 --- a/config/models.go +++ b/config/models.go @@ -3,17 +3,18 @@ package config // Config holds the configuration of the application // Use cmd.NewConfig to create a new instance type Config struct { - LLM LLM `mapstructure:"llm"` - NLP NLP `mapstructure:"nlp"` - Memory MemoryConfig `mapstructure:"memory"` - Extractors ExtractorsConfig `mapstructure:"extractors"` - Store StoreConfig `mapstructure:"store"` - Server ServerConfig `mapstructure:"server"` - Log LogConfig `mapstructure:"log"` - Auth AuthConfig `mapstructure:"auth"` - DataConfig DataConfig `mapstructure:"data"` - Development bool `mapstructure:"development"` - CustomPrompts CustomPromptsConfig `mapstructure:"custom_prompts"` + LLM LLM `mapstructure:"llm"` + EmbeddingsClient EmbeddingsClient `mapstructure:"embeddings_client"` + NLP NLP `mapstructure:"nlp"` + Memory MemoryConfig `mapstructure:"memory"` + Extractors ExtractorsConfig `mapstructure:"extractors"` + Store StoreConfig `mapstructure:"store"` + Server ServerConfig `mapstructure:"server"` + Log LogConfig `mapstructure:"log"` + Auth AuthConfig `mapstructure:"auth"` + DataConfig DataConfig `mapstructure:"data"` + Development bool `mapstructure:"development"` + CustomPrompts CustomPromptsConfig `mapstructure:"custom_prompts"` } type StoreConfig struct { @@ -32,6 +33,16 @@ type LLM struct { OpenAIOrgID string `mapstructure:"openai_org_id"` } +type EmbeddingsClient struct { + Enabled bool `mapstructure:"enabled"` + Service string `mapstructure:"service"` + OpenAIAPIKey string `mapstructure:"openai_api_key"` + AzureOpenAIEndpoint string `mapstructure:"azure_openai_endpoint"` + AzureOpenAIModel AzureOpenAIConfig `mapstructure:"azure_openai"` + OpenAIEndpoint string `mapstructure:"openai_endpoint"` + OpenAIOrgID string `mapstructure:"openai_org_id"` +} + type AzureOpenAIConfig struct { LLMDeployment string `mapstructure:"llm_deployment"` EmbeddingDeployment string `mapstructure:"embedding_deployment"` diff --git a/pkg/llms/embeddings.go b/pkg/llms/embeddings.go index 44d18eed..196a80f4 100644 --- a/pkg/llms/embeddings.go +++ b/pkg/llms/embeddings.go @@ -20,14 +20,23 @@ func EmbedTexts( return nil, errors.New("no text to embed") } - if appState.LLMClient == nil { + llmClient := appState.LLMClient + + if llmClient == nil { return nil, errors.New(InvalidLLMModelError) } if model.Service == "local" { return embedTextsLocal(ctx, appState, documentType, text) } - return appState.LLMClient.EmbedTexts(ctx, text) + + embeddingsClient := appState.EmbeddingsClient + + if embeddingsClient != nil { + return embeddingsClient.EmbedTexts(ctx, text) + } + + return llmClient.EmbedTexts(ctx, text) } func GetEmbeddingModel( diff --git a/pkg/llms/embeddings_openai.go b/pkg/llms/embeddings_openai.go new file mode 100644 index 00000000..c0d8687b --- /dev/null +++ b/pkg/llms/embeddings_openai.go @@ -0,0 +1,88 @@ +package llms + +import ( + "context" + + "github.com/getzep/zep/config" + "github.com/getzep/zep/pkg/models" + "github.com/tmc/langchaingo/llms/openai" +) + +const EmbeddingsOpenAIAPIKeyNotSetError = "ZEP_EMBEDDINGS_OPENAI_API_KEY is not set" //nolint:gosec + +var _ models.ZepEmbeddingsClient = &ZepOpenAIEmbeddingsClient{} + +func NewOpenAIEmbeddingsClient(ctx context.Context, cfg *config.Config) (*ZepOpenAIEmbeddingsClient, error) { + zembeddings := &ZepOpenAIEmbeddingsClient{} + err := zembeddings.Init(ctx, cfg) + if err != nil { + return nil, err + } + return zembeddings, nil +} + +type ZepOpenAIEmbeddingsClient struct { + client *openai.Chat +} + +func (zembeddings *ZepOpenAIEmbeddingsClient) Init(_ context.Context, cfg *config.Config) error { + + options, err := zembeddings.configureClient(cfg) + if err != nil { + return err + } + + // Create a new client instance with options. + // Even if it will just used for embeddings, + // it uses the same langchain openai chat client builder + client, err := NewOpenAIChatClient(options...) + zembeddings.client = client + + return nil +} + +func (zembeddings *ZepOpenAIEmbeddingsClient) EmbedTexts(ctx context.Context, texts []string) ([][]float32, error) { + return EmbedTextsWithOpenAIClient(ctx, texts, zembeddings.client) +} + +func getValidOpenAIModel() string { + for k := range ValidOpenAILLMs { + return k + } + return "gpt-3.5-turbo" +} + +func (zembeddings *ZepOpenAIEmbeddingsClient) configureClient(cfg *config.Config) ([]openai.Option, error) { + // Retrieve the OpenAIAPIKey from configuration + apiKey := GetOpenAIAPIKey(cfg, "embeddings") + + if cfg.EmbeddingsClient.AzureOpenAIEndpoint != "" && cfg.EmbeddingsClient.OpenAIEndpoint != "" { + log.Fatal("only one of AzureOpenAIEndpoint or OpenAIEndpoint can be set") + } + + // Even if it will only be used for embeddings, we should pass a valid openai llm model + // to avoid any errors + validOpenaiLLMModel := getValidOpenAIModel() + + options := GetBaseOpenAIClientOptions(apiKey, validOpenaiLLMModel) + + switch { + case cfg.EmbeddingsClient.AzureOpenAIEndpoint != "": + options = append( + options, + openai.WithAPIType(openai.APITypeAzure), + openai.WithBaseURL(cfg.EmbeddingsClient.AzureOpenAIEndpoint), + openai.WithEmbeddingModel(cfg.EmbeddingsClient.AzureOpenAIModel.EmbeddingDeployment), + ) + case cfg.EmbeddingsClient.OpenAIEndpoint != "": + // If an alternate OpenAI-compatible endpoint Path is set, use this as the base Path for requests + options = append( + options, + openai.WithBaseURL(cfg.EmbeddingsClient.OpenAIEndpoint), + ) + case cfg.EmbeddingsClient.OpenAIOrgID != "": + options = append(options, openai.WithOrganization(cfg.EmbeddingsClient.OpenAIOrgID)) + } + + return options, nil +} diff --git a/pkg/llms/embeddings_openai_test.go b/pkg/llms/embeddings_openai_test.go new file mode 100644 index 00000000..4087bb14 --- /dev/null +++ b/pkg/llms/embeddings_openai_test.go @@ -0,0 +1,118 @@ +package llms + +import ( + "context" + "testing" + + "github.com/getzep/zep/pkg/testutils" + + "github.com/stretchr/testify/assert" + + "github.com/getzep/zep/config" +) + +func TestZepOpenAIEmbeddings_Init(t *testing.T) { + cfg := &config.Config{ + EmbeddingsClient: config.EmbeddingsClient{ + OpenAIAPIKey: "test-key", + }, + } + + zembeddings := &ZepOpenAIEmbeddingsClient{} + + err := zembeddings.Init(context.Background(), cfg) + assert.NoError(t, err, "Expected no error from Init") + assert.NotNil(t, zembeddings.client, "Expected client to be initialized") +} + +func TestZepOpenAIEmbeddings_TestConfigureClient(t *testing.T) { + zembeddings := &ZepOpenAIEmbeddingsClient{} + + t.Run("Test with OpenAIAPIKey", func(t *testing.T) { + cfg := &config.Config{ + EmbeddingsClient: config.EmbeddingsClient{ + OpenAIAPIKey: "test-key", + }, + } + + options, err := zembeddings.configureClient(cfg) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if len(options) != 3 { + t.Errorf("Expected 2 options, got %d", len(options)) + } + }) + + t.Run("Test with AzureOpenAIEmbeddingModel", func(t *testing.T) { + cfg := &config.Config{ + EmbeddingsClient: config.EmbeddingsClient{ + OpenAIAPIKey: "test-key", + AzureOpenAIEndpoint: "https://azure.openai.com", + AzureOpenAIModel: config.AzureOpenAIConfig{ + EmbeddingDeployment: "test-embedding-deployment", + }, + }, + } + + options, err := zembeddings.configureClient(cfg) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if len(options) != 6 { + t.Errorf("Expected 6 options, got %d", len(options)) + } + }) + + t.Run("Test with OpenAIEndpoint", func(t *testing.T) { + cfg := &config.Config{ + EmbeddingsClient: config.EmbeddingsClient{ + OpenAIAPIKey: "test-key", + OpenAIEndpoint: "https://openai.com", + }, + } + + options, err := zembeddings.configureClient(cfg) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if len(options) != 4 { + t.Errorf("Expected 3 options, got %d", len(options)) + } + }) + + t.Run("Test with OpenAIOrgID", func(t *testing.T) { + cfg := &config.Config{ + EmbeddingsClient: config.EmbeddingsClient{ + OpenAIAPIKey: "test-key", + OpenAIOrgID: "org-id", + }, + } + + options, err := zembeddings.configureClient(cfg) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if len(options) != 4 { + t.Errorf("Expected 3 options, got %d", len(options)) + } + }) +} + +func TestZepOpenAIEmbeddings_EmbedTexts(t *testing.T) { + cfg := testutils.NewTestConfig() + + zembeddings, err := NewOpenAIEmbeddingsClient(context.Background(), cfg) + assert.NoError(t, err, "Expected no error from NewOpenAIEmbeddingsClient") + + texts := []string{"Hello, world!", "Another text"} + embeddings, err := zembeddings.EmbedTexts(context.Background(), texts) + assert.NoError(t, err, "Expected no error from EmbedTexts") + assert.Equal(t, len(texts), len(embeddings), "Expected embeddings to have same length as texts") + assert.NotZero(t, embeddings[0], "Expected embeddings to be non-zero") + assert.NotZero(t, embeddings[1], "Expected embeddings to be non-zero") +} diff --git a/pkg/llms/llm_base.go b/pkg/llms/llm_base.go index 5be35ba9..f7c7137f 100644 --- a/pkg/llms/llm_base.go +++ b/pkg/llms/llm_base.go @@ -7,6 +7,7 @@ import ( "time" "github.com/getzep/zep/pkg/models" + "github.com/tmc/langchaingo/llms/openai" "github.com/hashicorp/go-retryablehttp" @@ -17,6 +18,11 @@ import ( const DefaultTemperature = 0.0 const InvalidLLMModelError = "llm model is not set or is invalid" +const InvalidEmbeddingsClientError = "embeddings client is not set or is invalid" + +var InvalidEmbeddingsDeploymentError = func(service string) error { + return fmt.Errorf("invalid embeddings deployment for %s, deployment name is required", service) +} var log = internal.GetLogger() @@ -43,10 +49,8 @@ func NewLLMClient(ctx context.Context, cfg *config.Config) (models.ZepLLM, error // EmbeddingsDeployment is only required if Zep is also configured to use // OpenAI embeddings for document or message extractors if cfg.LLM.AzureOpenAIModel.EmbeddingDeployment == "" && useOpenAIEmbeddings(cfg) { - return nil, fmt.Errorf( - "invalid embeddings deployment for %s, deployment name is required", - cfg.LLM.Service, - ) + err := InvalidEmbeddingsDeploymentError(cfg.EmbeddingsClient.Service) + return nil, err } return NewOpenAILLM(ctx, cfg) } @@ -80,6 +84,25 @@ func NewLLMClient(ctx context.Context, cfg *config.Config) (models.ZepLLM, error } } +func NewEmbeddingsClient(ctx context.Context, cfg *config.Config) (models.ZepEmbeddingsClient, error) { + switch cfg.EmbeddingsClient.Service { + // For now we only support OpenAI embeddings + case "openai": + // EmbeddingsDeployment is required if using external embeddings with AzureOpenAI + if cfg.EmbeddingsClient.AzureOpenAIEndpoint != "" && cfg.EmbeddingsClient.AzureOpenAIModel.EmbeddingDeployment == "" { + err := InvalidEmbeddingsDeploymentError(cfg.EmbeddingsClient.Service) + return nil, err + } + // The logic is the same if custom OpenAI Endpoint is set or not + // since the model name will be set automatically in this case + return NewOpenAIEmbeddingsClient(ctx, cfg) + case "": + return NewOpenAIEmbeddingsClient(ctx, cfg) + default: + return nil, fmt.Errorf("invalid embeddings service: %s", cfg.EmbeddingsClient.Service) + } +} + type LLMError struct { message string originalError error @@ -180,3 +203,60 @@ func useOpenAIEmbeddings(cfg *config.Config) bool { return false } + +func NewOpenAIChatClient(options ...openai.Option) (*openai.Chat, error) { + client, err := openai.NewChat(options...) + if err != nil { + return nil, err + } + return client, nil +} + +func GetOpenAIAPIKey(cfg *config.Config, clientType string) string { + var apiKey string + + if clientType == "embeddings" { + apiKey = cfg.EmbeddingsClient.OpenAIAPIKey + // If the key is not set, log a fatal error and exit + if apiKey == "" { + log.Fatal(EmbeddingsOpenAIAPIKeyNotSetError) + } + } else { + apiKey = cfg.LLM.OpenAIAPIKey + if apiKey == "" { + log.Fatal(EmbeddingsOpenAIAPIKeyNotSetError) + } + } + return apiKey +} + +func EmbedTextsWithOpenAIClient(ctx context.Context, texts []string, openAIClient *openai.Chat) ([][]float32, error) { + // If the LLM is not initialized, return an error + if openAIClient == nil { + return nil, NewLLMError(InvalidLLMModelError, nil) + } + + thisCtx, cancel := context.WithTimeout(ctx, OpenAIAPITimeout) + defer cancel() + + embeddings, err := openAIClient.CreateEmbedding(thisCtx, texts) + if err != nil { + return nil, NewLLMError("error while creating embedding", err) + } + + return embeddings, nil +} + +func GetBaseOpenAIClientOptions(apiKey, validModel string) []openai.Option { + retryableHTTPClient := NewRetryableHTTPClient(MaxOpenAIAPIRequestAttempts, OpenAIAPITimeout) + + options := make([]openai.Option, 0) + options = append( + options, + openai.WithHTTPClient(retryableHTTPClient.StandardClient()), + openai.WithModel(validModel), + openai.WithToken(apiKey), + ) + + return options +} diff --git a/pkg/llms/llm_openai.go b/pkg/llms/llm_openai.go index ee77bfbf..8d4a9607 100644 --- a/pkg/llms/llm_openai.go +++ b/pkg/llms/llm_openai.go @@ -49,10 +49,7 @@ func (zllm *ZepOpenAILLM) Init(_ context.Context, cfg *config.Config) error { } // Create a new client instance with options - llm, err := openai.NewChat(options...) - if err != nil { - return err - } + llm, err := NewOpenAIChatClient(options...) zllm.llm = llm return nil @@ -85,20 +82,7 @@ func (zllm *ZepOpenAILLM) Call(ctx context.Context, } func (zllm *ZepOpenAILLM) EmbedTexts(ctx context.Context, texts []string) ([][]float32, error) { - // If the LLM is not initialized, return an error - if zllm.llm == nil { - return nil, NewLLMError(InvalidLLMModelError, nil) - } - - thisCtx, cancel := context.WithTimeout(ctx, OpenAIAPITimeout) - defer cancel() - - embeddings, err := zllm.llm.CreateEmbedding(thisCtx, texts) - if err != nil { - return nil, NewLLMError("error while creating embedding", err) - } - - return embeddings, nil + return EmbedTextsWithOpenAIClient(ctx, texts, zllm.llm) } // GetTokenCount returns the number of tokens in the text @@ -108,24 +92,13 @@ func (zllm *ZepOpenAILLM) GetTokenCount(text string) (int, error) { func (zllm *ZepOpenAILLM) configureClient(cfg *config.Config) ([]openai.Option, error) { // Retrieve the OpenAIAPIKey from configuration - apiKey := cfg.LLM.OpenAIAPIKey - // If the key is not set, log a fatal error and exit - if apiKey == "" { - log.Fatal(OpenAIAPIKeyNotSetError) - } + apiKey := GetOpenAIAPIKey(cfg, "llm") + if cfg.LLM.AzureOpenAIEndpoint != "" && cfg.LLM.OpenAIEndpoint != "" { log.Fatal("only one of AzureOpenAIEndpoint or OpenAIEndpoint can be set") } - retryableHTTPClient := NewRetryableHTTPClient(MaxOpenAIAPIRequestAttempts, OpenAIAPITimeout) - - options := make([]openai.Option, 0) - options = append( - options, - openai.WithHTTPClient(retryableHTTPClient.StandardClient()), - openai.WithModel(cfg.LLM.Model), - openai.WithToken(apiKey), - ) + options := GetBaseOpenAIClientOptions(apiKey, cfg.LLM.Model) switch { case cfg.LLM.AzureOpenAIEndpoint != "": diff --git a/pkg/models/appstate.go b/pkg/models/appstate.go index 2496d67d..bc326238 100644 --- a/pkg/models/appstate.go +++ b/pkg/models/appstate.go @@ -7,11 +7,12 @@ import ( // AppState is a struct that holds the state of the application // Use cmd.NewAppState to create a new instance type AppState struct { - LLMClient ZepLLM - MemoryStore MemoryStore[any] - DocumentStore DocumentStore[any] - UserStore UserStore - TaskRouter TaskRouter - TaskPublisher TaskPublisher - Config *config.Config + LLMClient ZepLLM + EmbeddingsClient ZepEmbeddingsClient + MemoryStore MemoryStore[any] + DocumentStore DocumentStore[any] + UserStore UserStore + TaskRouter TaskRouter + TaskPublisher TaskPublisher + Config *config.Config } diff --git a/pkg/models/zembeddings.go b/pkg/models/zembeddings.go new file mode 100644 index 00000000..5fa542fe --- /dev/null +++ b/pkg/models/zembeddings.go @@ -0,0 +1,14 @@ +package models + +import ( + "context" + + "github.com/getzep/zep/config" +) + +type ZepEmbeddingsClient interface { + // EmbedTexts embeds the given texts + EmbedTexts(ctx context.Context, texts []string) ([][]float32, error) + // Init initializes the Client + Init(ctx context.Context, cfg *config.Config) error +} diff --git a/pkg/server/webhandlers/settings.go b/pkg/server/webhandlers/settings.go index d9882b5a..741a5172 100644 --- a/pkg/server/webhandlers/settings.go +++ b/pkg/server/webhandlers/settings.go @@ -23,6 +23,7 @@ func redactHTMLEncodeConfig(cfg *config.Config) (*config.Config, error) { redactedConfig := *cfg redactedConfig.LLM.AnthropicAPIKey = "**redacted**" redactedConfig.LLM.OpenAIAPIKey = "**redacted**" + redactedConfig.EmbeddingsClient.OpenAIAPIKey = "**redacted**" redactedConfig.Auth.Secret = "**redacted**" re := regexp.MustCompile(`(?i:postgres://[^:]+:)([^@]+)`) diff --git a/pkg/testutils/utils.go b/pkg/testutils/utils.go index cbc75e76..6d519aa1 100644 --- a/pkg/testutils/utils.go +++ b/pkg/testutils/utils.go @@ -100,6 +100,8 @@ func testConfigDefaults() (*config.Config, error) { testConfig.LLM.AnthropicAPIKey = os.Getenv(envVar) case "llm.openai_api_key": testConfig.LLM.OpenAIAPIKey = os.Getenv(envVar) + case "embeddings_client.openai_api_key": + testConfig.EmbeddingsClient.OpenAIAPIKey = os.Getenv(envVar) case "auth.secret": testConfig.Auth.Secret = os.Getenv(envVar) case "development": From 9e49d9cd4121ef03e11a26288cfd1a30dbfa5857 Mon Sep 17 00:00:00 2001 From: Brice Macias Date: Mon, 30 Oct 2023 13:14:20 -0600 Subject: [PATCH 2/8] refactoring --- config.yaml | 2 +- pkg/llms/embeddings_base.go | 42 ++++++++++ pkg/llms/embeddings_openai.go | 27 +----- pkg/llms/llm_base.go | 78 ----------------- pkg/llms/llm_openai.go | 36 +------- pkg/llms/openai_base.go | 153 ++++++++++++++++++++++++++++++++++ 6 files changed, 204 insertions(+), 134 deletions(-) create mode 100644 pkg/llms/embeddings_base.go create mode 100644 pkg/llms/openai_base.go diff --git a/config.yaml b/config.yaml index d3fa94f4..5e2bed28 100644 --- a/config.yaml +++ b/config.yaml @@ -26,7 +26,7 @@ embeddings_client: service: "openai" azure_openai_endpoint: azure_openai: - #embeddings deployment is required when using azure deployment for embeddings + # embeddings deployment is required when using azure deployment for embeddings embedding_deployment: "text-embedding-ada-002-customname" openai_endpoint: openai_org_id: diff --git a/pkg/llms/embeddings_base.go b/pkg/llms/embeddings_base.go new file mode 100644 index 00000000..910e0d8e --- /dev/null +++ b/pkg/llms/embeddings_base.go @@ -0,0 +1,42 @@ +package llms + +import ( + "context" + "fmt" + + "github.com/getzep/zep/pkg/models" + + "github.com/getzep/zep/config" +) + +const InvalidEmbeddingsClientError = "embeddings client is not set or is invalid" + +type EmbeddingsClientError struct { + message string + originalError error +} + +func (e *EmbeddingsClientError) Error() string { + return fmt.Sprintf("embeddings client error: %s (original error: %v)", e.message, e.originalError) +} + +func NewEmbeddingsClientError(message string, originalError error) *EmbeddingsClientError { + return &EmbeddingsClientError{message: message, originalError: originalError} +} + +func NewEmbeddingsClient(ctx context.Context, cfg *config.Config) (models.ZepEmbeddingsClient, error) { + switch cfg.EmbeddingsClient.Service { + // For now we only support OpenAI embeddings + case "openai": + // EmbeddingsDeployment is required if using external embeddings with AzureOpenAI + if cfg.EmbeddingsClient.AzureOpenAIEndpoint != "" && cfg.EmbeddingsClient.AzureOpenAIModel.EmbeddingDeployment == "" { + err := InvalidEmbeddingsDeploymentError(cfg.EmbeddingsClient.Service) + return nil, err + } + // The logic is the same if custom OpenAI Endpoint is set or not + // since the model name will be set automatically in this case + return NewOpenAIEmbeddingsClient(ctx, cfg) + default: + return nil, fmt.Errorf("invalid embeddings service: %s", cfg.EmbeddingsClient.Service) + } +} diff --git a/pkg/llms/embeddings_openai.go b/pkg/llms/embeddings_openai.go index c0d8687b..7c1098cc 100644 --- a/pkg/llms/embeddings_openai.go +++ b/pkg/llms/embeddings_openai.go @@ -26,7 +26,6 @@ type ZepOpenAIEmbeddingsClient struct { } func (zembeddings *ZepOpenAIEmbeddingsClient) Init(_ context.Context, cfg *config.Config) error { - options, err := zembeddings.configureClient(cfg) if err != nil { return err @@ -42,7 +41,7 @@ func (zembeddings *ZepOpenAIEmbeddingsClient) Init(_ context.Context, cfg *confi } func (zembeddings *ZepOpenAIEmbeddingsClient) EmbedTexts(ctx context.Context, texts []string) ([][]float32, error) { - return EmbedTextsWithOpenAIClient(ctx, texts, zembeddings.client) + return EmbedTextsWithOpenAIClient(ctx, texts, zembeddings.client, EmbeddingsClientType) } func getValidOpenAIModel() string { @@ -54,11 +53,9 @@ func getValidOpenAIModel() string { func (zembeddings *ZepOpenAIEmbeddingsClient) configureClient(cfg *config.Config) ([]openai.Option, error) { // Retrieve the OpenAIAPIKey from configuration - apiKey := GetOpenAIAPIKey(cfg, "embeddings") + apiKey := GetOpenAIAPIKey(cfg, EmbeddingsClientType) - if cfg.EmbeddingsClient.AzureOpenAIEndpoint != "" && cfg.EmbeddingsClient.OpenAIEndpoint != "" { - log.Fatal("only one of AzureOpenAIEndpoint or OpenAIEndpoint can be set") - } + validateOpenAIConfig(cfg, EmbeddingsClientType) // Even if it will only be used for embeddings, we should pass a valid openai llm model // to avoid any errors @@ -66,23 +63,7 @@ func (zembeddings *ZepOpenAIEmbeddingsClient) configureClient(cfg *config.Config options := GetBaseOpenAIClientOptions(apiKey, validOpenaiLLMModel) - switch { - case cfg.EmbeddingsClient.AzureOpenAIEndpoint != "": - options = append( - options, - openai.WithAPIType(openai.APITypeAzure), - openai.WithBaseURL(cfg.EmbeddingsClient.AzureOpenAIEndpoint), - openai.WithEmbeddingModel(cfg.EmbeddingsClient.AzureOpenAIModel.EmbeddingDeployment), - ) - case cfg.EmbeddingsClient.OpenAIEndpoint != "": - // If an alternate OpenAI-compatible endpoint Path is set, use this as the base Path for requests - options = append( - options, - openai.WithBaseURL(cfg.EmbeddingsClient.OpenAIEndpoint), - ) - case cfg.EmbeddingsClient.OpenAIOrgID != "": - options = append(options, openai.WithOrganization(cfg.EmbeddingsClient.OpenAIOrgID)) - } + options = ConfigureOpenAIClientOptions(options, cfg, EmbeddingsClientType) return options, nil } diff --git a/pkg/llms/llm_base.go b/pkg/llms/llm_base.go index f7c7137f..67d997cc 100644 --- a/pkg/llms/llm_base.go +++ b/pkg/llms/llm_base.go @@ -7,7 +7,6 @@ import ( "time" "github.com/getzep/zep/pkg/models" - "github.com/tmc/langchaingo/llms/openai" "github.com/hashicorp/go-retryablehttp" @@ -18,7 +17,6 @@ import ( const DefaultTemperature = 0.0 const InvalidLLMModelError = "llm model is not set or is invalid" -const InvalidEmbeddingsClientError = "embeddings client is not set or is invalid" var InvalidEmbeddingsDeploymentError = func(service string) error { return fmt.Errorf("invalid embeddings deployment for %s, deployment name is required", service) @@ -84,25 +82,6 @@ func NewLLMClient(ctx context.Context, cfg *config.Config) (models.ZepLLM, error } } -func NewEmbeddingsClient(ctx context.Context, cfg *config.Config) (models.ZepEmbeddingsClient, error) { - switch cfg.EmbeddingsClient.Service { - // For now we only support OpenAI embeddings - case "openai": - // EmbeddingsDeployment is required if using external embeddings with AzureOpenAI - if cfg.EmbeddingsClient.AzureOpenAIEndpoint != "" && cfg.EmbeddingsClient.AzureOpenAIModel.EmbeddingDeployment == "" { - err := InvalidEmbeddingsDeploymentError(cfg.EmbeddingsClient.Service) - return nil, err - } - // The logic is the same if custom OpenAI Endpoint is set or not - // since the model name will be set automatically in this case - return NewOpenAIEmbeddingsClient(ctx, cfg) - case "": - return NewOpenAIEmbeddingsClient(ctx, cfg) - default: - return nil, fmt.Errorf("invalid embeddings service: %s", cfg.EmbeddingsClient.Service) - } -} - type LLMError struct { message string originalError error @@ -203,60 +182,3 @@ func useOpenAIEmbeddings(cfg *config.Config) bool { return false } - -func NewOpenAIChatClient(options ...openai.Option) (*openai.Chat, error) { - client, err := openai.NewChat(options...) - if err != nil { - return nil, err - } - return client, nil -} - -func GetOpenAIAPIKey(cfg *config.Config, clientType string) string { - var apiKey string - - if clientType == "embeddings" { - apiKey = cfg.EmbeddingsClient.OpenAIAPIKey - // If the key is not set, log a fatal error and exit - if apiKey == "" { - log.Fatal(EmbeddingsOpenAIAPIKeyNotSetError) - } - } else { - apiKey = cfg.LLM.OpenAIAPIKey - if apiKey == "" { - log.Fatal(EmbeddingsOpenAIAPIKeyNotSetError) - } - } - return apiKey -} - -func EmbedTextsWithOpenAIClient(ctx context.Context, texts []string, openAIClient *openai.Chat) ([][]float32, error) { - // If the LLM is not initialized, return an error - if openAIClient == nil { - return nil, NewLLMError(InvalidLLMModelError, nil) - } - - thisCtx, cancel := context.WithTimeout(ctx, OpenAIAPITimeout) - defer cancel() - - embeddings, err := openAIClient.CreateEmbedding(thisCtx, texts) - if err != nil { - return nil, NewLLMError("error while creating embedding", err) - } - - return embeddings, nil -} - -func GetBaseOpenAIClientOptions(apiKey, validModel string) []openai.Option { - retryableHTTPClient := NewRetryableHTTPClient(MaxOpenAIAPIRequestAttempts, OpenAIAPITimeout) - - options := make([]openai.Option, 0) - options = append( - options, - openai.WithHTTPClient(retryableHTTPClient.StandardClient()), - openai.WithModel(validModel), - openai.WithToken(apiKey), - ) - - return options -} diff --git a/pkg/llms/llm_openai.go b/pkg/llms/llm_openai.go index 8d4a9607..40073d7d 100644 --- a/pkg/llms/llm_openai.go +++ b/pkg/llms/llm_openai.go @@ -2,7 +2,6 @@ package llms import ( "context" - "time" "github.com/tmc/langchaingo/schema" @@ -14,9 +13,7 @@ import ( "github.com/tmc/langchaingo/llms/openai" ) -const OpenAIAPITimeout = 90 * time.Second const OpenAIAPIKeyNotSetError = "ZEP_OPENAI_API_KEY is not set" //nolint:gosec -const MaxOpenAIAPIRequestAttempts = 5 var _ models.ZepLLM = &ZepOpenAILLM{} @@ -82,7 +79,7 @@ func (zllm *ZepOpenAILLM) Call(ctx context.Context, } func (zllm *ZepOpenAILLM) EmbedTexts(ctx context.Context, texts []string) ([][]float32, error) { - return EmbedTextsWithOpenAIClient(ctx, texts, zllm.llm) + return EmbedTextsWithOpenAIClient(ctx, texts, zllm.llm, LLMClientType) } // GetTokenCount returns the number of tokens in the text @@ -92,38 +89,13 @@ func (zllm *ZepOpenAILLM) GetTokenCount(text string) (int, error) { func (zllm *ZepOpenAILLM) configureClient(cfg *config.Config) ([]openai.Option, error) { // Retrieve the OpenAIAPIKey from configuration - apiKey := GetOpenAIAPIKey(cfg, "llm") + apiKey := GetOpenAIAPIKey(cfg, LLMClientType) - if cfg.LLM.AzureOpenAIEndpoint != "" && cfg.LLM.OpenAIEndpoint != "" { - log.Fatal("only one of AzureOpenAIEndpoint or OpenAIEndpoint can be set") - } + validateOpenAIConfig(cfg, LLMClientType) options := GetBaseOpenAIClientOptions(apiKey, cfg.LLM.Model) - switch { - case cfg.LLM.AzureOpenAIEndpoint != "": - // Check configuration for AzureOpenAIEndpoint; if it's set, use the DefaultAzureConfig - // and provided endpoint Path - options = append( - options, - openai.WithAPIType(openai.APITypeAzure), - openai.WithBaseURL(cfg.LLM.AzureOpenAIEndpoint), - ) - if cfg.LLM.AzureOpenAIModel.EmbeddingDeployment != "" { - options = append( - options, - openai.WithEmbeddingModel(cfg.LLM.AzureOpenAIModel.EmbeddingDeployment), - ) - } - case cfg.LLM.OpenAIEndpoint != "": - // If an alternate OpenAI-compatible endpoint Path is set, use this as the base Path for requests - options = append( - options, - openai.WithBaseURL(cfg.LLM.OpenAIEndpoint), - ) - case cfg.LLM.OpenAIOrgID != "": - options = append(options, openai.WithOrganization(cfg.LLM.OpenAIOrgID)) - } + options = ConfigureOpenAIClientOptions(options, cfg, LLMClientType) return options, nil } diff --git a/pkg/llms/openai_base.go b/pkg/llms/openai_base.go new file mode 100644 index 00000000..d99acfe1 --- /dev/null +++ b/pkg/llms/openai_base.go @@ -0,0 +1,153 @@ +package llms + +import ( + "context" + "time" + + "github.com/getzep/zep/config" + "github.com/tmc/langchaingo/llms/openai" +) + +const OpenAIAPITimeout = 90 * time.Second +const MaxOpenAIAPIRequestAttempts = 5 + +type ClientType string + +const ( + EmbeddingsClientType ClientType = "embeddings" + LLMClientType ClientType = "llm" +) + +func NewOpenAIChatClient(options ...openai.Option) (*openai.Chat, error) { + client, err := openai.NewChat(options...) + if err != nil { + return nil, err + } + return client, nil +} + +func GetOpenAIAPIKey(cfg *config.Config, clientType ClientType) string { + var apiKey string + + if clientType == EmbeddingsClientType { + apiKey = cfg.EmbeddingsClient.OpenAIAPIKey + // If the key is not set, log a fatal error and exit + if apiKey == "" { + log.Fatal(EmbeddingsOpenAIAPIKeyNotSetError) + } + } else { + apiKey = cfg.LLM.OpenAIAPIKey + if apiKey == "" { + log.Fatal(OpenAIAPIKeyNotSetError) + } + } + return apiKey +} + +func validateOpenAIConfig(cfg *config.Config, clientType ClientType) { + + var azureEndpoint string + var openAIEndpoint string + + if clientType == EmbeddingsClientType { + azureEndpoint = cfg.EmbeddingsClient.AzureOpenAIEndpoint + openAIEndpoint = cfg.EmbeddingsClient.OpenAIEndpoint + } else { + azureEndpoint = cfg.LLM.AzureOpenAIEndpoint + openAIEndpoint = cfg.LLM.OpenAIEndpoint + } + + if azureEndpoint != "" && openAIEndpoint != "" { + log.Fatal("only one of AzureOpenAIEndpoint or OpenAIEndpoint can be set") + } +} + +func EmbedTextsWithOpenAIClient(ctx context.Context, texts []string, openAIClient *openai.Chat, clientType ClientType) ([][]float32, error) { + // If the Client is not initialized, return an error + if openAIClient == nil { + if clientType == EmbeddingsClientType { + return nil, NewEmbeddingsClientError(InvalidEmbeddingsClientError, nil) + } + return nil, NewLLMError(InvalidLLMModelError, nil) + } + + thisCtx, cancel := context.WithTimeout(ctx, OpenAIAPITimeout) + defer cancel() + + embeddings, err := openAIClient.CreateEmbedding(thisCtx, texts) + if err != nil { + message := "error while creating embedding" + if clientType == EmbeddingsClientType { + return nil, NewEmbeddingsClientError(message, nil) + } + return nil, NewLLMError(message, err) + } + + return embeddings, nil +} + +func GetBaseOpenAIClientOptions(apiKey, validModel string) []openai.Option { + retryableHTTPClient := NewRetryableHTTPClient(MaxOpenAIAPIRequestAttempts, OpenAIAPITimeout) + + options := make([]openai.Option, 0) + options = append( + options, + openai.WithHTTPClient(retryableHTTPClient.StandardClient()), + openai.WithModel(validModel), + openai.WithToken(apiKey), + ) + + return options +} + +func ConfigureOpenAIClientOptions(options []openai.Option, cfg *config.Config, clientType ClientType) []openai.Option { + applyOption := func(cond bool, opts ...openai.Option) []openai.Option { + if cond { + return append(options, opts...) + } + return options + } + + var openAIEndpoint string + var openAIOrgID string + + if clientType == EmbeddingsClientType { + openAIEndpoint = cfg.EmbeddingsClient.OpenAIEndpoint + openAIOrgID = cfg.EmbeddingsClient.OpenAIOrgID + + // Check configuration for AzureOpenAIEndpoint; if it's set, use the DefaultAzureConfig + // and provided endpoint Path. + // WithEmbeddings is always required in case of embeddings client + options = applyOption(cfg.EmbeddingsClient.AzureOpenAIEndpoint != "", + openai.WithAPIType(openai.APITypeAzure), + openai.WithBaseURL(cfg.EmbeddingsClient.AzureOpenAIEndpoint), + openai.WithEmbeddingModel(cfg.EmbeddingsClient.AzureOpenAIModel.EmbeddingDeployment), + ) + } else { + openAIEndpoint = cfg.LLM.OpenAIEndpoint + openAIOrgID = cfg.LLM.OpenAIOrgID + + options = append( + options, + openai.WithAPIType(openai.APITypeAzure), + openai.WithBaseURL(cfg.LLM.AzureOpenAIEndpoint), + ) + if cfg.LLM.AzureOpenAIModel.EmbeddingDeployment != "" { + options = append( + options, + openai.WithEmbeddingModel(cfg.LLM.AzureOpenAIModel.EmbeddingDeployment), + ) + } + + } + + options = applyOption(openAIEndpoint != "", + openai.WithBaseURL(openAIEndpoint), + ) + + options = applyOption(openAIOrgID != "", + openai.WithOrganization(openAIOrgID), + ) + + return options +} From 8737e6d1d9339059a5a8e44e2f904f86d3571f0f Mon Sep 17 00:00:00 2001 From: Brice Macias Date: Mon, 30 Oct 2023 14:22:24 -0600 Subject: [PATCH 3/8] refactoring test cases --- pkg/llms/embeddings_openai.go | 2 +- pkg/llms/embeddings_openai_test.go | 43 ++++------------------ pkg/llms/llm_openai.go | 2 +- pkg/llms/llm_openai_test.go | 43 ++++------------------ pkg/llms/openai_base.go | 2 +- pkg/llms/openai_base_test.go | 57 ++++++++++++++++++++++++++++++ 6 files changed, 74 insertions(+), 75 deletions(-) create mode 100644 pkg/llms/openai_base_test.go diff --git a/pkg/llms/embeddings_openai.go b/pkg/llms/embeddings_openai.go index 7c1098cc..33152c8a 100644 --- a/pkg/llms/embeddings_openai.go +++ b/pkg/llms/embeddings_openai.go @@ -55,7 +55,7 @@ func (zembeddings *ZepOpenAIEmbeddingsClient) configureClient(cfg *config.Config // Retrieve the OpenAIAPIKey from configuration apiKey := GetOpenAIAPIKey(cfg, EmbeddingsClientType) - validateOpenAIConfig(cfg, EmbeddingsClientType) + ValidateOpenAIConfig(cfg, EmbeddingsClientType) // Even if it will only be used for embeddings, we should pass a valid openai llm model // to avoid any errors diff --git a/pkg/llms/embeddings_openai_test.go b/pkg/llms/embeddings_openai_test.go index 4087bb14..e025bd41 100644 --- a/pkg/llms/embeddings_openai_test.go +++ b/pkg/llms/embeddings_openai_test.go @@ -21,8 +21,7 @@ func TestZepOpenAIEmbeddings_Init(t *testing.T) { zembeddings := &ZepOpenAIEmbeddingsClient{} err := zembeddings.Init(context.Background(), cfg) - assert.NoError(t, err, "Expected no error from Init") - assert.NotNil(t, zembeddings.client, "Expected client to be initialized") + TestOpenAIClient_Init(t, err, zembeddings.client, EmbeddingsClientType) } func TestZepOpenAIEmbeddings_TestConfigureClient(t *testing.T) { @@ -36,13 +35,7 @@ func TestZepOpenAIEmbeddings_TestConfigureClient(t *testing.T) { } options, err := zembeddings.configureClient(cfg) - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - - if len(options) != 3 { - t.Errorf("Expected 2 options, got %d", len(options)) - } + TestOpenAIClient_ConfigureClient(t, options, err, OpenAIAPIKeyTestCase) }) t.Run("Test with AzureOpenAIEmbeddingModel", func(t *testing.T) { @@ -57,13 +50,7 @@ func TestZepOpenAIEmbeddings_TestConfigureClient(t *testing.T) { } options, err := zembeddings.configureClient(cfg) - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - - if len(options) != 6 { - t.Errorf("Expected 6 options, got %d", len(options)) - } + TestOpenAIClient_ConfigureClient(t, options, err, AzureOpenAIEmbeddingModelTestCase) }) t.Run("Test with OpenAIEndpoint", func(t *testing.T) { @@ -75,13 +62,7 @@ func TestZepOpenAIEmbeddings_TestConfigureClient(t *testing.T) { } options, err := zembeddings.configureClient(cfg) - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - - if len(options) != 4 { - t.Errorf("Expected 3 options, got %d", len(options)) - } + TestOpenAIClient_ConfigureClient(t, options, err, OpenAIEndpointTestCase) }) t.Run("Test with OpenAIOrgID", func(t *testing.T) { @@ -93,13 +74,7 @@ func TestZepOpenAIEmbeddings_TestConfigureClient(t *testing.T) { } options, err := zembeddings.configureClient(cfg) - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - - if len(options) != 4 { - t.Errorf("Expected 3 options, got %d", len(options)) - } + TestOpenAIClient_ConfigureClient(t, options, err, OpenAIOrgIDTestCase) }) } @@ -109,10 +84,6 @@ func TestZepOpenAIEmbeddings_EmbedTexts(t *testing.T) { zembeddings, err := NewOpenAIEmbeddingsClient(context.Background(), cfg) assert.NoError(t, err, "Expected no error from NewOpenAIEmbeddingsClient") - texts := []string{"Hello, world!", "Another text"} - embeddings, err := zembeddings.EmbedTexts(context.Background(), texts) - assert.NoError(t, err, "Expected no error from EmbedTexts") - assert.Equal(t, len(texts), len(embeddings), "Expected embeddings to have same length as texts") - assert.NotZero(t, embeddings[0], "Expected embeddings to be non-zero") - assert.NotZero(t, embeddings[1], "Expected embeddings to be non-zero") + embeddings, err := zembeddings.EmbedTexts(context.Background(), EmbeddingsTestTexts) + TestOpenAIClient_EmbedText(t, embeddings, err) } diff --git a/pkg/llms/llm_openai.go b/pkg/llms/llm_openai.go index 40073d7d..8a9e5cea 100644 --- a/pkg/llms/llm_openai.go +++ b/pkg/llms/llm_openai.go @@ -91,7 +91,7 @@ func (zllm *ZepOpenAILLM) configureClient(cfg *config.Config) ([]openai.Option, // Retrieve the OpenAIAPIKey from configuration apiKey := GetOpenAIAPIKey(cfg, LLMClientType) - validateOpenAIConfig(cfg, LLMClientType) + ValidateOpenAIConfig(cfg, LLMClientType) options := GetBaseOpenAIClientOptions(apiKey, cfg.LLM.Model) diff --git a/pkg/llms/llm_openai_test.go b/pkg/llms/llm_openai_test.go index 4d4f15d8..b2d988ca 100644 --- a/pkg/llms/llm_openai_test.go +++ b/pkg/llms/llm_openai_test.go @@ -22,8 +22,7 @@ func TestZepOpenAILLM_Init(t *testing.T) { zllm := &ZepOpenAILLM{} err := zllm.Init(context.Background(), cfg) - assert.NoError(t, err, "Expected no error from Init") - assert.NotNil(t, zllm.llm, "Expected llm to be initialized") + TestOpenAIClient_Init(t, err, zllm.llm, LLMClientType) assert.NotNil(t, zllm.tkm, "Expected tkm to be initialized") } @@ -38,13 +37,7 @@ func TestZepOpenAILLM_TestConfigureClient(t *testing.T) { } options, err := zllm.configureClient(cfg) - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - - if len(options) != 3 { - t.Errorf("Expected 2 options, got %d", len(options)) - } + TestOpenAIClient_ConfigureClient(t, options, err, OpenAIAPIKeyTestCase) }) t.Run("Test with AzureOpenAIEndpoint", func(t *testing.T) { @@ -79,13 +72,7 @@ func TestZepOpenAILLM_TestConfigureClient(t *testing.T) { } options, err := zllm.configureClient(cfg) - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - - if len(options) != 6 { - t.Errorf("Expected 6 options, got %d", len(options)) - } + TestOpenAIClient_ConfigureClient(t, options, err, AzureOpenAIEmbeddingModelTestCase) }) t.Run("Test with OpenAIEndpointAndCustomModelName", func(t *testing.T) { @@ -98,13 +85,7 @@ func TestZepOpenAILLM_TestConfigureClient(t *testing.T) { } options, err := zllm.configureClient(cfg) - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - - if len(options) != 4 { - t.Errorf("Expected 3 options, got %d", len(options)) - } + TestOpenAIClient_ConfigureClient(t, options, err, OpenAIEndpointTestCase) }) t.Run("Test with OpenAIOrgID", func(t *testing.T) { @@ -116,13 +97,7 @@ func TestZepOpenAILLM_TestConfigureClient(t *testing.T) { } options, err := zllm.configureClient(cfg) - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - - if len(options) != 4 { - t.Errorf("Expected 3 options, got %d", len(options)) - } + TestOpenAIClient_ConfigureClient(t, options, err, OpenAIOrgIDTestCase) }) } @@ -146,12 +121,8 @@ func TestZepOpenAILLM_EmbedTexts(t *testing.T) { zllm, err := NewOpenAILLM(context.Background(), cfg) assert.NoError(t, err, "Expected no error from NewOpenAILLM") - texts := []string{"Hello, world!", "Another text"} - embeddings, err := zllm.EmbedTexts(context.Background(), texts) - assert.NoError(t, err, "Expected no error from EmbedTexts") - assert.Equal(t, len(texts), len(embeddings), "Expected embeddings to have same length as texts") - assert.NotZero(t, embeddings[0], "Expected embeddings to be non-zero") - assert.NotZero(t, embeddings[1], "Expected embeddings to be non-zero") + embeddings, err := zllm.EmbedTexts(context.Background(), EmbeddingsTestTexts) + TestOpenAIClient_EmbedText(t, embeddings, err) } func TestZepOpenAILLM_GetTokenCount(t *testing.T) { diff --git a/pkg/llms/openai_base.go b/pkg/llms/openai_base.go index d99acfe1..e5f13b35 100644 --- a/pkg/llms/openai_base.go +++ b/pkg/llms/openai_base.go @@ -44,7 +44,7 @@ func GetOpenAIAPIKey(cfg *config.Config, clientType ClientType) string { return apiKey } -func validateOpenAIConfig(cfg *config.Config, clientType ClientType) { +func ValidateOpenAIConfig(cfg *config.Config, clientType ClientType) { var azureEndpoint string var openAIEndpoint string diff --git a/pkg/llms/openai_base_test.go b/pkg/llms/openai_base_test.go new file mode 100644 index 00000000..763bc5c3 --- /dev/null +++ b/pkg/llms/openai_base_test.go @@ -0,0 +1,57 @@ +package llms + +import ( + "testing" + + "github.com/tmc/langchaingo/llms/openai" + + "github.com/stretchr/testify/assert" +) + +type TestCaseType string + +const ( + OpenAIAPIKeyTestCase TestCaseType = "OpenAIAPIKeyTestCase" + AzureOpenAIEmbeddingModelTestCase TestCaseType = "AzureOpenAIEmbeddingModelTestCase" + OpenAIEndpointTestCase TestCaseType = "OpenAIEndpointTestCase" + OpenAIOrgIDTestCase TestCaseType = "OpenAIOrgIDTestCase" +) + +var EmbeddingsTestTexts = []string{"Hello, world!", "Another text"} + +func TestOpenAIClient_Init(t *testing.T, err error, openAIClient *openai.Chat, clientType ClientType) { + assert.NoError(t, err, "Expected no error from Init") + switch clientType { + case EmbeddingsClientType: + assert.NotNil(t, openAIClient, "Expected client to be initialized") + default: + assert.NotNil(t, openAIClient, "Expected llm to be initialized") + } +} + +func TestOpenAIClient_ConfigureClient(t *testing.T, options []openai.Option, err error, testCase TestCaseType) { + assert.NoError(t, err, "Unexpected error") + expectedOptions := map[TestCaseType]int{ + OpenAIAPIKeyTestCase: 3, + AzureOpenAIEmbeddingModelTestCase: 6, + OpenAIEndpointTestCase: 4, + OpenAIOrgIDTestCase: 4, + } + expected, ok := expectedOptions[testCase] + if !ok { + t.Errorf("Unexpected test case: %s", testCase) + return + } + //? assert.Len(t, options, expected, "Unexpected number of options") + if len(options) != expected { + t.Errorf("Expected %e options, got %d", expected, len(options)) + } +} + +func TestOpenAIClient_EmbedText(t *testing.T, embeddings [][]float32, err error) { + assert.NoError(t, err, "Expected no error from EmbedTexts") + assert.Equal(t, len(EmbeddingsTestTexts), len(embeddings), "Expected embeddings to have same length as texts") + assert.NotZero(t, embeddings[0], "Expected embeddings to be non-zero") + assert.NotZero(t, embeddings[1], "Expected embeddings to be non-zero") + assert.NoError(t, err, "Unexpected error") +} From c9b930d11aaa7d5c91990563bb224b5dc9eaf0f7 Mon Sep 17 00:00:00 2001 From: Brice Macias Date: Mon, 30 Oct 2023 15:42:45 -0600 Subject: [PATCH 4/8] fixes --- config.yaml | 6 +++--- pkg/llms/embeddings_openai_test.go | 12 ++++++------ pkg/llms/llm_openai_test.go | 12 ++++++------ pkg/llms/openai_base.go | 13 ++++--------- pkg/llms/openai_base_test.go | 11 +++++++---- 5 files changed, 26 insertions(+), 28 deletions(-) diff --git a/config.yaml b/config.yaml index 5e2bed28..5c67b362 100644 --- a/config.yaml +++ b/config.yaml @@ -108,9 +108,6 @@ custom_prompts: # {{.MessagesJoined}} # Response without preamble. # - # For Open Source models compatible with the OpenAI client, - # follow the prompt guidelines provided by the model owners. - # # If left empty, the default Anthropic summary prompt from zep/pkg/extractors/prompts.go will be used. anthropic: | @@ -127,6 +124,9 @@ custom_prompts: # Current summary: {{.PrevSummary}} # New lines of conversation: {{.MessagesJoined}} # New summary:` + # + # For Open Source models compatible with the OpenAI client, + # follow the prompt guidelines provided by the model owners. # # If left empty, the default OpenAI summary prompt from zep/pkg/extractors/prompts.go will be used. openai: | diff --git a/pkg/llms/embeddings_openai_test.go b/pkg/llms/embeddings_openai_test.go index e025bd41..66b408c6 100644 --- a/pkg/llms/embeddings_openai_test.go +++ b/pkg/llms/embeddings_openai_test.go @@ -21,7 +21,7 @@ func TestZepOpenAIEmbeddings_Init(t *testing.T) { zembeddings := &ZepOpenAIEmbeddingsClient{} err := zembeddings.Init(context.Background(), cfg) - TestOpenAIClient_Init(t, err, zembeddings.client, EmbeddingsClientType) + assertInit(t, err, zembeddings.client, EmbeddingsClientType) } func TestZepOpenAIEmbeddings_TestConfigureClient(t *testing.T) { @@ -35,7 +35,7 @@ func TestZepOpenAIEmbeddings_TestConfigureClient(t *testing.T) { } options, err := zembeddings.configureClient(cfg) - TestOpenAIClient_ConfigureClient(t, options, err, OpenAIAPIKeyTestCase) + assertConfigureClient(t, options, err, OpenAIAPIKeyTestCase) }) t.Run("Test with AzureOpenAIEmbeddingModel", func(t *testing.T) { @@ -50,7 +50,7 @@ func TestZepOpenAIEmbeddings_TestConfigureClient(t *testing.T) { } options, err := zembeddings.configureClient(cfg) - TestOpenAIClient_ConfigureClient(t, options, err, AzureOpenAIEmbeddingModelTestCase) + assertConfigureClient(t, options, err, AzureOpenAIEmbeddingModelTestCase) }) t.Run("Test with OpenAIEndpoint", func(t *testing.T) { @@ -62,7 +62,7 @@ func TestZepOpenAIEmbeddings_TestConfigureClient(t *testing.T) { } options, err := zembeddings.configureClient(cfg) - TestOpenAIClient_ConfigureClient(t, options, err, OpenAIEndpointTestCase) + assertConfigureClient(t, options, err, OpenAIEndpointTestCase) }) t.Run("Test with OpenAIOrgID", func(t *testing.T) { @@ -74,7 +74,7 @@ func TestZepOpenAIEmbeddings_TestConfigureClient(t *testing.T) { } options, err := zembeddings.configureClient(cfg) - TestOpenAIClient_ConfigureClient(t, options, err, OpenAIOrgIDTestCase) + assertConfigureClient(t, options, err, OpenAIOrgIDTestCase) }) } @@ -85,5 +85,5 @@ func TestZepOpenAIEmbeddings_EmbedTexts(t *testing.T) { assert.NoError(t, err, "Expected no error from NewOpenAIEmbeddingsClient") embeddings, err := zembeddings.EmbedTexts(context.Background(), EmbeddingsTestTexts) - TestOpenAIClient_EmbedText(t, embeddings, err) + assertEmbeddings(t, embeddings, err) } diff --git a/pkg/llms/llm_openai_test.go b/pkg/llms/llm_openai_test.go index b2d988ca..92cd4163 100644 --- a/pkg/llms/llm_openai_test.go +++ b/pkg/llms/llm_openai_test.go @@ -22,7 +22,7 @@ func TestZepOpenAILLM_Init(t *testing.T) { zllm := &ZepOpenAILLM{} err := zllm.Init(context.Background(), cfg) - TestOpenAIClient_Init(t, err, zllm.llm, LLMClientType) + assertInit(t, err, zllm.llm, LLMClientType) assert.NotNil(t, zllm.tkm, "Expected tkm to be initialized") } @@ -37,7 +37,7 @@ func TestZepOpenAILLM_TestConfigureClient(t *testing.T) { } options, err := zllm.configureClient(cfg) - TestOpenAIClient_ConfigureClient(t, options, err, OpenAIAPIKeyTestCase) + assertConfigureClient(t, options, err, OpenAIAPIKeyTestCase) }) t.Run("Test with AzureOpenAIEndpoint", func(t *testing.T) { @@ -72,7 +72,7 @@ func TestZepOpenAILLM_TestConfigureClient(t *testing.T) { } options, err := zllm.configureClient(cfg) - TestOpenAIClient_ConfigureClient(t, options, err, AzureOpenAIEmbeddingModelTestCase) + assertConfigureClient(t, options, err, AzureOpenAIEmbeddingModelTestCase) }) t.Run("Test with OpenAIEndpointAndCustomModelName", func(t *testing.T) { @@ -85,7 +85,7 @@ func TestZepOpenAILLM_TestConfigureClient(t *testing.T) { } options, err := zllm.configureClient(cfg) - TestOpenAIClient_ConfigureClient(t, options, err, OpenAIEndpointTestCase) + assertConfigureClient(t, options, err, OpenAIEndpointTestCase) }) t.Run("Test with OpenAIOrgID", func(t *testing.T) { @@ -97,7 +97,7 @@ func TestZepOpenAILLM_TestConfigureClient(t *testing.T) { } options, err := zllm.configureClient(cfg) - TestOpenAIClient_ConfigureClient(t, options, err, OpenAIOrgIDTestCase) + assertConfigureClient(t, options, err, OpenAIOrgIDTestCase) }) } @@ -122,7 +122,7 @@ func TestZepOpenAILLM_EmbedTexts(t *testing.T) { assert.NoError(t, err, "Expected no error from NewOpenAILLM") embeddings, err := zllm.EmbedTexts(context.Background(), EmbeddingsTestTexts) - TestOpenAIClient_EmbedText(t, embeddings, err) + assertEmbeddings(t, embeddings, err) } func TestZepOpenAILLM_GetTokenCount(t *testing.T) { diff --git a/pkg/llms/openai_base.go b/pkg/llms/openai_base.go index e5f13b35..8df7eb2b 100644 --- a/pkg/llms/openai_base.go +++ b/pkg/llms/openai_base.go @@ -127,18 +127,13 @@ func ConfigureOpenAIClientOptions(options []openai.Option, cfg *config.Config, c openAIEndpoint = cfg.LLM.OpenAIEndpoint openAIOrgID = cfg.LLM.OpenAIOrgID - options = append( - options, + options = applyOption(cfg.LLM.AzureOpenAIEndpoint != "", openai.WithAPIType(openai.APITypeAzure), openai.WithBaseURL(cfg.LLM.AzureOpenAIEndpoint), ) - if cfg.LLM.AzureOpenAIModel.EmbeddingDeployment != "" { - options = append( - options, - openai.WithEmbeddingModel(cfg.LLM.AzureOpenAIModel.EmbeddingDeployment), - ) - } - + options = applyOption(cfg.LLM.AzureOpenAIModel.EmbeddingDeployment != "", + openai.WithEmbeddingModel(cfg.LLM.AzureOpenAIModel.EmbeddingDeployment), + ) } options = applyOption(openAIEndpoint != "", diff --git a/pkg/llms/openai_base_test.go b/pkg/llms/openai_base_test.go index 763bc5c3..763c2b61 100644 --- a/pkg/llms/openai_base_test.go +++ b/pkg/llms/openai_base_test.go @@ -19,7 +19,8 @@ const ( var EmbeddingsTestTexts = []string{"Hello, world!", "Another text"} -func TestOpenAIClient_Init(t *testing.T, err error, openAIClient *openai.Chat, clientType ClientType) { +func assertInit(t *testing.T, err error, openAIClient *openai.Chat, clientType ClientType) { + t.Helper() assert.NoError(t, err, "Expected no error from Init") switch clientType { case EmbeddingsClientType: @@ -29,7 +30,8 @@ func TestOpenAIClient_Init(t *testing.T, err error, openAIClient *openai.Chat, c } } -func TestOpenAIClient_ConfigureClient(t *testing.T, options []openai.Option, err error, testCase TestCaseType) { +func assertConfigureClient(t *testing.T, options []openai.Option, err error, testCase TestCaseType) { + t.Helper() assert.NoError(t, err, "Unexpected error") expectedOptions := map[TestCaseType]int{ OpenAIAPIKeyTestCase: 3, @@ -44,11 +46,12 @@ func TestOpenAIClient_ConfigureClient(t *testing.T, options []openai.Option, err } //? assert.Len(t, options, expected, "Unexpected number of options") if len(options) != expected { - t.Errorf("Expected %e options, got %d", expected, len(options)) + t.Errorf("Expected %d options, got %d", expected, len(options)) } } -func TestOpenAIClient_EmbedText(t *testing.T, embeddings [][]float32, err error) { +func assertEmbeddings(t *testing.T, embeddings [][]float32, err error) { + t.Helper() assert.NoError(t, err, "Expected no error from EmbedTexts") assert.Equal(t, len(EmbeddingsTestTexts), len(embeddings), "Expected embeddings to have same length as texts") assert.NotZero(t, embeddings[0], "Expected embeddings to be non-zero") From e59a4150a823a54346956ce683002a83492901cf Mon Sep 17 00:00:00 2001 From: Brice Macias Date: Mon, 30 Oct 2023 16:47:18 -0600 Subject: [PATCH 5/8] readability --- pkg/llms/openai_base.go | 59 ++++++++++++++++++++++++----------------- 1 file changed, 34 insertions(+), 25 deletions(-) diff --git a/pkg/llms/openai_base.go b/pkg/llms/openai_base.go index 8df7eb2b..a0cd07e9 100644 --- a/pkg/llms/openai_base.go +++ b/pkg/llms/openai_base.go @@ -101,13 +101,6 @@ func GetBaseOpenAIClientOptions(apiKey, validModel string) []openai.Option { } func ConfigureOpenAIClientOptions(options []openai.Option, cfg *config.Config, clientType ClientType) []openai.Option { - applyOption := func(cond bool, opts ...openai.Option) []openai.Option { - if cond { - return append(options, opts...) - } - return options - } - var openAIEndpoint string var openAIOrgID string @@ -118,31 +111,47 @@ func ConfigureOpenAIClientOptions(options []openai.Option, cfg *config.Config, c // Check configuration for AzureOpenAIEndpoint; if it's set, use the DefaultAzureConfig // and provided endpoint Path. // WithEmbeddings is always required in case of embeddings client - options = applyOption(cfg.EmbeddingsClient.AzureOpenAIEndpoint != "", - openai.WithAPIType(openai.APITypeAzure), - openai.WithBaseURL(cfg.EmbeddingsClient.AzureOpenAIEndpoint), - openai.WithEmbeddingModel(cfg.EmbeddingsClient.AzureOpenAIModel.EmbeddingDeployment), - ) + if cfg.EmbeddingsClient.AzureOpenAIEndpoint != "" { + options = append( + options, + openai.WithAPIType(openai.APITypeAzure), + openai.WithBaseURL(cfg.EmbeddingsClient.AzureOpenAIEndpoint), + openai.WithEmbeddingModel(cfg.EmbeddingsClient.AzureOpenAIModel.EmbeddingDeployment), + ) + } } else { openAIEndpoint = cfg.LLM.OpenAIEndpoint openAIOrgID = cfg.LLM.OpenAIOrgID - options = applyOption(cfg.LLM.AzureOpenAIEndpoint != "", - openai.WithAPIType(openai.APITypeAzure), - openai.WithBaseURL(cfg.LLM.AzureOpenAIEndpoint), - ) - options = applyOption(cfg.LLM.AzureOpenAIModel.EmbeddingDeployment != "", - openai.WithEmbeddingModel(cfg.LLM.AzureOpenAIModel.EmbeddingDeployment), - ) + if cfg.LLM.AzureOpenAIEndpoint != "" { + options = append( + options, + openai.WithAPIType(openai.APITypeAzure), + openai.WithBaseURL(cfg.LLM.AzureOpenAIEndpoint), + ) + + if cfg.LLM.AzureOpenAIModel.EmbeddingDeployment != "" { + options = append( + options, + openai.WithEmbeddingModel(cfg.LLM.AzureOpenAIModel.EmbeddingDeployment), + ) + } + } } - options = applyOption(openAIEndpoint != "", - openai.WithBaseURL(openAIEndpoint), - ) + if openAIEndpoint != "" { + options = append( + options, + openai.WithBaseURL(openAIEndpoint), + ) + } - options = applyOption(openAIOrgID != "", - openai.WithOrganization(openAIOrgID), - ) + if openAIOrgID != "" { + options = append( + options, + openai.WithBaseURL(openAIOrgID), + ) + } return options } From a117ce0cea08bc1c31f6552bc384d1f07da212cf Mon Sep 17 00:00:00 2001 From: Brice Macias Date: Wed, 1 Nov 2023 20:00:54 -0600 Subject: [PATCH 6/8] clean --- cmd/zep/run.go | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/cmd/zep/run.go b/cmd/zep/run.go index e77cd061..5f99ee32 100644 --- a/cmd/zep/run.go +++ b/cmd/zep/run.go @@ -62,19 +62,20 @@ func NewAppState(cfg *config.Config) *models.AppState { ctx := context.Background() // Create a new LLM client - llmClient, err := llms.NewLLMClient(ctx, cfg) - if err != nil { - log.Fatal(err) + llmClient, llmClientErr := llms.NewLLMClient(ctx, cfg) + if llmClientErr != nil { + log.Fatal(llmClientErr) } var embeddingsClient models.ZepEmbeddingsClient = nil + var embeddingsClientClientErr error = nil + // If enabled, create a new Embeddings client if cfg.EmbeddingsClient.Enabled { - embeddingsClient, err = llms.NewEmbeddingsClient(ctx, cfg) - } - - if err != nil { - log.Fatal(err) + embeddingsClient, embeddingsClientClientErr = llms.NewEmbeddingsClient(ctx, cfg) + if embeddingsClientClientErr != nil { + log.Fatal(embeddingsClientClientErr) + } } appState := &models.AppState{ From a2851b29f940fad2cf8b894fa1210a3c4f7fe246 Mon Sep 17 00:00:00 2001 From: Brice Macias Date: Wed, 1 Nov 2023 20:20:08 -0600 Subject: [PATCH 7/8] fix (lint) --- cmd/zep/run.go | 7 ++++--- pkg/llms/embeddings_openai.go | 6 +++++- pkg/llms/llm_openai.go | 6 +++++- pkg/llms/openai_base.go | 8 -------- 4 files changed, 14 insertions(+), 13 deletions(-) diff --git a/cmd/zep/run.go b/cmd/zep/run.go index 5f99ee32..e8580585 100644 --- a/cmd/zep/run.go +++ b/cmd/zep/run.go @@ -73,9 +73,10 @@ func NewAppState(cfg *config.Config) *models.AppState { // If enabled, create a new Embeddings client if cfg.EmbeddingsClient.Enabled { embeddingsClient, embeddingsClientClientErr = llms.NewEmbeddingsClient(ctx, cfg) - if embeddingsClientClientErr != nil { - log.Fatal(embeddingsClientClientErr) - } + } + + if embeddingsClientClientErr != nil { + log.Fatal(embeddingsClientClientErr) } appState := &models.AppState{ diff --git a/pkg/llms/embeddings_openai.go b/pkg/llms/embeddings_openai.go index 33152c8a..4f144e9a 100644 --- a/pkg/llms/embeddings_openai.go +++ b/pkg/llms/embeddings_openai.go @@ -34,7 +34,11 @@ func (zembeddings *ZepOpenAIEmbeddingsClient) Init(_ context.Context, cfg *confi // Create a new client instance with options. // Even if it will just used for embeddings, // it uses the same langchain openai chat client builder - client, err := NewOpenAIChatClient(options...) + client, err := openai.NewChat(options...) + if err != nil { + return err + } + zembeddings.client = client return nil diff --git a/pkg/llms/llm_openai.go b/pkg/llms/llm_openai.go index 8a9e5cea..ddc406da 100644 --- a/pkg/llms/llm_openai.go +++ b/pkg/llms/llm_openai.go @@ -46,7 +46,11 @@ func (zllm *ZepOpenAILLM) Init(_ context.Context, cfg *config.Config) error { } // Create a new client instance with options - llm, err := NewOpenAIChatClient(options...) + llm, err := openai.NewChat(options...) + if err != nil { + return err + } + zllm.llm = llm return nil diff --git a/pkg/llms/openai_base.go b/pkg/llms/openai_base.go index a0cd07e9..8726b917 100644 --- a/pkg/llms/openai_base.go +++ b/pkg/llms/openai_base.go @@ -18,14 +18,6 @@ const ( LLMClientType ClientType = "llm" ) -func NewOpenAIChatClient(options ...openai.Option) (*openai.Chat, error) { - client, err := openai.NewChat(options...) - if err != nil { - return nil, err - } - return client, nil -} - func GetOpenAIAPIKey(cfg *config.Config, clientType ClientType) string { var apiKey string From e9caac611535d09566e45a9efdf806e987c1e99b Mon Sep 17 00:00:00 2001 From: Brice Macias Date: Thu, 9 Nov 2023 18:09:17 -0600 Subject: [PATCH 8/8] fix openai org id case --- pkg/llms/openai_base.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/llms/openai_base.go b/pkg/llms/openai_base.go index 8726b917..77971bd1 100644 --- a/pkg/llms/openai_base.go +++ b/pkg/llms/openai_base.go @@ -141,7 +141,7 @@ func ConfigureOpenAIClientOptions(options []openai.Option, cfg *config.Config, c if openAIOrgID != "" { options = append( options, - openai.WithBaseURL(openAIOrgID), + openai.WithOrganization(openAIOrgID), ) }