Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat/Anthropic LLM support; API request improvements #163

Merged
merged 5 commits into from
Aug 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/build-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jobs:
--health-timeout 5s
--health-retries 5
nlp:
image: ghcr.io/getzep/zep-nlp-server:0.3.0-beta.0
image: ghcr.io/getzep/zep-nlp-server:latest
env:
ENABLE_EMBEDDINGS: true
options: >-
Expand All @@ -52,5 +52,6 @@ jobs:
run: CGO_ENABLED=1 go test -tags=testutils -race -p 1 -v ./...
env:
ZEP_OPENAI_API_KEY: ${{ secrets.ZEP_OPENAI_API_KEY }}
ZEP_ANTHROPIC_API_KEY: ${{ secrets.ZEP_ANTHROPIC_API_KEY }}
ZEP_STORE_POSTGRES_DSN: 'postgres://postgres:postgres@postgres:5432/?sslmode=disable'
ZEP_NLP_SERVER_URL: 'http://nlp:5557'
2 changes: 1 addition & 1 deletion .github/workflows/golangci-lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
uses: golangci/golangci-lint-action@v3
with:
# Optional: version of golangci-lint to use in form of v1.2 or v1.2.3 or `latest` to use the latest version
version: v1.52.2
version: latest

# Optional: working directory, useful for monorepos
# working-directory: somedir
Expand Down
10 changes: 8 additions & 2 deletions cmd/zep/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,15 @@ func run() {
func NewAppState(cfg *config.Config) *models.AppState {
ctx := context.Background()

// Create a new LLM client
llmClient, err := llms.NewLLMClient(ctx, cfg)
if err != nil {
log.Fatal(err)
}

appState := &models.AppState{
OpenAIClient: llms.NewOpenAIRetryClient(cfg),
Config: cfg,
LLMClient: llmClient,
Config: cfg,
}

initializeStores(appState)
Expand Down
5 changes: 4 additions & 1 deletion config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
llm:
# gpt-3.5-turbo or gpt-4
# openai or anthropic
service: "openai"
# OpenAI: gpt-3.5-turbo, gpt-4, gpt-3.5-turbo-16k, gpt-4-32k; Anthropic: claude-instant-1 or claude-2
model: "gpt-3.5-turbo"
## OpenAI-specific settings
# Only used for Azure OpenAI API
azure_openai_endpoint:
# Use only with an alternate OpenAI-compatible API endpoint
Expand Down
5 changes: 3 additions & 2 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@ var log = logrus.New()

// EnvVars is a set of secrets that should be stored in the environment, not config file
var EnvVars = map[string]string{
"llm.openai_api_key": "ZEP_OPENAI_API_KEY",
"auth.secret": "ZEP_AUTH_SECRET",
"llm.anthropic_api_key": "ZEP_ANTHROPIC_API_KEY",
"llm.openai_api_key": "ZEP_OPENAI_API_KEY",
"auth.secret": "ZEP_AUTH_SECRET",
}

// LoadConfig loads the config file and ENV variables into a Config struct
Expand Down
5 changes: 3 additions & 2 deletions config/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ type StoreConfig struct {
}

type LLM struct {
Model string `mapstructure:"model"`
// OpenAIAPIKey is loaded from ENV not config file.
Service string `mapstructure:"service"`
Model string `mapstructure:"model"`
AnthropicAPIKey string `mapstructure:"anthropic_api_key"`
OpenAIAPIKey string `mapstructure:"openai_api_key"`
AzureOpenAIEndpoint string `mapstructure:"azure_openai_endpoint"`
OpenAIEndpoint string `mapstructure:"openai_endpoint"`
Expand Down
15 changes: 9 additions & 6 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ require (
dario.cat/mergo v1.0.0
github.com/alitto/pond v1.8.3
github.com/avast/retry-go/v4 v4.3.4
github.com/brianvoe/gofakeit/v6 v6.23.0
github.com/chi-middleware/logrus-logger v0.2.0
github.com/go-chi/chi/v5 v5.0.8
github.com/go-chi/jwtauth/v5 v5.1.0
Expand All @@ -16,8 +17,8 @@ require (
github.com/joho/godotenv v1.5.1
github.com/oiime/logrusbun v0.1.1
github.com/pgvector/pgvector-go v0.1.1
github.com/pkoukk/tiktoken-go v0.1.1
github.com/sashabaranov/go-openai v1.11.3
github.com/pkoukk/tiktoken-go v0.1.5
github.com/sashabaranov/go-openai v1.14.2
github.com/sirupsen/logrus v1.9.0
github.com/spf13/cobra v1.7.0
github.com/spf13/viper v1.15.0
Expand All @@ -27,15 +28,15 @@ require (
github.com/uptrace/bun/dialect/pgdialect v1.1.12
github.com/uptrace/bun/driver/pgdriver v1.1.12
gonum.org/v1/gonum v0.13.0
github.com/brianvoe/gofakeit/v6 v6.23.0
)

require github.com/tmc/langchaingo v0.0.0-20230811231558-fd8b7f099537 // indirect

require (
github.com/KyleBanks/depth v1.2.1 // indirect

github.com/davecgh/go-spew v1.1.1 // indirect
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0 // indirect
github.com/dlclark/regexp2 v1.9.0 // indirect
github.com/dlclark/regexp2 v1.10.0 // indirect
github.com/fsnotify/fsnotify v1.6.0 // indirect
github.com/gabriel-vasile/mimetype v1.4.2 // indirect
github.com/go-openapi/jsonpointer v0.19.6 // indirect
Expand All @@ -45,6 +46,8 @@ require (
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/goccy/go-json v0.10.2 // indirect
github.com/hashicorp/go-cleanhttp v0.5.2 // indirect
github.com/hashicorp/go-retryablehttp v0.7.4 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/jackc/pgx/v5 v5.4.2 // indirect
Expand All @@ -64,7 +67,7 @@ require (
github.com/pelletier/go-toml/v2 v2.0.7 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rogpeppe/go-internal v1.10.0 // indirect
github.com/rogpeppe/go-internal v1.11.0 // indirect
github.com/segmentio/asm v1.2.0 // indirect
github.com/spf13/afero v1.9.5 // indirect
github.com/spf13/cast v1.5.0 // indirect
Expand Down
19 changes: 19 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03
github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo=
github.com/KyleBanks/depth v1.2.1 h1:5h8fQADFrWtarTdtDudMmGsC7GPbOAu6RVB3ffsVFHc=
github.com/KyleBanks/depth v1.2.1/go.mod h1:jzSb9d0L43HxTQfT+oSA1EEp2q+ne2uh6XgeJcm8brE=
github.com/PuerkitoBio/purell v1.1.1/go.mod h1:c11w/QuzBsJSee3cPx9rAFu61PvFxuPbtSwDGJws/X0=
github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578/go.mod h1:uGdkoq3SwY9Y+13GIhn11/XLaGBb4BfwItxLd5jeuXE=
github.com/alitto/pond v1.8.3 h1:ydIqygCLVPqIX/USe5EaV/aSRXTRXDEI9JwuDdu+/xs=
github.com/alitto/pond v1.8.3/go.mod h1:CmvIIGd5jKLasGI3D87qDkQxjzChdKMmnXMg3fG6M6Q=
github.com/avast/retry-go/v4 v4.3.4 h1:pHLkL7jvCvP317I8Ge+Km2Yhntv3SdkJm7uekkqbKhM=
Expand All @@ -68,6 +70,8 @@ github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0 h1:8UrgZ3GkP4i/CLijOJx79Yu+etly
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0/go.mod h1:v57UDF4pDQJcEfFUCRop3lJL149eHGSe9Jvczhzjo/0=
github.com/dlclark/regexp2 v1.9.0 h1:pTK/l/3qYIKaRXuHnEnIf7Y5NxfRPfpb7dis6/gdlVI=
github.com/dlclark/regexp2 v1.9.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0=
github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98=
Expand Down Expand Up @@ -170,6 +174,11 @@ github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+
github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg=
github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk=
github.com/googleapis/google-cloud-go-testing v0.0.0-20200911160855-bcd43fbb19e8/go.mod h1:dvDLG8qkwmyD9a/MJJN3XJcT3xFxOKAvTZGvuZmac9g=
github.com/hashicorp/go-cleanhttp v0.5.2 h1:035FKYIWjmULyFRBKPs8TBQoi0x6d9G4xc9neXJWAZQ=
github.com/hashicorp/go-cleanhttp v0.5.2/go.mod h1:kO/YDlP8L1346E6Sodw+PrpBSV4/SoxCXGY6BqNFT48=
github.com/hashicorp/go-hclog v0.9.2/go.mod h1:5CU+agLiy3J7N7QjHK5d05KxGsuXiQLrjA0H7acj2lQ=
github.com/hashicorp/go-retryablehttp v0.7.4 h1:ZQgVdpTdAL7WpMIwLzCfbalOcSUdkDZnpUv3/+BxzFA=
github.com/hashicorp/go-retryablehttp v0.7.4/go.mod h1:Jy/gPYAdjqffZ/yFGCFV2doI5wjtH1ewM9u8iYVjtX8=
github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4=
Expand Down Expand Up @@ -239,16 +248,24 @@ github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINE
github.com/pkg/sftp v1.13.1/go.mod h1:3HaPG6Dq1ILlpPZRO0HVMrsydcdLt6HRDccSgb87qRg=
github.com/pkoukk/tiktoken-go v0.1.1 h1:jtkYlIECjyM9OW1w4rjPmTohK4arORP9V25y6TM6nXo=
github.com/pkoukk/tiktoken-go v0.1.1/go.mod h1:boMWvk9pQCOTx11pgu0DrIdrAKgQzzJKUP6vLXaz7Rw=
github.com/pkoukk/tiktoken-go v0.1.2 h1:u7PCSBiWJ3nJYoTGShyM9iHXz4dNyYkurwwp+GHtyHY=
github.com/pkoukk/tiktoken-go v0.1.2/go.mod h1:boMWvk9pQCOTx11pgu0DrIdrAKgQzzJKUP6vLXaz7Rw=
github.com/pkoukk/tiktoken-go v0.1.5 h1:hAlT4dCf6Uk50x8E7HQrddhH3EWMKUN+LArExQQsQx4=
github.com/pkoukk/tiktoken-go v0.1.5/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M=
github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/sashabaranov/go-openai v1.11.3 h1:bvwWF8hj4UhPlswBdL9/IfOpaHXfzGCJO8WY8ml9sGc=
github.com/sashabaranov/go-openai v1.11.3/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
github.com/sashabaranov/go-openai v1.14.2 h1:5DPTtR9JBjKPJS008/A409I5ntFhUPPGCmaAihcPRyo=
github.com/sashabaranov/go-openai v1.14.2/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
github.com/segmentio/asm v1.2.0 h1:9BQrFxC+YOHJlTlHGkTrFWf59nbL3XnCoFLTwDCI7ys=
github.com/segmentio/asm v1.2.0/go.mod h1:BqMnlJP91P8d+4ibuonYZw9mfnzI9HfxselHZr5aAcs=
github.com/sirupsen/logrus v1.8.1/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0=
Expand Down Expand Up @@ -287,6 +304,8 @@ github.com/sv-tools/openapi v0.2.1 h1:ES1tMQMJFGibWndMagvdoo34T1Vllxr1Nlm5wz6b1a
github.com/sv-tools/openapi v0.2.1/go.mod h1:k5VuZamTw1HuiS9p2Wl5YIDWzYnHG6/FgPOSFXLAhGg=
github.com/swaggo/swag/v2 v2.0.0-rc3 h1:cIkbddJ9ftgRenDaDzyvg+2TUDLFCDffZ40yZE1r0vU=
github.com/swaggo/swag/v2 v2.0.0-rc3/go.mod h1:mfTZJmxpXWA3JQ9V381+cRlutUCo7OXd/VyIRcMhByc=
github.com/tmc/langchaingo v0.0.0-20230811231558-fd8b7f099537 h1:vkeNjlW+0Xiw2XizMHoQuLG8pg6AN1hU8zJuMV9GQBc=
github.com/tmc/langchaingo v0.0.0-20230811231558-fd8b7f099537/go.mod h1:8T+nNIGBr3nYQEYFmF/YaT8t8YTKLvFYZBuVZOAYn5E=
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc h1:9lRDQMhESg+zvGYmW5DyG0UqvY96Bu5QYsTLvCHdrgo=
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc/go.mod h1:bciPuU6GHm1iF1pBvUfxfsH0Wmnc2VbpgvbI9ZWuIRs=
github.com/uptrace/bun v0.3.9/go.mod h1:aL6D9vPw8DXaTQTwGrEPtUderBYXx7ShUmPfnxnqscw=
Expand Down
10 changes: 10 additions & 0 deletions internal/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,13 @@ func StructToMap(item interface{}) map[string]interface{} {

return out
}

func MergeMaps[T any](maps ...map[string]T) map[string]T {
result := make(map[string]T)
for _, m := range maps {
for k, v := range m {
result[k] = v
}
}
return result
}
21 changes: 21 additions & 0 deletions internal/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,27 @@ type testData struct {
Name string
}

func TestMergeMaps(t *testing.T) {
map1 := map[string]int{"one": 1, "two": 2}
map2 := map[string]int{"three": 3, "four": 4}
map3 := map[string]int{"five": 5, "six": 6}

expected := map[string]int{
"one": 1,
"two": 2,
"three": 3,
"four": 4,
"five": 5,
"six": 6,
}

result := MergeMaps(map1, map2, map3)

if !reflect.DeepEqual(result, expected) {
t.Errorf("Expected %v, but got %v", expected, result)
}
}

func TestParsePrompt(t *testing.T) {
testCases := []struct {
name string
Expand Down
2 changes: 1 addition & 1 deletion pkg/extractors/doc_embedding_processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func (ep *DocEmbeddingProcessor) Run(
ctx context.Context,
) error {
ep.documentType = "document"
model, err := llms.GetMessageEmbeddingModel(ep.appState, ep.documentType)
model, err := llms.GetEmbeddingModel(ep.appState, ep.documentType)
if err != nil {
return fmt.Errorf("failed to get embedding model: %w", err)
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/extractors/embedder.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func (ee *EmbeddingExtractor) Extract(

texts := messageToStringSlice(messageEvent.Messages, false)

model, err := llms.GetMessageEmbeddingModel(appState, messageType)
model, err := llms.GetEmbeddingModel(appState, messageType)
if err != nil {
return NewExtractorError("EmbeddingExtractor get message embedding model failed", err)
}
Expand Down
8 changes: 7 additions & 1 deletion pkg/extractors/embedder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,13 @@ import (
"github.com/stretchr/testify/assert"
)

func TestEmbeddingExtractor_Extract(t *testing.T) {
func TestEmbeddingExtractor_Extract_OpenAI(t *testing.T) {
appState.Config.LLM.Service = "openai"
appState.Config.LLM.Model = "gpt-3.5-turbo"
llmClient, err := llms.NewOpenAILLM(testCtx, appState.Config)
assert.NoError(t, err)
appState.LLMClient = llmClient

store := appState.MemoryStore

documentType := "message"
Expand Down
8 changes: 7 additions & 1 deletion pkg/extractors/initialize_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,13 @@ func setup() {

appState = &models.AppState{}
cfg := testutils.NewTestConfig()
appState.OpenAIClient = llms.NewOpenAIRetryClient(cfg)

llmClient, err := llms.NewLLMClient(context.Background(), cfg)
if err != nil {
panic(err)
}

appState.LLMClient = llmClient
appState.Config = cfg
appState.Config.Store.Postgres.DSN = testutils.GetDSN()

Expand Down
17 changes: 13 additions & 4 deletions pkg/extractors/intent_extractor.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,22 @@ package extractors
import (
"context"
"errors"
"regexp"
"strings"
"sync"

"github.com/tmc/langchaingo/llms"

"github.com/getzep/zep/internal"
"github.com/getzep/zep/pkg/llms"
"github.com/getzep/zep/pkg/models"
)

var _ models.Extractor = &IntentExtractor{}

const intentMaxTokens = 512

var IntentStringRegex = regexp.MustCompile(`(?i)^\s*intent\W+\s+`)

type IntentExtractor struct {
BaseExtractor
}
Expand Down Expand Up @@ -90,15 +94,20 @@ func (ee *IntentExtractor) processMessage(
}

// Send the populated prompt to the language model
resp, err := llms.RunChatCompletion(ctx, appState, intentMaxTokens, prompt)
intentContent, err := appState.LLMClient.Call(
ctx,
prompt,
llms.WithMaxTokens(intentMaxTokens),
)
if err != nil {
errs <- NewExtractorError("IntentExtractor: "+err.Error(), err)
return
}

// Get the intent from the response
intentContent := resp.Choices[0].Message.Content
intentContent = strings.TrimPrefix(intentContent, "Intent: ")
intentContent = IntentStringRegex.ReplaceAllStringFunc(intentContent, func(s string) string {
return ""
})

// if we don't have an intent, just return
if intentContent == "" {
Expand Down
Loading