Skip to content

Commit

Permalink
event embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
michael-brennan2005 committed Jun 20, 2024
1 parent fd48e64 commit fa1ba59
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 2 deletions.
55 changes: 55 additions & 0 deletions backend/background/jobs/event_embeddings.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package jobs

import (
"context"
"fmt"
"log/slog"
"time"

"github.com/GenerateNU/sac/backend/background"
"github.com/GenerateNU/sac/backend/entities/models"
"github.com/GenerateNU/sac/backend/search"

"github.com/GenerateNU/sac/backend/constants"
)

// Generate event embeddings for events that did not receive them when being created or updated. This could occur in the case of
// mock data (which is uploaded to postgres directly, doesn't go through the app), or in the case OpenAI API goes down (service outage, bad api key, etc)
func (j *Jobs) EventEmbeddings(ctx context.Context) background.JobFunc {
return func() {
t := time.NewTicker(constants.EMBEDDINGS_GENERATION_INTERVAL)

for range t.C {
func() {
tx := j.db.WithContext(ctx).Begin()
defer func() {
if r := recover(); r != nil {
tx.Rollback()
}
}()

var event models.Event
if err := tx.Raw("SELECT * FROM events WHERE embedding IS NULL FOR UPDATE SKIP LOCKED LIMIT 1").Scan(&event).Error; err != nil {
tx.Rollback()
return
}

if event.Name == "" && event.Preview == "" && event.Description == "" { // empty club
tx.Rollback()
return
}

slog.Info(fmt.Sprintf("Generating embeddings for event '%s' (%s)", event.Name, event.ID.String()))

if err := search.UpsertEventEmbedding(tx, j.search, &event); err != nil {
tx.Rollback()
return
}

if err := tx.Commit().Error; err != nil {
return
}
}()
}
}
}
1 change: 1 addition & 0 deletions backend/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ func startBackgroundJobs(ctx context.Context, db *gorm.DB, settings *config.Sett
jobs := jobs.New(db, settings)
background.Go(jobs.WelcomeSender(ctx))
background.Go(jobs.ClubEmbeddings(ctx))
background.Go(jobs.EventEmbeddings(ctx))
}

func configureIntegrations(config *config.Integrations) *integrations.Integrations {
Expand Down
11 changes: 9 additions & 2 deletions backend/search/embeddings.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/goccy/go-json"
"github.com/gofiber/fiber/v2"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)

type CreateEmbeddingRequestBody struct {
Expand Down Expand Up @@ -41,7 +42,10 @@ func UpsertClubEmbedding(db *gorm.DB, s *config.SearchSettings, club *models.Clu
queryString := fmt.Sprintf(
"UPDATE clubs SET embedding = '[%s]' WHERE id = '%s'", embeddingStr, club.ID.String())

if err := db.Exec(queryString).Error; err != nil {
// Keep stdout/logs clean, don't output 512 floats
session := db.Session(&gorm.Session{Logger: logger.Default.LogMode(logger.Error)})

if err := session.Exec(queryString).Error; err != nil {
return err
}

Expand All @@ -60,7 +64,10 @@ func UpsertEventEmbedding(db *gorm.DB, s *config.SearchSettings, event *models.E
queryString := fmt.Sprintf(
"UPDATE events SET embedding = '[%s]' WHERE id = '%s'", embeddingStr, event.ID.String())

if err := db.Exec(queryString).Error; err != nil {
// Keep stdout/logs clean, don't output 512 floats
session := db.Session(&gorm.Session{Logger: logger.Default.LogMode(logger.Error)})

if err := session.Exec(queryString).Error; err != nil {
return err
}

Expand Down

0 comments on commit fa1ba59

Please sign in to comment.