Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable separate Embeddings Client #258

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions cmd/zep/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,27 @@ func NewAppState(cfg *config.Config) *models.AppState {
ctx := context.Background()

// Create a new LLM client
llmClient, err := llms.NewLLMClient(ctx, cfg)
if err != nil {
log.Fatal(err)
llmClient, llmClientErr := llms.NewLLMClient(ctx, cfg)
if llmClientErr != nil {
log.Fatal(llmClientErr)
}

var embeddingsClient models.ZepEmbeddingsClient = nil
var embeddingsClientClientErr error = nil

// If enabled, create a new Embeddings client
if cfg.EmbeddingsClient.Enabled {
embeddingsClient, embeddingsClientClientErr = llms.NewEmbeddingsClient(ctx, cfg)
}

if embeddingsClientClientErr != nil {
log.Fatal(embeddingsClientClientErr)
}

appState := &models.AppState{
LLMClient: llmClient,
Config: cfg,
LLMClient: llmClient,
EmbeddingsClient: embeddingsClient,
Config: cfg,
}

initializeStores(ctx, appState)
Expand Down
15 changes: 15 additions & 0 deletions config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,18 @@ llm:
# Use only with an alternate OpenAI-compatible API endpoint
openai_endpoint:
openai_org_id:
embeddings_client:
# Enable if, when not using local embeddings, you want to use distinct clients for your llm and embeddings.
# For now, only OpenAI embeddings are available as external embeddings client.
enabled: false
# When enabled, please make sure to have your ZEP_EMBEDDINGS_OPENAI_API_KEY environment variable set
service: "openai"
azure_openai_endpoint:
azure_openai:
# embeddings deployment is required when using azure deployment for embeddings
embedding_deployment: "text-embedding-ada-002-customname"
openai_endpoint:
openai_org_id:
nlp:
server_url: "http://localhost:5557"
memory:
Expand Down Expand Up @@ -112,6 +124,9 @@ custom_prompts:
# Current summary: {{.PrevSummary}}
# New lines of conversation: {{.MessagesJoined}}
# New summary:`
#
# For Open Source models compatible with the OpenAI client,
# follow the prompt guidelines provided by the model owners.
#
# If left empty, the default OpenAI summary prompt from zep/pkg/extractors/prompts.go will be used.
openai: |
9 changes: 5 additions & 4 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@ var log = logrus.New()

// EnvVars is a set of secrets that should be stored in the environment, not config file
var EnvVars = map[string]string{
"llm.anthropic_api_key": "ZEP_ANTHROPIC_API_KEY",
"llm.openai_api_key": "ZEP_OPENAI_API_KEY",
"auth.secret": "ZEP_AUTH_SECRET",
"development": "ZEP_DEVELOPMENT",
"llm.anthropic_api_key": "ZEP_ANTHROPIC_API_KEY",
"llm.openai_api_key": "ZEP_OPENAI_API_KEY",
"embeddings_client.openai_api_key": "ZEP_EMBEDDINGS_OPENAI_API_KEY",
"auth.secret": "ZEP_AUTH_SECRET",
"development": "ZEP_DEVELOPMENT",
}

// LoadConfig loads the config file and ENV variables into a Config struct
Expand Down
33 changes: 22 additions & 11 deletions config/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,18 @@ package config
// Config holds the configuration of the application
// Use cmd.NewConfig to create a new instance
type Config struct {
LLM LLM `mapstructure:"llm"`
NLP NLP `mapstructure:"nlp"`
Memory MemoryConfig `mapstructure:"memory"`
Extractors ExtractorsConfig `mapstructure:"extractors"`
Store StoreConfig `mapstructure:"store"`
Server ServerConfig `mapstructure:"server"`
Log LogConfig `mapstructure:"log"`
Auth AuthConfig `mapstructure:"auth"`
DataConfig DataConfig `mapstructure:"data"`
Development bool `mapstructure:"development"`
CustomPrompts CustomPromptsConfig `mapstructure:"custom_prompts"`
LLM LLM `mapstructure:"llm"`
EmbeddingsClient EmbeddingsClient `mapstructure:"embeddings_client"`
NLP NLP `mapstructure:"nlp"`
Memory MemoryConfig `mapstructure:"memory"`
Extractors ExtractorsConfig `mapstructure:"extractors"`
Store StoreConfig `mapstructure:"store"`
Server ServerConfig `mapstructure:"server"`
Log LogConfig `mapstructure:"log"`
Auth AuthConfig `mapstructure:"auth"`
DataConfig DataConfig `mapstructure:"data"`
Development bool `mapstructure:"development"`
CustomPrompts CustomPromptsConfig `mapstructure:"custom_prompts"`
}

type StoreConfig struct {
Expand All @@ -32,6 +33,16 @@ type LLM struct {
OpenAIOrgID string `mapstructure:"openai_org_id"`
}

type EmbeddingsClient struct {
Enabled bool `mapstructure:"enabled"`
Service string `mapstructure:"service"`
OpenAIAPIKey string `mapstructure:"openai_api_key"`
AzureOpenAIEndpoint string `mapstructure:"azure_openai_endpoint"`
AzureOpenAIModel AzureOpenAIConfig `mapstructure:"azure_openai"`
OpenAIEndpoint string `mapstructure:"openai_endpoint"`
OpenAIOrgID string `mapstructure:"openai_org_id"`
}

type AzureOpenAIConfig struct {
LLMDeployment string `mapstructure:"llm_deployment"`
EmbeddingDeployment string `mapstructure:"embedding_deployment"`
Expand Down
13 changes: 11 additions & 2 deletions pkg/llms/embeddings.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,23 @@ func EmbedTexts(
return nil, errors.New("no text to embed")
}

if appState.LLMClient == nil {
llmClient := appState.LLMClient

if llmClient == nil {
return nil, errors.New(InvalidLLMModelError)
}

if model.Service == "local" {
return embedTextsLocal(ctx, appState, documentType, text)
}
return appState.LLMClient.EmbedTexts(ctx, text)

embeddingsClient := appState.EmbeddingsClient

if embeddingsClient != nil {
return embeddingsClient.EmbedTexts(ctx, text)
}

return llmClient.EmbedTexts(ctx, text)
}

func GetEmbeddingModel(
Expand Down
42 changes: 42 additions & 0 deletions pkg/llms/embeddings_base.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package llms

import (
"context"
"fmt"

"github.com/getzep/zep/pkg/models"

"github.com/getzep/zep/config"
)

const InvalidEmbeddingsClientError = "embeddings client is not set or is invalid"

type EmbeddingsClientError struct {
message string
originalError error
}

func (e *EmbeddingsClientError) Error() string {
return fmt.Sprintf("embeddings client error: %s (original error: %v)", e.message, e.originalError)
}

func NewEmbeddingsClientError(message string, originalError error) *EmbeddingsClientError {
return &EmbeddingsClientError{message: message, originalError: originalError}
}

func NewEmbeddingsClient(ctx context.Context, cfg *config.Config) (models.ZepEmbeddingsClient, error) {
switch cfg.EmbeddingsClient.Service {
// For now we only support OpenAI embeddings
case "openai":
// EmbeddingsDeployment is required if using external embeddings with AzureOpenAI
if cfg.EmbeddingsClient.AzureOpenAIEndpoint != "" && cfg.EmbeddingsClient.AzureOpenAIModel.EmbeddingDeployment == "" {
err := InvalidEmbeddingsDeploymentError(cfg.EmbeddingsClient.Service)
return nil, err
}
// The logic is the same if custom OpenAI Endpoint is set or not
// since the model name will be set automatically in this case
return NewOpenAIEmbeddingsClient(ctx, cfg)
default:
return nil, fmt.Errorf("invalid embeddings service: %s", cfg.EmbeddingsClient.Service)
}
}
73 changes: 73 additions & 0 deletions pkg/llms/embeddings_openai.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package llms

import (
"context"

"github.com/getzep/zep/config"
"github.com/getzep/zep/pkg/models"
"github.com/tmc/langchaingo/llms/openai"
)

const EmbeddingsOpenAIAPIKeyNotSetError = "ZEP_EMBEDDINGS_OPENAI_API_KEY is not set" //nolint:gosec

var _ models.ZepEmbeddingsClient = &ZepOpenAIEmbeddingsClient{}

func NewOpenAIEmbeddingsClient(ctx context.Context, cfg *config.Config) (*ZepOpenAIEmbeddingsClient, error) {
zembeddings := &ZepOpenAIEmbeddingsClient{}
err := zembeddings.Init(ctx, cfg)
if err != nil {
return nil, err
}
return zembeddings, nil
}

type ZepOpenAIEmbeddingsClient struct {
client *openai.Chat
}

func (zembeddings *ZepOpenAIEmbeddingsClient) Init(_ context.Context, cfg *config.Config) error {
options, err := zembeddings.configureClient(cfg)
if err != nil {
return err
}

// Create a new client instance with options.
// Even if it will just used for embeddings,
// it uses the same langchain openai chat client builder
client, err := openai.NewChat(options...)
if err != nil {
return err
}

zembeddings.client = client

return nil
}

func (zembeddings *ZepOpenAIEmbeddingsClient) EmbedTexts(ctx context.Context, texts []string) ([][]float32, error) {
return EmbedTextsWithOpenAIClient(ctx, texts, zembeddings.client, EmbeddingsClientType)
}

func getValidOpenAIModel() string {
for k := range ValidOpenAILLMs {
return k
}
return "gpt-3.5-turbo"
}

func (zembeddings *ZepOpenAIEmbeddingsClient) configureClient(cfg *config.Config) ([]openai.Option, error) {
// Retrieve the OpenAIAPIKey from configuration
apiKey := GetOpenAIAPIKey(cfg, EmbeddingsClientType)

ValidateOpenAIConfig(cfg, EmbeddingsClientType)

// Even if it will only be used for embeddings, we should pass a valid openai llm model
// to avoid any errors
validOpenaiLLMModel := getValidOpenAIModel()

options := GetBaseOpenAIClientOptions(apiKey, validOpenaiLLMModel)

options = ConfigureOpenAIClientOptions(options, cfg, EmbeddingsClientType)

return options, nil
}
89 changes: 89 additions & 0 deletions pkg/llms/embeddings_openai_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
package llms

import (
"context"
"testing"

"github.com/getzep/zep/pkg/testutils"

"github.com/stretchr/testify/assert"

"github.com/getzep/zep/config"
)

func TestZepOpenAIEmbeddings_Init(t *testing.T) {
cfg := &config.Config{
EmbeddingsClient: config.EmbeddingsClient{
OpenAIAPIKey: "test-key",
},
}

zembeddings := &ZepOpenAIEmbeddingsClient{}

err := zembeddings.Init(context.Background(), cfg)
assertInit(t, err, zembeddings.client, EmbeddingsClientType)
}

func TestZepOpenAIEmbeddings_TestConfigureClient(t *testing.T) {
zembeddings := &ZepOpenAIEmbeddingsClient{}

t.Run("Test with OpenAIAPIKey", func(t *testing.T) {
cfg := &config.Config{
EmbeddingsClient: config.EmbeddingsClient{
OpenAIAPIKey: "test-key",
},
}

options, err := zembeddings.configureClient(cfg)
assertConfigureClient(t, options, err, OpenAIAPIKeyTestCase)
})

t.Run("Test with AzureOpenAIEmbeddingModel", func(t *testing.T) {
cfg := &config.Config{
EmbeddingsClient: config.EmbeddingsClient{
OpenAIAPIKey: "test-key",
AzureOpenAIEndpoint: "https://azure.openai.com",
AzureOpenAIModel: config.AzureOpenAIConfig{
EmbeddingDeployment: "test-embedding-deployment",
},
},
}

options, err := zembeddings.configureClient(cfg)
assertConfigureClient(t, options, err, AzureOpenAIEmbeddingModelTestCase)
})

t.Run("Test with OpenAIEndpoint", func(t *testing.T) {
cfg := &config.Config{
EmbeddingsClient: config.EmbeddingsClient{
OpenAIAPIKey: "test-key",
OpenAIEndpoint: "https://openai.com",
},
}

options, err := zembeddings.configureClient(cfg)
assertConfigureClient(t, options, err, OpenAIEndpointTestCase)
})

t.Run("Test with OpenAIOrgID", func(t *testing.T) {
cfg := &config.Config{
EmbeddingsClient: config.EmbeddingsClient{
OpenAIAPIKey: "test-key",
OpenAIOrgID: "org-id",
},
}

options, err := zembeddings.configureClient(cfg)
assertConfigureClient(t, options, err, OpenAIOrgIDTestCase)
})
}

func TestZepOpenAIEmbeddings_EmbedTexts(t *testing.T) {
cfg := testutils.NewTestConfig()

zembeddings, err := NewOpenAIEmbeddingsClient(context.Background(), cfg)
assert.NoError(t, err, "Expected no error from NewOpenAIEmbeddingsClient")

embeddings, err := zembeddings.EmbedTexts(context.Background(), EmbeddingsTestTexts)
assertEmbeddings(t, embeddings, err)
}
10 changes: 6 additions & 4 deletions pkg/llms/llm_base.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ import (
const DefaultTemperature = 0.0
const InvalidLLMModelError = "llm model is not set or is invalid"

var InvalidEmbeddingsDeploymentError = func(service string) error {
return fmt.Errorf("invalid embeddings deployment for %s, deployment name is required", service)
}

var log = internal.GetLogger()

func NewLLMClient(ctx context.Context, cfg *config.Config) (models.ZepLLM, error) {
Expand All @@ -43,10 +47,8 @@ func NewLLMClient(ctx context.Context, cfg *config.Config) (models.ZepLLM, error
// EmbeddingsDeployment is only required if Zep is also configured to use
// OpenAI embeddings for document or message extractors
if cfg.LLM.AzureOpenAIModel.EmbeddingDeployment == "" && useOpenAIEmbeddings(cfg) {
return nil, fmt.Errorf(
"invalid embeddings deployment for %s, deployment name is required",
cfg.LLM.Service,
)
err := InvalidEmbeddingsDeploymentError(cfg.EmbeddingsClient.Service)
return nil, err
}
return NewOpenAILLM(ctx, cfg)
}
Expand Down
Loading
Loading