From 6b945459ffb0fb3dd1b6e58beb599b2b03049840 Mon Sep 17 00:00:00 2001 From: Daniel Chalef <131175+danielchalef@users.noreply.github.com> Date: Sat, 13 May 2023 07:49:49 -0700 Subject: [PATCH] refactor embedding extractor and putMessages code (#38) Simplify putMessages by removing embedding code. Simplify embedding extractor by using messages provided in MessageEvent. --- pkg/extractors/embedder.go | 35 +++++----------- pkg/extractors/embedder_test.go | 13 +++--- pkg/memorystore/postgres.go | 72 ++++---------------------------- pkg/memorystore/postgres_test.go | 58 +++++-------------------- pkg/models/memorystore.go | 6 +-- 5 files changed, 38 insertions(+), 146 deletions(-) diff --git a/pkg/extractors/embedder.go b/pkg/extractors/embedder.go index c7aff161..c0292d32 100644 --- a/pkg/extractors/embedder.go +++ b/pkg/extractors/embedder.go @@ -26,31 +26,17 @@ func (ee *EmbeddingExtractor) Extract( sessionMutex.Lock() defer sessionMutex.Unlock() - unembeddedMessages, err := appState.MemoryStore.GetMessageVectors( - ctx, - appState, - messageEvent.SessionID, - false, - ) - if err != nil { - return NewExtractorError("EmbeddingExtractor get message vectors failed", err) - } - - if len(unembeddedMessages) == 0 { - return nil - } - - texts := embeddingsToTextSlice(unembeddedMessages, false) + texts := messageToStringSlice(messageEvent.Messages, false) embeddings, err := llms.EmbedMessages(ctx, appState, texts) if err != nil { return NewExtractorError("EmbeddingExtractor embed messages failed", err) } - embeddingRecords := make([]models.Embeddings, len(unembeddedMessages)) - for i, r := range unembeddedMessages { + embeddingRecords := make([]models.Embeddings, len(messageEvent.Messages)) + for i, r := range messageEvent.Messages { embeddingRecords[i] = models.Embeddings{ - TextUUID: r.TextUUID, + TextUUID: r.UUID, Embedding: (*embeddings)[i].Embedding, } } @@ -59,7 +45,6 @@ func (ee *EmbeddingExtractor) Extract( appState, messageEvent.SessionID, embeddingRecords, - true, ) if err != nil { return NewExtractorError("EmbeddingExtractor put message vectors failed", err) @@ -67,27 +52,27 @@ func (ee *EmbeddingExtractor) Extract( return nil } -// embeddingsToTextSlice converts a slice of Embeddings to a slice of strings. +// messageToStringSlice converts a slice of Embeddings to a slice of strings. // If enrich is true, the text slice will include the prior and subsequent // messages text to the slice item. -func embeddingsToTextSlice(messages []models.Embeddings, enrich bool) []string { +func messageToStringSlice(messages []models.Message, enrich bool) []string { texts := make([]string, len(messages)) for i, r := range messages { if !enrich { - texts[i] = r.Text + texts[i] = r.Content continue } var parts []string if i > 0 { - parts = append(parts, messages[i-1].Text) + parts = append(parts, messages[i-1].Content) } - parts = append(parts, r.Text) + parts = append(parts, r.Content) if i < len(messages)-1 { - parts = append(parts, messages[i+1].Text) + parts = append(parts, messages[i+1].Content) } texts[i] = strings.Join(parts, " ") diff --git a/pkg/extractors/embedder_test.go b/pkg/extractors/embedder_test.go index bfcce158..a43bc142 100644 --- a/pkg/extractors/embedder_test.go +++ b/pkg/extractors/embedder_test.go @@ -40,18 +40,20 @@ func TestEmbeddingExtractor_Extract(t *testing.T) { assert.NoError(t, err) // Get messages that are missing embeddings using appState.MemoryStore.GetMessageVectors - unembeddedMessages, err := store.GetMessageVectors(ctx, appState, sessionID, false) + memories, err := store.GetMemory(ctx, appState, sessionID, 0) assert.NoError(t, err) - assert.True(t, len(unembeddedMessages) == len(testMessages)) + assert.True(t, len(memories.Messages) == len(testMessages)) + unembeddedMessages := memories.Messages // Create messageEvent. We only need to pass the sessionID messageEvent := &models.MessageEvent{ SessionID: sessionID, + Messages: unembeddedMessages, } texts := make([]string, len(unembeddedMessages)) for i, r := range unembeddedMessages { - texts[i] = r.Text + texts[i] = r.Content } embeddings, err := llms.EmbedMessages(ctx, appState, texts) @@ -60,8 +62,8 @@ func TestEmbeddingExtractor_Extract(t *testing.T) { expectedEmbeddingRecords := make([]models.Embeddings, len(unembeddedMessages)) for i, r := range unembeddedMessages { expectedEmbeddingRecords[i] = models.Embeddings{ - TextUUID: r.TextUUID, - Text: r.Text, + TextUUID: r.UUID, + Text: r.Content, Embedding: (*embeddings)[i].Embedding, } } @@ -74,7 +76,6 @@ func TestEmbeddingExtractor_Extract(t *testing.T) { ctx, appState, messageEvent.SessionID, - true, ) assert.NoError(t, err) diff --git a/pkg/memorystore/postgres.go b/pkg/memorystore/postgres.go index a7e3e40f..78c86ae8 100644 --- a/pkg/memorystore/postgres.go +++ b/pkg/memorystore/postgres.go @@ -272,7 +272,6 @@ func (pms *PostgresMemoryStore) PutMemory( ctx, pms.Client, sessionID, - appState.Config.Extractors.Embeddings.Enabled, memoryMessages.Messages, ) if err != nil { @@ -386,7 +385,6 @@ func (pms *PostgresMemoryStore) PutMessageVectors(ctx context.Context, _ *models.AppState, sessionID string, embeddings []models.Embeddings, - isEmbedded bool, ) error { if embeddings == nil { return NewStorageError("nil embeddings received", nil) @@ -395,7 +393,7 @@ func (pms *PostgresMemoryStore) PutMessageVectors(ctx context.Context, return NewStorageError("no embeddings received", nil) } - err := putEmbeddings(ctx, pms.Client, sessionID, embeddings, isEmbedded) + err := putEmbeddings(ctx, pms.Client, sessionID, embeddings) if err != nil { return NewStorageError("failed to put embeddings", err) } @@ -406,9 +404,8 @@ func (pms *PostgresMemoryStore) PutMessageVectors(ctx context.Context, func (pms *PostgresMemoryStore) GetMessageVectors(ctx context.Context, _ *models.AppState, sessionID string, - isEmbedded bool, ) ([]models.Embeddings, error) { - embeddings, err := getMessageVectors(ctx, pms.Client, sessionID, isEmbedded) + embeddings, err := getMessageVectors(ctx, pms.Client, sessionID) if err != nil { return nil, NewStorageError("GetMessageVectors failed to get embeddings", err) } @@ -418,8 +415,7 @@ func (pms *PostgresMemoryStore) GetMessageVectors(ctx context.Context, func getMessageVectors(ctx context.Context, db *bun.DB, - sessionID string, - isEmbedded bool) ([]models.Embeddings, error) { + sessionID string) ([]models.Embeddings, error) { var results []struct { PgMessageStore PgMessageVectorStore @@ -431,7 +427,7 @@ func getMessageVectors(ctx context.Context, JoinOn("message_embedding.message_uuid = message.uuid"). ColumnExpr("message.content"). ColumnExpr("message_embedding.*"). - Where("message_embedding.is_embedded = ?", isEmbedded). + Where("message_embedding.is_embedded = ?", true). Where("message_embedding.session_id = ?", sessionID). Exec(ctx, &results) if err != nil { @@ -510,7 +506,6 @@ func putMessages( ctx context.Context, db *bun.DB, sessionID string, - embeddingEnabled bool, messages []models.Message, ) ([]models.Message, error) { if len(messages) == 0 { @@ -534,54 +529,11 @@ func putMessages( pgMessages[i].SessionID = sessionID } - // wrap in a transaction, so we can roll back if any of the inserts fail. We - // don't want to partially save messages without vectorstore records. - tx, err := db.Begin() - if err != nil { - return nil, err - } - defer func(tx bun.Tx) { - _ = tx.Rollback() - }(tx) - - _, err = tx.NewInsert().Model(&pgMessages).On("CONFLICT (uuid) DO UPDATE").Exec(ctx) + _, err = db.NewInsert().Model(&pgMessages).On("CONFLICT (uuid) DO UPDATE").Exec(ctx) if err != nil { return nil, NewStorageError("failed to save memories to store", err) } - // If embeddings are enabled, store the new messages for future embedding. - // The Embedded field will be false until we run the embedding extractor out of band. - if embeddingEnabled { - zeroVector := make( - []float32, - 1536, - ) // TODO: use config. will need to drill appState down to here - - // Extract new messages. We use the original messages slice to filter - // out messages that already have UUIDs. - var embedRecords []PgMessageVectorStore - for i, msg := range messages { - if msg.UUID == uuid.Nil { - e := PgMessageVectorStore{ - SessionID: sessionID, - MessageUUID: pgMessages[i].UUID, - Embedding: pgvector.NewVector(zeroVector), // Vector fields can't be null - } - embedRecords = append(embedRecords, e) - } - } - if len(embedRecords) > 0 { - _, err = tx.NewInsert().Model(&embedRecords).Exec(ctx) - if err != nil { - return nil, NewStorageError("failed to save memory vector records", err) - } - } - } - - if err := tx.Commit(); err != nil { - return nil, err - } - retMessages := make([]models.Message, len(messages)) err = copier.Copy(&retMessages, &pgMessages) if err != nil { @@ -778,7 +730,6 @@ func putEmbeddings( db *bun.DB, sessionID string, embeddings []models.Embeddings, - isEmbedded bool, ) error { if embeddings == nil { return NewStorageError("nil embeddings received", nil) @@ -793,19 +744,14 @@ func putEmbeddings( SessionID: sessionID, Embedding: pgvector.NewVector(e.Embedding), MessageUUID: e.TextUUID, + IsEmbedded: true, } } - values := db.NewValues(&embeddingVectors) - _, err := db.NewUpdate(). - With("_data", values). - Model((*PgMessageVectorStore)(nil)). - TableExpr("_data"). - Set("embedding = _data.embedding"). - Set("is_embedded = ?", isEmbedded). - Where("me.message_uuid = _data.message_uuid"). - OmitZero(). + _, err := db.NewInsert(). + Model(&embeddingVectors). Exec(ctx) + if err != nil { return NewStorageError("failed to insert message vectors", err) } diff --git a/pkg/memorystore/postgres_test.go b/pkg/memorystore/postgres_test.go index a52f25d1..6e7892d9 100644 --- a/pkg/memorystore/postgres_test.go +++ b/pkg/memorystore/postgres_test.go @@ -217,7 +217,7 @@ func TestPgDeleteSession(t *testing.T) { } // Call putMessages function - resultMessages, err := putMessages(testCtx, testDB, sessionID, true, messages) + resultMessages, err := putMessages(testCtx, testDB, sessionID, messages) assert.NoError(t, err, "putMessages should not return an error") // Put a summary @@ -275,7 +275,7 @@ func TestPutMessages(t *testing.T) { viper.Set("extractor.embeddings.enabled", true) t.Run("insert messages", func(t *testing.T) { - resultMessages, err := putMessages(testCtx, testDB, sessionID, true, messages) + resultMessages, err := putMessages(testCtx, testDB, sessionID, messages) assert.NoError(t, err, "putMessages should not return an error") // Verify the inserted messages in the database @@ -292,7 +292,7 @@ func TestPutMessages(t *testing.T) { } // Call putMessages function to upsert the messages - resultMessages, err := putMessages(testCtx, testDB, sessionID, true, messages) + resultMessages, err := putMessages(testCtx, testDB, sessionID, messages) assert.NoError(t, err, "putMessages should not return an error") // Verify the upserted messages in the database @@ -374,7 +374,7 @@ func TestGetMessages(t *testing.T) { _, err = putSession(testCtx, testDB, sessionID, metadata) assert.NoError(t, err) - messages, err := putMessages(testCtx, testDB, sessionID, true, test.TestMessages) + messages, err := putMessages(testCtx, testDB, sessionID, test.TestMessages) assert.NoError(t, err) // Explicitly set the message window to 10 @@ -458,44 +458,6 @@ func TestGetMessages(t *testing.T) { } } -func TestGetMessageVectorsWhereIsEmbeddedFalse(t *testing.T) { - // Create a test session - sessionID, err := test.GenerateRandomSessionID(16) - assert.NoError(t, err, "GenerateRandomSessionID should not return an error") - metadata := map[string]interface{}{ - "key": "value", - } - _, err = putSession(testCtx, testDB, sessionID, metadata) - assert.NoError(t, err) - - messages := []models.Message{ - { - Role: "user", - Content: "Hello", - Metadata: map[string]interface{}{"timestamp": "1629462540"}, - }, - { - Role: "bot", - Content: "Hi there!", - Metadata: map[string]interface{}{"something": "good"}, - }, - } - - addedMessages, err := putMessages(testCtx, testDB, sessionID, true, messages) - assert.NoError(t, err) - - // getMessageVectors only for isEmbedded = false - embeddings, err := getMessageVectors(testCtx, testDB, sessionID, false) - assert.NoError(t, err) - assert.Equal(t, len(messages), len(embeddings)) - - for i, emb := range embeddings { - assert.NotNil(t, emb.TextUUID) - assert.NotEmpty(t, emb.Embedding) - assert.Equal(t, addedMessages[i].UUID, emb.TextUUID) - } -} - func TestPutSummary(t *testing.T) { sessionID, err := test.GenerateRandomSessionID(16) assert.NoError(t, err, "GenerateRandomSessionID should not return an error") @@ -517,7 +479,7 @@ func TestPutSummary(t *testing.T) { } // Call putMessages function - resultMessages, err := putMessages(testCtx, testDB, sessionID, true, messages) + resultMessages, err := putMessages(testCtx, testDB, sessionID, messages) assert.NoError(t, err, "putMessages should not return an error") tests := []struct { @@ -620,7 +582,7 @@ func TestGetSummary(t *testing.T) { } // Call putMessages function - resultMessages, err := putMessages(testCtx, testDB, sessionID, true, messages) + resultMessages, err := putMessages(testCtx, testDB, sessionID, messages) assert.NoError(t, err, "putMessages should not return an error") summary.SummaryPointUUID = resultMessages[0].UUID @@ -687,7 +649,7 @@ func TestPutEmbeddings(t *testing.T) { viper.Set("extractor.embeddings.enabled", true) // Call putMessages function - resultMessages, err := putMessages(testCtx, testDB, sessionID, true, messages) + resultMessages, err := putMessages(testCtx, testDB, sessionID, messages) assert.NoError(t, err, "putMessages should not return an error") vector := make([]float32, 1536) @@ -706,7 +668,7 @@ func TestPutEmbeddings(t *testing.T) { }, } - err = putEmbeddings(testCtx, testDB, sessionID, embeddings, true) + err = putEmbeddings(testCtx, testDB, sessionID, embeddings) assert.NoError(t, err, "putEmbeddings should not return an error") // Check for the creation of PgMessageVectorStore values @@ -756,7 +718,7 @@ func TestLastSummaryPointIndex(t *testing.T) { assert.NoError(t, err, "putSession should not return an error") // Call putMessages function using internal.TestMessages - resultMessages, err := putMessages(testCtx, testDB, sessionID, true, test.TestMessages) + resultMessages, err := putMessages(testCtx, testDB, sessionID, test.TestMessages) assert.NoError(t, err, "putMessages should not return an error") configuredMessageWindow := 30 @@ -818,7 +780,7 @@ func TestSearch(t *testing.T) { assert.NoError(t, err, "GenerateRandomSessionID should not return an error") // Call putMessages function - msgs, err := putMessages(testCtx, testDB, sessionID, true, test.TestMessages) + msgs, err := putMessages(testCtx, testDB, sessionID, test.TestMessages) assert.NoError(t, err, "putMessages should not return an error") appState.MemoryStore.NotifyExtractors( diff --git a/pkg/models/memorystore.go b/pkg/models/memorystore.go index 6b80ce06..daaa08c4 100644 --- a/pkg/models/memorystore.go +++ b/pkg/models/memorystore.go @@ -33,15 +33,13 @@ type MemoryStore[T any] interface { PutMessageVectors(ctx context.Context, appState *AppState, sessionID string, - embeddings []Embeddings, - isEmbedded bool) error + embeddings []Embeddings) error // GetMessageVectors retrieves a collection of Embeddings for a given sessionID. isEmbedded is a flag that // whether the Embeddings records have been embedded. The Embeddings extractor uses this internally to determine // which records still need to be embedded. GetMessageVectors(ctx context.Context, appState *AppState, - sessionID string, - isEmbedded bool) ([]Embeddings, error) + sessionID string) ([]Embeddings, error) // SearchMemory retrieves a collection of SearchResults for a given sessionID and query. Currently, the query // is a simple string, but this could be extended to support more complex queries in the future. The SearchResult // structure can include both Messages and Summaries. Currently, we only search Messages.