Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Club vector embeddings + Pinecone seeding/testing logic #292

Binary file added backend/src/__debug_bin3800770352
Binary file not shown.
1 change: 0 additions & 1 deletion backend/src/database/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package database
import (
"github.com/GenerateNU/sac/backend/src/config"
"github.com/GenerateNU/sac/backend/src/models"

"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/logger"
Expand Down
6 changes: 5 additions & 1 deletion backend/src/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/GenerateNU/sac/backend/src/config"
"github.com/GenerateNU/sac/backend/src/database"
_ "github.com/GenerateNU/sac/backend/src/docs"
"github.com/GenerateNU/sac/backend/src/search"
"github.com/GenerateNU/sac/backend/src/server"
)

Expand Down Expand Up @@ -35,6 +36,9 @@ func main() {
panic(fmt.Sprintf("Error configuring database: %s", err.Error()))
}

openAi := search.NewOpenAIClient(config.OpenAISettings)
pinecone := search.NewPineconeClient(openAi, config.PineconeSettings)

if *onlyMigrate {
return
}
Expand All @@ -44,7 +48,7 @@ func main() {
panic(fmt.Sprintf("Error connecting to database: %s", err.Error()))
}

app := server.Init(db, *config)
app := server.Init(db, pinecone, *config)

err = app.Listen(fmt.Sprintf("%s:%d", config.Application.Host, config.Application.Port))
if err != nil {
Expand Down
3,576 changes: 2,621 additions & 955 deletions backend/src/migrations/data.sql

Large diffs are not rendered by default.

8 changes: 5 additions & 3 deletions backend/src/models/club.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,11 @@ type Club struct {

SoftDeletedAt gorm.DeletedAt `gorm:"type:timestamptz;default:NULL" json:"-" validate:"-"`

Name string `gorm:"type:varchar(255)" json:"name" validate:"required,max=255"`
Preview string `gorm:"type:varchar(255)" json:"preview" validate:"required,max=255"`
Description string `gorm:"type:varchar(255)" json:"description" validate:"required,http_url,mongo_url,max=255"` // MongoDB URL
Name string `gorm:"type:varchar(255)" json:"name" validate:"required,max=255"`
Preview string `gorm:"type:varchar(255)" json:"preview" validate:"required,max=255"`
// FIXME: make description a mongodb url again
/*Description string `gorm:"type:varchar(255)" json:"description" validate:"required,http_url,mongo_url,max=255"` // MongoDB URL*/
Description string `gorm:"type:text" json:"description" validate:"required"`
NumMembers int `gorm:"type:int" json:"num_members" validate:"required,min=1"`
IsRecruiting bool `gorm:"type:bool;default:false" json:"is_recruiting" validate:"required"`
RecruitmentCycle RecruitmentCycle `gorm:"type:varchar(255);default:always" json:"recruitment_cycle" validate:"required,max=255,oneof=fall spring fallSpring always"`
Expand Down
17 changes: 10 additions & 7 deletions backend/src/search/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ func NewOpenAIClient(settings config.OpenAISettings) *OpenAIClient {
}

type CreateEmbeddingRequestBody struct {
Input string `json:"input"`
Model string `json:"model"`
Input []string `json:"input"`
Model string `json:"model"`
}

type Embedding struct {
Expand All @@ -37,10 +37,15 @@ type CreateEmbeddingResponseBody struct {
Data []Embedding `json:"data"`
}

func (c *OpenAIClient) CreateEmbedding(payload string) ([]float32, *errors.Error) {
func (c *OpenAIClient) CreateEmbedding(items []Searchable) ([]Embedding, *errors.Error) {
embeddingStrings := []string{}
for _, item := range items {
embeddingStrings = append(embeddingStrings, item.EmbeddingString())
}

embeddingBody, err := json.Marshal(
CreateEmbeddingRequestBody{
Input: payload,
Input: embeddingStrings,
Model: "text-embedding-ada-002",
})
if err != nil {
Expand Down Expand Up @@ -81,7 +86,5 @@ func (c *OpenAIClient) CreateEmbedding(payload string) ([]float32, *errors.Error
return nil, &errors.FailedToCreateEmbedding
}

EMBEDDING_INDEX := 0

return embeddingResultBody.Data[EMBEDDING_INDEX].Embedding, nil
return embeddingResultBody.Data, nil
}
184 changes: 168 additions & 16 deletions backend/src/search/pinecone.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,32 +5,167 @@
"fmt"
"net/http"

"github.com/garrettladley/mattress"
"github.com/google/uuid"
"gorm.io/gorm"

"github.com/goccy/go-json"
"github.com/gofiber/fiber/v2"

"github.com/GenerateNU/sac/backend/src/config"
"github.com/GenerateNU/sac/backend/src/errors"
"github.com/GenerateNU/sac/backend/src/models"
"github.com/GenerateNU/sac/backend/src/utilities"

stdliberrors "errors"
)

type PineconeClientInterface interface {
Upsert(item Searchable) *errors.Error
Delete(item Searchable) *errors.Error
Upsert(items []Searchable) *errors.Error
Delete(items []Searchable) *errors.Error
Search(item Searchable, topK int) ([]string, *errors.Error)
}

type PineconeClient struct {
Settings config.PineconeSettings
IndexName *mattress.Secret[string]
openAIClient *OpenAIClient
}

// Connects to an existing Pinecone index, using the host and keys provided in settings.
func NewPineconeClient(openAIClient *OpenAIClient, settings config.PineconeSettings) *PineconeClient {
return &PineconeClient{
Settings: settings,
openAIClient: openAIClient,
}
}

type PineconePodRequest struct {
Environment string `json:"environment"`
PodType string `json:"pod_type"`
}

type PineconeSpecRequest struct {
Pod PineconePodRequest `json:"pod"`
}

type PineconeCreateIndexRequestBody struct {
Name string `json:"name"`
Dimension int32 `json:"dimension"`
Cosine string `json:"metric"`
Spec PineconeSpecRequest `json:"spec"`
}

type PineconeCreateIndexResponseBody struct {
Host string `json:"host"`
}

// Similar to NewPineconeClient, but instead of connecting to an existing index, creates a new one.
func NewPineconeClientCreateIndex(openAIClient *OpenAIClient, settings config.PineconeSettings) (*PineconeClient, error) {
newIndexUUID, err := uuid.NewUUID()
if err != nil {
return nil, err
}
newIndexName := fmt.Sprintf("dev-%s", newIndexUUID.String())

createIndexBody, err := json.Marshal(
PineconeCreateIndexRequestBody{
Name: newIndexName,
Dimension: 1536,
Cosine: "cosine",
Spec: PineconeSpecRequest{
Pod: PineconePodRequest{
Environment: "gcp-starter",
PodType: "p1.x1",
},
},
})

if err != nil {
return nil, err
}

req, err := http.NewRequest(fiber.MethodPost,
"https://api.pinecone.io/indexes",
bytes.NewBuffer(createIndexBody))
if err != nil {
return nil, err
}

req = utilities.ApplyModifiers(req,
utilities.HeaderKV("Api-Key", settings.APIKey.Expose()),
utilities.AcceptJSON(),
utilities.JSON())

resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()

if resp.StatusCode != fiber.StatusCreated {
return nil, nil
}

var body PineconeCreateIndexResponseBody
err = json.NewDecoder(resp.Body).Decode(&body)
if err != nil {
return nil, err
}

indexHostSecret, err := mattress.NewSecret(fmt.Sprintf("https://%s", body.Host))

Check failure on line 116 in backend/src/search/pinecone.go

View workflow job for this annotation

GitHub Actions / Lint

ineffectual assignment to err (ineffassign)

Check warning

Code scanning / CodeQL

Useless assignment to local variable Warning

This definition of err is never used.
indexNameSecret, err := mattress.NewSecret(newIndexName)

Check failure on line 117 in backend/src/search/pinecone.go

View workflow job for this annotation

GitHub Actions / Lint

ineffectual assignment to err (ineffassign)

Check warning

Code scanning / CodeQL

Useless assignment to local variable Warning

This definition of err is never used.
return &PineconeClient{
Settings: config.PineconeSettings{
IndexHost: indexHostSecret,
APIKey: settings.APIKey,
},
IndexName: indexNameSecret,
openAIClient: openAIClient,
}, nil
}

type PineconeDeleteIndexRequestBody struct {
IndexName string `json:"index_name"`
}

// Seeds the pinecone index with the clubs currently in the database.
func (c *PineconeClient) Seed(db *gorm.DB) error {
var clubs []models.Club

if err := db.Find(&clubs).Error; err != nil {
return err
}

var searchables []Searchable

Check failure on line 140 in backend/src/search/pinecone.go

View workflow job for this annotation

GitHub Actions / Lint

Consider pre-allocating `searchables` (prealloc)
for _, club := range clubs {
searchables = append(searchables, &club)

Check failure on line 142 in backend/src/search/pinecone.go

View workflow job for this annotation

GitHub Actions / Lint

G601: Implicit memory aliasing in for loop. (gosec)
}

var chunks [][]Searchable
chunkSize := 50

for i := 0; i < len(searchables); i += chunkSize {
end := i + chunkSize

if end > len(searchables) {
end = len(searchables)
}

chunks = append(chunks, searchables[i:end])
}

for i, chunk := range chunks {
print(fmt.Sprintf("Uploading chunk #%d (of %d) to pinecone...\n", i+1, len(chunks)))
err := c.Upsert(chunk)
if err != nil {
return stdliberrors.New("Club upsert failed...")
}
}

return nil
}

func (c *PineconeClient) pineconeRequest(req *http.Request) *http.Request {
return utilities.ApplyModifiers(req,
utilities.HeaderKV("Api-Key", c.Settings.APIKey.Expose()),
Expand All @@ -49,21 +184,28 @@
Namespace string `json:"namespace"`
}

func (c *PineconeClient) Upsert(item Searchable) *errors.Error {
values, embeddingErr := c.openAIClient.CreateEmbedding(item.EmbeddingString())
func (c *PineconeClient) Upsert(items []Searchable) *errors.Error {
if len(items) == 0 {
return nil
}

embeddings, embeddingErr := c.openAIClient.CreateEmbedding(items)
if embeddingErr != nil {
return &errors.FailedToUpsertToPinecone
}

vectors := []Vector{}
for i, item := range items {
vectors = append(vectors, Vector{
ID: item.SearchId(),
Values: embeddings[i].Embedding,
})
}

upsertBody, err := json.Marshal(
PineconeUpsertRequestBody{
Vectors: []Vector{
{
ID: item.SearchId(),
Values: values,
},
},
Namespace: item.Namespace(),
Vectors: vectors,
Namespace: items[0].Namespace(),
})
if err != nil {
return &errors.FailedToUpsertToPinecone
Expand All @@ -82,6 +224,7 @@
if err != nil {
return &errors.FailedToUpsertToPinecone
}
defer resp.Body.Close()

if resp.StatusCode != fiber.StatusOK {
return &errors.FailedToUpsertToPinecone
Expand All @@ -104,11 +247,20 @@
}
}

func (c *PineconeClient) Delete(item Searchable) *errors.Error {
func (c *PineconeClient) Delete(items []Searchable) *errors.Error {
if len(items) == 0 {
return nil
}

itemIds := []string{}
for _, item := range items {
itemIds = append(itemIds, item.SearchId())
}

deleteBody, err := json.Marshal(
PineconeDeleteRequestBody{
IDs: []string{item.SearchId()},
Namespace: item.Namespace(),
IDs: itemIds,
Namespace: items[0].Namespace(),
DeleteAll: false,
})
if err != nil {
Expand Down Expand Up @@ -156,7 +308,7 @@
}

func (c *PineconeClient) Search(item Searchable, topK int) ([]string, *errors.Error) {
values, embeddingErr := c.openAIClient.CreateEmbedding(item.EmbeddingString())
values, embeddingErr := c.openAIClient.CreateEmbedding([]Searchable{item})
if embeddingErr != nil {
return []string{}, embeddingErr
}
Expand All @@ -166,7 +318,7 @@
IncludeValues: false,
IncludeMetadata: false,
TopK: topK,
Vector: values,
Vector: values[0].Embedding,
Namespace: item.Namespace(),
})
if err != nil {
Expand Down
5 changes: 3 additions & 2 deletions backend/src/server/routes/club.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@ import (
p "github.com/GenerateNU/sac/backend/src/auth"
"github.com/GenerateNU/sac/backend/src/controllers"
"github.com/GenerateNU/sac/backend/src/middleware"
"github.com/GenerateNU/sac/backend/src/search"
"github.com/GenerateNU/sac/backend/src/services"
"github.com/go-playground/validator/v10"
"github.com/gofiber/fiber/v2"
"gorm.io/gorm"
)

func ClubRoutes(router fiber.Router, db *gorm.DB, validate *validator.Validate, authMiddleware *middleware.AuthMiddlewareService) {
clubIDRouter := Club(router, services.NewClubService(db, validate), authMiddleware)
func ClubRoutes(router fiber.Router, db *gorm.DB, pinecone search.PineconeClientInterface, validate *validator.Validate, authMiddleware *middleware.AuthMiddlewareService) {
clubIDRouter := Club(router, services.NewClubService(db, pinecone, validate), authMiddleware)

ClubTag(clubIDRouter, services.NewClubTagService(db, validate), authMiddleware)
ClubFollower(clubIDRouter, services.NewClubFollowerService(db), authMiddleware)
Expand Down
6 changes: 3 additions & 3 deletions backend/src/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (

"github.com/GenerateNU/sac/backend/src/config"
"github.com/GenerateNU/sac/backend/src/middleware"
"github.com/GenerateNU/sac/backend/src/search"
"github.com/GenerateNU/sac/backend/src/server/routes"
"github.com/GenerateNU/sac/backend/src/services"
"github.com/GenerateNU/sac/backend/src/utilities"
Expand All @@ -25,8 +26,7 @@ import (
// @contact.email [email protected] and [email protected]
// @host 127.0.0.1:8080
// @BasePath /
// @schemes http https
func Init(db *gorm.DB, settings config.Settings) *fiber.App {
func Init(db *gorm.DB, pinecone search.PineconeClientInterface, settings config.Settings) *fiber.App {
app := newFiberApp(settings.Application)

validate, err := utilities.RegisterCustomValidators()
Expand All @@ -43,7 +43,7 @@ func Init(db *gorm.DB, settings config.Settings) *fiber.App {
routes.Auth(apiv1, services.NewAuthService(db, validate), settings.Auth, authMiddleware)
routes.UserRoutes(apiv1, db, validate, authMiddleware)
routes.Contact(apiv1, services.NewContactService(db, validate), authMiddleware)
routes.ClubRoutes(apiv1, db, validate, authMiddleware)
routes.ClubRoutes(apiv1, db, pinecone, validate, authMiddleware)
routes.Tag(apiv1, services.NewTagService(db, validate), authMiddleware)
routes.CategoryRoutes(apiv1, db, validate, authMiddleware)
routes.Event(apiv1, services.NewEventService(db, validate), authMiddleware)
Expand Down
Loading
Loading