diff --git a/db/chat.go b/db/chat.go index 1659397bc..5956a8f12 100644 --- a/db/chat.go +++ b/db/chat.go @@ -24,6 +24,31 @@ func (db database) AddChat(chat *Chat) (Chat, error) { return *chat, nil } +func (db database) UpdateChat(chat *Chat) (Chat, error) { + if chat.ID == "" { + return Chat{}, errors.New("chat ID is required") + } + + var existingChat Chat + if err := db.db.First(&existingChat, "id = ?", chat.ID).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return Chat{}, fmt.Errorf("chat not found") + } + return Chat{}, fmt.Errorf("failed to fetch chat: %w", err) + } + + if chat.Title != "" { + existingChat.Title = chat.Title + } + existingChat.UpdatedAt = time.Now() + + if err := db.db.Save(&existingChat).Error; err != nil { + return Chat{}, fmt.Errorf("failed to update chat: %w", err) + } + + return existingChat, nil +} + func (db database) GetChatByChatID(chatID string) (Chat, error) { var chat Chat result := db.db.Where("id = ?", chatID).First(&chat) diff --git a/db/interface.go b/db/interface.go index 571fab083..cf636a36f 100644 --- a/db/interface.go +++ b/db/interface.go @@ -202,6 +202,7 @@ type Database interface { GetFeatureBrief(featureUuid string) (string, error) GetTicketsByPhaseUUID(featureUUID string, phaseUUID string) ([]Tickets, error) AddChat(chat *Chat) (Chat, error) + UpdateChat(chat *Chat) (Chat, error) GetChatByChatID(chatID string) (Chat, error) AddChatMessage(message *ChatMessage) (ChatMessage, error) UpdateChatMessage(message *ChatMessage) (ChatMessage, error) diff --git a/handlers/chat.go b/handlers/chat.go index 9fc13c677..af676666c 100644 --- a/handlers/chat.go +++ b/handlers/chat.go @@ -96,6 +96,62 @@ func (ch *ChatHandler) CreateChat(w http.ResponseWriter, r *http.Request) { }) } +func (ch *ChatHandler) UpdateChat(w http.ResponseWriter, r *http.Request) { + chatID := chi.URLParam(r, "chat_id") + if chatID == "" { + w.WriteHeader(http.StatusBadRequest) + json.NewEncoder(w).Encode(ChatResponse{ + Success: false, + Message: "Chat ID is required", + }) + return + } + + var request struct { + WorkspaceID string `json:"workspaceId"` + Title string `json:"title"` + } + + if err := json.NewDecoder(r.Body).Decode(&request); err != nil { + w.WriteHeader(http.StatusBadRequest) + json.NewEncoder(w).Encode(ChatResponse{ + Success: false, + Message: "Invalid request body", + }) + return + } + + existingChat, err := ch.db.GetChatByChatID(chatID) + if err != nil { + w.WriteHeader(http.StatusNotFound) + json.NewEncoder(w).Encode(ChatResponse{ + Success: false, + Message: "Chat not found", + }) + return + } + + updatedChat := existingChat + updatedChat.Title = request.Title + + updatedChat, err = ch.db.UpdateChat(&updatedChat) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + json.NewEncoder(w).Encode(ChatResponse{ + Success: false, + Message: fmt.Sprintf("Failed to update chat: %v", err), + }) + return + } + + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(ChatResponse{ + Success: true, + Message: "Chat updated successfully", + Data: updatedChat, + }) +} + func (ch *ChatHandler) SendMessage(w http.ResponseWriter, r *http.Request) { var request struct { diff --git a/handlers/chat_test.go b/handlers/chat_test.go new file mode 100644 index 000000000..cebbb9172 --- /dev/null +++ b/handlers/chat_test.go @@ -0,0 +1,151 @@ +package handlers + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/go-chi/chi" + "github.com/google/uuid" + "github.com/stakwork/sphinx-tribes/db" + "github.com/stretchr/testify/assert" +) + +type Chat struct { + ID string `json:"id"` + WorkspaceID string `json:"workspaceId"` + Title string `json:"title"` +} + +func TestUpdateChat(t *testing.T) { + teardownSuite := SetupSuite(t) + defer teardownSuite(t) + + chatHandler := NewChatHandler(&http.Client{}, db.TestDB) + + t.Run("should successfully update chat when valid data is provided", func(t *testing.T) { + rr := httptest.NewRecorder() + handler := http.HandlerFunc(chatHandler.UpdateChat) + + chat := &db.Chat{ + ID: uuid.New().String(), + WorkspaceID: uuid.New().String(), + Title: "Old Title", + } + db.TestDB.AddChat(chat) + + requestBody := map[string]string{ + "workspaceId": chat.WorkspaceID, + "title": "New Title", + } + bodyBytes, _ := json.Marshal(requestBody) + + rctx := chi.NewRouteContext() + rctx.URLParams.Add("chat_id", chat.ID) + req, err := http.NewRequestWithContext( + context.WithValue(context.Background(), chi.RouteCtxKey, rctx), + http.MethodPut, + "/hivechat/"+chat.ID, + bytes.NewReader(bodyBytes), + ) + if err != nil { + t.Fatal(err) + } + + handler.ServeHTTP(rr, req) + + var response ChatResponse + _ = json.Unmarshal(rr.Body.Bytes(), &response) + assert.Equal(t, http.StatusOK, rr.Code) + assert.True(t, response.Success) + assert.Equal(t, "Chat updated successfully", response.Message) + responseData, ok := response.Data.(map[string]interface{}) + assert.True(t, ok, "Response data should be a map") + assert.Equal(t, chat.ID, responseData["id"]) + assert.Equal(t, "New Title", responseData["title"]) + }) + + t.Run("should return bad request when chat_id is missing", func(t *testing.T) { + rr := httptest.NewRecorder() + handler := http.HandlerFunc(chatHandler.UpdateChat) + + rctx := chi.NewRouteContext() + req, err := http.NewRequestWithContext( + context.WithValue(context.Background(), chi.RouteCtxKey, rctx), + http.MethodPut, + "/hivechat/", + nil, + ) + if err != nil { + t.Fatal(err) + } + + handler.ServeHTTP(rr, req) + + var response ChatResponse + _ = json.Unmarshal(rr.Body.Bytes(), &response) + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.False(t, response.Success) + assert.Equal(t, "Chat ID is required", response.Message) + }) + + t.Run("should return bad request when request body is invalid", func(t *testing.T) { + rr := httptest.NewRecorder() + handler := http.HandlerFunc(chatHandler.UpdateChat) + + invalidJson := []byte(`{"title": "New Title"`) + rctx := chi.NewRouteContext() + rctx.URLParams.Add("chat_id", uuid.New().String()) + req, err := http.NewRequestWithContext( + context.WithValue(context.Background(), chi.RouteCtxKey, rctx), + http.MethodPut, + "/hivechat/"+uuid.New().String(), + bytes.NewReader(invalidJson), + ) + if err != nil { + t.Fatal(err) + } + + handler.ServeHTTP(rr, req) + + var response ChatResponse + _ = json.Unmarshal(rr.Body.Bytes(), &response) + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.False(t, response.Success) + assert.Equal(t, "Invalid request body", response.Message) + }) + + t.Run("should return not found when chat doesn't exist", func(t *testing.T) { + rr := httptest.NewRecorder() + handler := http.HandlerFunc(chatHandler.UpdateChat) + + requestBody := map[string]string{ + "workspaceId": uuid.New().String(), + "title": "New Title", + } + bodyBytes, _ := json.Marshal(requestBody) + + rctx := chi.NewRouteContext() + rctx.URLParams.Add("chat_id", uuid.New().String()) + req, err := http.NewRequestWithContext( + context.WithValue(context.Background(), chi.RouteCtxKey, rctx), + http.MethodPut, + "/hivechat/"+uuid.New().String(), + bytes.NewReader(bodyBytes), + ) + if err != nil { + t.Fatal(err) + } + + handler.ServeHTTP(rr, req) + + var response ChatResponse + _ = json.Unmarshal(rr.Body.Bytes(), &response) + assert.Equal(t, http.StatusNotFound, rr.Code) + assert.False(t, response.Success) + assert.Equal(t, "Chat not found", response.Message) + }) +} diff --git a/mocks/Database.go b/mocks/Database.go index 81cf482a4..b0bc90eed 100644 --- a/mocks/Database.go +++ b/mocks/Database.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.46.3. DO NOT EDIT. +// Code generated by mockery v2.50.0. DO NOT EDIT. package db @@ -611,7 +611,7 @@ func (_c *Database_ChangeWorkspaceDeleteStatus_Call) RunAndReturn(run func(strin return _c } -// CountBounties provides a mock function with given fields: +// CountBounties provides a mock function with no fields func (_m *Database) CountBounties() uint64 { ret := _m.Called() @@ -656,7 +656,7 @@ func (_c *Database_CountBounties_Call) RunAndReturn(run func() uint64) *Database return _c } -// CountDevelopers provides a mock function with given fields: +// CountDevelopers provides a mock function with no fields func (_m *Database) CountDevelopers() int64 { ret := _m.Called() @@ -2489,7 +2489,7 @@ func (_c *Database_GetAllBounties_Call) RunAndReturn(run func(*http.Request) []d return _c } -// GetAllTribes provides a mock function with given fields: +// GetAllTribes provides a mock function with no fields func (_m *Database) GetAllTribes() []db.Tribe { ret := _m.Called() @@ -3080,7 +3080,7 @@ func (_c *Database_GetBountiesCountByFeatureAndPhaseUuid_Call) RunAndReturn(run return _c } -// GetBountiesLeaderboard provides a mock function with given fields: +// GetBountiesLeaderboard provides a mock function with no fields func (_m *Database) GetBountiesLeaderboard() []db.LeaderData { ret := _m.Called() @@ -3440,7 +3440,7 @@ func (_c *Database_GetBountyIndexById_Call) RunAndReturn(run func(string) int64) return _c } -// GetBountyRoles provides a mock function with given fields: +// GetBountyRoles provides a mock function with no fields func (_m *Database) GetBountyRoles() []db.BountyRoles { ret := _m.Called() @@ -3695,6 +3695,64 @@ func (_c *Database_GetChatMessagesForChatID_Call) RunAndReturn(run func(string) return _c } +// GetChatsForWorkspace provides a mock function with given fields: workspaceID +func (_m *Database) GetChatsForWorkspace(workspaceID string) ([]db.Chat, error) { + ret := _m.Called(workspaceID) + + if len(ret) == 0 { + panic("no return value specified for GetChatsForWorkspace") + } + + var r0 []db.Chat + var r1 error + if rf, ok := ret.Get(0).(func(string) ([]db.Chat, error)); ok { + return rf(workspaceID) + } + if rf, ok := ret.Get(0).(func(string) []db.Chat); ok { + r0 = rf(workspaceID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]db.Chat) + } + } + + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(workspaceID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Database_GetChatsForWorkspace_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetChatsForWorkspace' +type Database_GetChatsForWorkspace_Call struct { + *mock.Call +} + +// GetChatsForWorkspace is a helper method to define mock.On call +// - workspaceID string +func (_e *Database_Expecter) GetChatsForWorkspace(workspaceID interface{}) *Database_GetChatsForWorkspace_Call { + return &Database_GetChatsForWorkspace_Call{Call: _e.mock.On("GetChatsForWorkspace", workspaceID)} +} + +func (_c *Database_GetChatsForWorkspace_Call) Run(run func(workspaceID string)) *Database_GetChatsForWorkspace_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(string)) + }) + return _c +} + +func (_c *Database_GetChatsForWorkspace_Call) Return(_a0 []db.Chat, _a1 error) *Database_GetChatsForWorkspace_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *Database_GetChatsForWorkspace_Call) RunAndReturn(run func(string) ([]db.Chat, error)) *Database_GetChatsForWorkspace_Call { + _c.Call.Return(run) + return _c +} + // GetCodeGraphByUUID provides a mock function with given fields: uuid func (_m *Database) GetCodeGraphByUUID(uuid string) (db.WorkspaceCodeGraph, error) { ret := _m.Called(uuid) @@ -3809,7 +3867,7 @@ func (_c *Database_GetCodeGraphsByWorkspaceUuid_Call) RunAndReturn(run func(stri return _c } -// GetConnectionCode provides a mock function with given fields: +// GetConnectionCode provides a mock function with no fields func (_m *Database) GetConnectionCode() db.ConnectionCodesShort { ret := _m.Called() @@ -4282,7 +4340,7 @@ func (_c *Database_GetFeaturesByWorkspaceUuid_Call) RunAndReturn(run func(string return _c } -// GetFilterStatusCount provides a mock function with given fields: +// GetFilterStatusCount provides a mock function with no fields func (_m *Database) GetFilterStatusCount() db.FilterStattuCount { ret := _m.Called() @@ -5222,7 +5280,7 @@ func (_c *Database_GetPaymentHistoryByCreated_Call) RunAndReturn(run func(*time. return _c } -// GetPendingPaymentHistory provides a mock function with given fields: +// GetPendingPaymentHistory provides a mock function with no fields func (_m *Database) GetPendingPaymentHistory() []db.NewPaymentHistory { ret := _m.Called() @@ -6249,7 +6307,7 @@ func (_c *Database_GetTicketsByPhaseUUID_Call) RunAndReturn(run func(string, str return _c } -// GetTicketsWithoutGroup provides a mock function with given fields: +// GetTicketsWithoutGroup provides a mock function with no fields func (_m *Database) GetTicketsWithoutGroup() ([]db.Tickets, error) { ret := _m.Called() @@ -6541,7 +6599,7 @@ func (_c *Database_GetTribesByOwner_Call) RunAndReturn(run func(string) []db.Tri return _c } -// GetTribesTotal provides a mock function with given fields: +// GetTribesTotal provides a mock function with no fields func (_m *Database) GetTribesTotal() int64 { ret := _m.Called() @@ -6586,7 +6644,7 @@ func (_c *Database_GetTribesTotal_Call) RunAndReturn(run func() int64) *Database return _c } -// GetUnconfirmedGithub provides a mock function with given fields: +// GetUnconfirmedGithub provides a mock function with no fields func (_m *Database) GetUnconfirmedGithub() []db.Person { ret := _m.Called() @@ -6633,7 +6691,7 @@ func (_c *Database_GetUnconfirmedGithub_Call) RunAndReturn(run func() []db.Perso return _c } -// GetUnconfirmedTwitter provides a mock function with given fields: +// GetUnconfirmedTwitter provides a mock function with no fields func (_m *Database) GetUnconfirmedTwitter() []db.Person { ret := _m.Called() @@ -8016,7 +8074,7 @@ func (_c *Database_GetWorkspaces_Call) RunAndReturn(run func(*http.Request) []db return _c } -// GetWorkspacesCount provides a mock function with given fields: +// GetWorkspacesCount provides a mock function with no fields func (_m *Database) GetWorkspacesCount() int64 { ret := _m.Called() @@ -8240,7 +8298,7 @@ func (_c *Database_ProcessAlerts_Call) Return() *Database_ProcessAlerts_Call { } func (_c *Database_ProcessAlerts_Call) RunAndReturn(run func(db.Person)) *Database_ProcessAlerts_Call { - _c.Call.Return(run) + _c.Run(run) return _c } @@ -8476,7 +8534,7 @@ func (_c *Database_ProcessUpdateBudget_Call) RunAndReturn(run func(db.NewInvoice return _c } -// ProcessUpdateTicketsWithoutGroup provides a mock function with given fields: +// ProcessUpdateTicketsWithoutGroup provides a mock function with no fields func (_m *Database) ProcessUpdateTicketsWithoutGroup() { _m.Called() } @@ -8504,7 +8562,7 @@ func (_c *Database_ProcessUpdateTicketsWithoutGroup_Call) Return() *Database_Pro } func (_c *Database_ProcessUpdateTicketsWithoutGroup_Call) RunAndReturn(run func()) *Database_ProcessUpdateTicketsWithoutGroup_Call { - _c.Call.Return(run) + _c.Run(run) return _c } @@ -9480,6 +9538,62 @@ func (_c *Database_UpdateChannel_Call) RunAndReturn(run func(uint, map[string]in return _c } +// UpdateChat provides a mock function with given fields: chat +func (_m *Database) UpdateChat(chat *db.Chat) (db.Chat, error) { + ret := _m.Called(chat) + + if len(ret) == 0 { + panic("no return value specified for UpdateChat") + } + + var r0 db.Chat + var r1 error + if rf, ok := ret.Get(0).(func(*db.Chat) (db.Chat, error)); ok { + return rf(chat) + } + if rf, ok := ret.Get(0).(func(*db.Chat) db.Chat); ok { + r0 = rf(chat) + } else { + r0 = ret.Get(0).(db.Chat) + } + + if rf, ok := ret.Get(1).(func(*db.Chat) error); ok { + r1 = rf(chat) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Database_UpdateChat_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateChat' +type Database_UpdateChat_Call struct { + *mock.Call +} + +// UpdateChat is a helper method to define mock.On call +// - chat *db.Chat +func (_e *Database_Expecter) UpdateChat(chat interface{}) *Database_UpdateChat_Call { + return &Database_UpdateChat_Call{Call: _e.mock.On("UpdateChat", chat)} +} + +func (_c *Database_UpdateChat_Call) Run(run func(chat *db.Chat)) *Database_UpdateChat_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(*db.Chat)) + }) + return _c +} + +func (_c *Database_UpdateChat_Call) Return(_a0 db.Chat, _a1 error) *Database_UpdateChat_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *Database_UpdateChat_Call) RunAndReturn(run func(*db.Chat) (db.Chat, error)) *Database_UpdateChat_Call { + _c.Call.Return(run) + return _c +} + // UpdateChatMessage provides a mock function with given fields: message func (_m *Database) UpdateChatMessage(message *db.ChatMessage) (db.ChatMessage, error) { ret := _m.Called(message) @@ -9566,7 +9680,7 @@ func (_c *Database_UpdateGithubConfirmed_Call) Return() *Database_UpdateGithubCo } func (_c *Database_UpdateGithubConfirmed_Call) RunAndReturn(run func(uint, bool)) *Database_UpdateGithubConfirmed_Call { - _c.Call.Return(run) + _c.Run(run) return _c } @@ -9600,7 +9714,7 @@ func (_c *Database_UpdateGithubIssues_Call) Return() *Database_UpdateGithubIssue } func (_c *Database_UpdateGithubIssues_Call) RunAndReturn(run func(uint, map[string]interface{})) *Database_UpdateGithubIssues_Call { - _c.Call.Return(run) + _c.Run(run) return _c } @@ -9970,7 +10084,7 @@ func (_c *Database_UpdateTribeUniqueName_Call) Return() *Database_UpdateTribeUni } func (_c *Database_UpdateTribeUniqueName_Call) RunAndReturn(run func(string, string)) *Database_UpdateTribeUniqueName_Call { - _c.Call.Return(run) + _c.Run(run) return _c } @@ -10004,7 +10118,7 @@ func (_c *Database_UpdateTwitterConfirmed_Call) Return() *Database_UpdateTwitter } func (_c *Database_UpdateTwitterConfirmed_Call) RunAndReturn(run func(uint, bool)) *Database_UpdateTwitterConfirmed_Call { - _c.Call.Return(run) + _c.Run(run) return _c } @@ -10320,7 +10434,7 @@ func (_c *Database_WithdrawBudget_Call) Return() *Database_WithdrawBudget_Call { } func (_c *Database_WithdrawBudget_Call) RunAndReturn(run func(string, string, uint)) *Database_WithdrawBudget_Call { - _c.Call.Return(run) + _c.Run(run) return _c } @@ -10337,58 +10451,3 @@ func NewDatabase(t interface { return mock } - -// GetChatsForWorkspace provides a mock function with given fields: workspaceID -func (_m *Database) GetChatsForWorkspace(workspaceID string) ([]db.Chat, error) { - ret := _m.Called(workspaceID) - - if len(ret) == 0 { - panic("no return value specified for GetChatsForWorkspace") - } - - var r0 []db.Chat - var r1 error - if rf, ok := ret.Get(0).(func(string) ([]db.Chat, error)); ok { - return rf(workspaceID) - } - if rf, ok := ret.Get(0).(func(string) []db.Chat); ok { - r0 = rf(workspaceID) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).([]db.Chat) - } - } - - if rf, ok := ret.Get(1).(func(string) error); ok { - r1 = rf(workspaceID) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -type Database_GetChatsForWorkspace_Call struct { - *mock.Call -} - -func (_e *Database_Expecter) GetChatsForWorkspace(workspaceID interface{}) *Database_GetChatsForWorkspace_Call { - return &Database_GetChatsForWorkspace_Call{Call: _e.mock.On("GetChatsForWorkspace", workspaceID)} -} - -func (_c *Database_GetChatsForWorkspace_Call) Run(run func(workspaceID string)) *Database_GetChatsForWorkspace_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(string)) - }) - return _c -} - -func (_c *Database_GetChatsForWorkspace_Call) Return(_a0 []db.Chat, _a1 error) *Database_GetChatsForWorkspace_Call { - _c.Call.Return(_a0, _a1) - return _c -} - -func (_c *Database_GetChatsForWorkspace_Call) RunAndReturn(run func(string) ([]db.Chat, error)) *Database_GetChatsForWorkspace_Call { - _c.Call.Return(run) - return _c -} \ No newline at end of file diff --git a/routes/chat.go b/routes/chat.go index ab8365706..332f91dcf 100644 --- a/routes/chat.go +++ b/routes/chat.go @@ -19,6 +19,7 @@ func ChatRoutes() chi.Router { r.Use(auth.PubKeyContext) r.Get("/", chatHandler.GetChat) r.Post("/", chatHandler.CreateChat) + r.Put("/{chat_id}", chatHandler.UpdateChat) r.Post("/send", chatHandler.SendMessage) r.Get("/history/{uuid}", chatHandler.GetChatHistory) })