Skip to content

Commit

Permalink
refactor embedding extractor and putMessages code (#38)
Browse files Browse the repository at this point in the history
Simplify putMessages by removing embedding code. Simplify embedding extractor by using messages provided in MessageEvent.
  • Loading branch information
danielchalef authored May 13, 2023
1 parent 754c4f4 commit 6b94545
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 146 deletions.
35 changes: 10 additions & 25 deletions pkg/extractors/embedder.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,31 +26,17 @@ func (ee *EmbeddingExtractor) Extract(
sessionMutex.Lock()
defer sessionMutex.Unlock()

unembeddedMessages, err := appState.MemoryStore.GetMessageVectors(
ctx,
appState,
messageEvent.SessionID,
false,
)
if err != nil {
return NewExtractorError("EmbeddingExtractor get message vectors failed", err)
}

if len(unembeddedMessages) == 0 {
return nil
}

texts := embeddingsToTextSlice(unembeddedMessages, false)
texts := messageToStringSlice(messageEvent.Messages, false)

embeddings, err := llms.EmbedMessages(ctx, appState, texts)
if err != nil {
return NewExtractorError("EmbeddingExtractor embed messages failed", err)
}

embeddingRecords := make([]models.Embeddings, len(unembeddedMessages))
for i, r := range unembeddedMessages {
embeddingRecords := make([]models.Embeddings, len(messageEvent.Messages))
for i, r := range messageEvent.Messages {
embeddingRecords[i] = models.Embeddings{
TextUUID: r.TextUUID,
TextUUID: r.UUID,
Embedding: (*embeddings)[i].Embedding,
}
}
Expand All @@ -59,35 +45,34 @@ func (ee *EmbeddingExtractor) Extract(
appState,
messageEvent.SessionID,
embeddingRecords,
true,
)
if err != nil {
return NewExtractorError("EmbeddingExtractor put message vectors failed", err)
}
return nil
}

// embeddingsToTextSlice converts a slice of Embeddings to a slice of strings.
// messageToStringSlice converts a slice of Embeddings to a slice of strings.
// If enrich is true, the text slice will include the prior and subsequent
// messages text to the slice item.
func embeddingsToTextSlice(messages []models.Embeddings, enrich bool) []string {
func messageToStringSlice(messages []models.Message, enrich bool) []string {
texts := make([]string, len(messages))
for i, r := range messages {
if !enrich {
texts[i] = r.Text
texts[i] = r.Content
continue
}

var parts []string

if i > 0 {
parts = append(parts, messages[i-1].Text)
parts = append(parts, messages[i-1].Content)
}

parts = append(parts, r.Text)
parts = append(parts, r.Content)

if i < len(messages)-1 {
parts = append(parts, messages[i+1].Text)
parts = append(parts, messages[i+1].Content)
}

texts[i] = strings.Join(parts, " ")
Expand Down
13 changes: 7 additions & 6 deletions pkg/extractors/embedder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,20 @@ func TestEmbeddingExtractor_Extract(t *testing.T) {
assert.NoError(t, err)

// Get messages that are missing embeddings using appState.MemoryStore.GetMessageVectors
unembeddedMessages, err := store.GetMessageVectors(ctx, appState, sessionID, false)
memories, err := store.GetMemory(ctx, appState, sessionID, 0)
assert.NoError(t, err)
assert.True(t, len(unembeddedMessages) == len(testMessages))
assert.True(t, len(memories.Messages) == len(testMessages))

unembeddedMessages := memories.Messages
// Create messageEvent. We only need to pass the sessionID
messageEvent := &models.MessageEvent{
SessionID: sessionID,
Messages: unembeddedMessages,
}

texts := make([]string, len(unembeddedMessages))
for i, r := range unembeddedMessages {
texts[i] = r.Text
texts[i] = r.Content
}

embeddings, err := llms.EmbedMessages(ctx, appState, texts)
Expand All @@ -60,8 +62,8 @@ func TestEmbeddingExtractor_Extract(t *testing.T) {
expectedEmbeddingRecords := make([]models.Embeddings, len(unembeddedMessages))
for i, r := range unembeddedMessages {
expectedEmbeddingRecords[i] = models.Embeddings{
TextUUID: r.TextUUID,
Text: r.Text,
TextUUID: r.UUID,
Text: r.Content,
Embedding: (*embeddings)[i].Embedding,
}
}
Expand All @@ -74,7 +76,6 @@ func TestEmbeddingExtractor_Extract(t *testing.T) {
ctx,
appState,
messageEvent.SessionID,
true,
)
assert.NoError(t, err)

Expand Down
72 changes: 9 additions & 63 deletions pkg/memorystore/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,6 @@ func (pms *PostgresMemoryStore) PutMemory(
ctx,
pms.Client,
sessionID,
appState.Config.Extractors.Embeddings.Enabled,
memoryMessages.Messages,
)
if err != nil {
Expand Down Expand Up @@ -386,7 +385,6 @@ func (pms *PostgresMemoryStore) PutMessageVectors(ctx context.Context,
_ *models.AppState,
sessionID string,
embeddings []models.Embeddings,
isEmbedded bool,
) error {
if embeddings == nil {
return NewStorageError("nil embeddings received", nil)
Expand All @@ -395,7 +393,7 @@ func (pms *PostgresMemoryStore) PutMessageVectors(ctx context.Context,
return NewStorageError("no embeddings received", nil)
}

err := putEmbeddings(ctx, pms.Client, sessionID, embeddings, isEmbedded)
err := putEmbeddings(ctx, pms.Client, sessionID, embeddings)
if err != nil {
return NewStorageError("failed to put embeddings", err)
}
Expand All @@ -406,9 +404,8 @@ func (pms *PostgresMemoryStore) PutMessageVectors(ctx context.Context,
func (pms *PostgresMemoryStore) GetMessageVectors(ctx context.Context,
_ *models.AppState,
sessionID string,
isEmbedded bool,
) ([]models.Embeddings, error) {
embeddings, err := getMessageVectors(ctx, pms.Client, sessionID, isEmbedded)
embeddings, err := getMessageVectors(ctx, pms.Client, sessionID)
if err != nil {
return nil, NewStorageError("GetMessageVectors failed to get embeddings", err)
}
Expand All @@ -418,8 +415,7 @@ func (pms *PostgresMemoryStore) GetMessageVectors(ctx context.Context,

func getMessageVectors(ctx context.Context,
db *bun.DB,
sessionID string,
isEmbedded bool) ([]models.Embeddings, error) {
sessionID string) ([]models.Embeddings, error) {
var results []struct {
PgMessageStore
PgMessageVectorStore
Expand All @@ -431,7 +427,7 @@ func getMessageVectors(ctx context.Context,
JoinOn("message_embedding.message_uuid = message.uuid").
ColumnExpr("message.content").
ColumnExpr("message_embedding.*").
Where("message_embedding.is_embedded = ?", isEmbedded).
Where("message_embedding.is_embedded = ?", true).
Where("message_embedding.session_id = ?", sessionID).
Exec(ctx, &results)
if err != nil {
Expand Down Expand Up @@ -510,7 +506,6 @@ func putMessages(
ctx context.Context,
db *bun.DB,
sessionID string,
embeddingEnabled bool,
messages []models.Message,
) ([]models.Message, error) {
if len(messages) == 0 {
Expand All @@ -534,54 +529,11 @@ func putMessages(
pgMessages[i].SessionID = sessionID
}

// wrap in a transaction, so we can roll back if any of the inserts fail. We
// don't want to partially save messages without vectorstore records.
tx, err := db.Begin()
if err != nil {
return nil, err
}
defer func(tx bun.Tx) {
_ = tx.Rollback()
}(tx)

_, err = tx.NewInsert().Model(&pgMessages).On("CONFLICT (uuid) DO UPDATE").Exec(ctx)
_, err = db.NewInsert().Model(&pgMessages).On("CONFLICT (uuid) DO UPDATE").Exec(ctx)
if err != nil {
return nil, NewStorageError("failed to save memories to store", err)
}

// If embeddings are enabled, store the new messages for future embedding.
// The Embedded field will be false until we run the embedding extractor out of band.
if embeddingEnabled {
zeroVector := make(
[]float32,
1536,
) // TODO: use config. will need to drill appState down to here

// Extract new messages. We use the original messages slice to filter
// out messages that already have UUIDs.
var embedRecords []PgMessageVectorStore
for i, msg := range messages {
if msg.UUID == uuid.Nil {
e := PgMessageVectorStore{
SessionID: sessionID,
MessageUUID: pgMessages[i].UUID,
Embedding: pgvector.NewVector(zeroVector), // Vector fields can't be null
}
embedRecords = append(embedRecords, e)
}
}
if len(embedRecords) > 0 {
_, err = tx.NewInsert().Model(&embedRecords).Exec(ctx)
if err != nil {
return nil, NewStorageError("failed to save memory vector records", err)
}
}
}

if err := tx.Commit(); err != nil {
return nil, err
}

retMessages := make([]models.Message, len(messages))
err = copier.Copy(&retMessages, &pgMessages)
if err != nil {
Expand Down Expand Up @@ -778,7 +730,6 @@ func putEmbeddings(
db *bun.DB,
sessionID string,
embeddings []models.Embeddings,
isEmbedded bool,
) error {
if embeddings == nil {
return NewStorageError("nil embeddings received", nil)
Expand All @@ -793,19 +744,14 @@ func putEmbeddings(
SessionID: sessionID,
Embedding: pgvector.NewVector(e.Embedding),
MessageUUID: e.TextUUID,
IsEmbedded: true,
}
}

values := db.NewValues(&embeddingVectors)
_, err := db.NewUpdate().
With("_data", values).
Model((*PgMessageVectorStore)(nil)).
TableExpr("_data").
Set("embedding = _data.embedding").
Set("is_embedded = ?", isEmbedded).
Where("me.message_uuid = _data.message_uuid").
OmitZero().
_, err := db.NewInsert().
Model(&embeddingVectors).
Exec(ctx)

if err != nil {
return NewStorageError("failed to insert message vectors", err)
}
Expand Down
Loading

0 comments on commit 6b94545

Please sign in to comment.