Skip to content

Commit

Permalink
Merge pull request #2165 from MahtabBukhari/Update_the_Hive_Chat_Endp…
Browse files Browse the repository at this point in the history
…oints_to_Correctly_Process_New_Chats

Update the Hive Chat Endpoints to Correctly Process New Chats
  • Loading branch information
humansinstitute authored Dec 13, 2024
2 parents e6d9940 + e349434 commit 87be43e
Show file tree
Hide file tree
Showing 5 changed files with 445 additions and 212 deletions.
29 changes: 29 additions & 0 deletions db/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,35 @@ func (db database) AddChatMessage(chatMessage *ChatMessage) (ChatMessage, error)
return *chatMessage, nil
}

func (db database) UpdateChatMessage(chatMessage *ChatMessage) (ChatMessage, error) {
if chatMessage.ID == "" {
return ChatMessage{}, errors.New("message ID is required")
}

var existingMessage ChatMessage
if err := db.db.First(&existingMessage, "id = ?", chatMessage.ID).Error; err != nil {
return ChatMessage{}, fmt.Errorf("message not found: %w", err)
}

if chatMessage.Message != "" {
existingMessage.Message = chatMessage.Message
}
if chatMessage.Status != "" {
existingMessage.Status = chatMessage.Status
}
if chatMessage.Role != "" {
existingMessage.Role = chatMessage.Role
}

existingMessage.Timestamp = time.Now()

if err := db.db.Save(&existingMessage).Error; err != nil {
return ChatMessage{}, fmt.Errorf("failed to update chat message: %w", err)
}

return existingMessage, nil
}

func (db database) GetChatMessagesForChatID(chatID string) ([]ChatMessage, error) {
var chatMessages []ChatMessage

Expand Down
1 change: 1 addition & 0 deletions db/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ type Database interface {
AddChat(chat *Chat) (Chat, error)
GetChatByChatID(chatID string) (Chat, error)
AddChatMessage(message *ChatMessage) (ChatMessage, error)
UpdateChatMessage(message *ChatMessage) (ChatMessage, error)
GetChatMessagesForChatID(chatID string) ([]ChatMessage, error)
GetCodeGraphByUUID(uuid string) (WorkspaceCodeGraph, error)
GetCodeGraphsByWorkspaceUuid(workspace_uuid string) ([]WorkspaceCodeGraph, error)
Expand Down
132 changes: 131 additions & 1 deletion handlers/chat.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
package handlers

import (
"bytes"
"encoding/json"
"fmt"
"github.com/rs/xid"
"github.com/stakwork/sphinx-tribes/websocket"
"io"
"log"
"net/http"
"os"
"time"

"github.com/go-chi/chi"
Expand All @@ -27,6 +32,23 @@ type HistoryChatResponse struct {
Data interface{} `json:"data,omitempty"`
}

type ChatHistoryResponse struct {
Messages []db.ChatMessage `json:"messages"`
}

type StakworkChatPayload struct {
Name string `json:"name"`
WorkflowID int `json:"workflow_id"`
WorkflowParams map[string]interface{} `json:"workflow_params"`
}

type ChatWebhookResponse struct {
Success bool `json:"success"`
Message string `json:"message"`
ChatID string `json:"chat_id"`
History []db.ChatMessage `json:"history"`
}

func NewChatHandler(httpClient *http.Client, database db.Database) *ChatHandler {
return &ChatHandler{
httpClient: httpClient,
Expand Down Expand Up @@ -74,13 +96,15 @@ func (ch *ChatHandler) CreateChat(w http.ResponseWriter, r *http.Request) {
}

func (ch *ChatHandler) SendMessage(w http.ResponseWriter, r *http.Request) {

var request struct {
ChatID string `json:"chatId"`
Message string `json:"message"`
ContextTags []struct {
Type string `json:"type"`
ID string `json:"id"`
} `json:"contextTags"`
SourceWebsocketID string `json:"sourceWebsocketId"`
}

if err := json.NewDecoder(r.Body).Decode(&request); err != nil {
Expand All @@ -92,25 +116,94 @@ func (ch *ChatHandler) SendMessage(w http.ResponseWriter, r *http.Request) {
return
}

history, err := ch.db.GetChatMessagesForChatID(request.ChatID)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
json.NewEncoder(w).Encode(ChatResponse{
Success: false,
Message: fmt.Sprintf("Failed to fetch chat history: %v", err),
})
return
}

start := 0
if len(history) > 20 {
start = len(history) - 20
}
recentHistory := history[start:]

messageHistory := make([]map[string]string, len(recentHistory))
for i, msg := range recentHistory {
messageHistory[i] = map[string]string{
string(msg.Role): msg.Message,
}
}

message := &db.ChatMessage{
ID: xid.New().String(),
ChatID: request.ChatID,
Message: request.Message,
Role: "user",
Timestamp: time.Now(),
Status: "sending",
Source: "user",
}

createdMessage, err := ch.db.AddChatMessage(message)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
json.NewEncoder(w).Encode(ChatResponse{
Success: false,
Message: fmt.Sprintf("Failed to send message: %v", err),
Message: fmt.Sprintf("Failed to save message: %v", err),
})
return
}

stakworkPayload := StakworkChatPayload{
Name: "Hive Chat Processor",
WorkflowID: 38842,
WorkflowParams: map[string]interface{}{
"set_var": map[string]interface{}{
"attributes": map[string]interface{}{
"vars": map[string]interface{}{
"chatId": request.ChatID,
"messageId": createdMessage.ID,
"message": request.Message,
"history": messageHistory,
"contextTags": "This is a project with Typescript frontend and Go Backend",
"sourceWebsocketId": request.SourceWebsocketID,
"webhook_url": fmt.Sprintf("%s/hivechat/process", os.Getenv("HOST")),
},
},
},
},
}

if err := ch.sendToStakwork(stakworkPayload); err != nil {

createdMessage.Status = "error"
ch.db.UpdateChatMessage(&createdMessage)

w.WriteHeader(http.StatusInternalServerError)
json.NewEncoder(w).Encode(ChatResponse{
Success: false,
Message: fmt.Sprintf("Failed to process message: %v", err),
})
return
}

wsMessage := websocket.TicketMessage{
BroadcastType: "direct",
SourceSessionID: request.SourceWebsocketID,
Message: "Message sent",
Action: "process",
ChatMessage: createdMessage,
}

if err := websocket.WebsocketPool.SendTicketMessage(wsMessage); err != nil {
log.Printf("Failed to send websocket message: %v", err)
}

w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(ChatResponse{
Success: true,
Expand All @@ -119,6 +212,43 @@ func (ch *ChatHandler) SendMessage(w http.ResponseWriter, r *http.Request) {
})
}

func (ch *ChatHandler) sendToStakwork(payload StakworkChatPayload) error {
payloadJSON, err := json.Marshal(payload)
if err != nil {
return fmt.Errorf("error marshaling payload: %v", err)
}

req, err := http.NewRequest(
http.MethodPost,
"https://api.stakwork.com/api/v1/projects",
bytes.NewBuffer(payloadJSON),
)
if err != nil {
return fmt.Errorf("error creating request: %v", err)
}

apiKey := os.Getenv("SWWFKEY")
if apiKey == "" {
return fmt.Errorf("SWWFKEY environment variable not set")
}

req.Header.Set("Authorization", "Token token="+apiKey)
req.Header.Set("Content-Type", "application/json")

resp, err := ch.httpClient.Do(req)
if err != nil {
return fmt.Errorf("error sending request: %v", err)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return fmt.Errorf("stakwork API error: %s", string(body))
}

return nil
}

func (ch *ChatHandler) GetChatHistory(w http.ResponseWriter, r *http.Request) {
chatID := chi.URLParam(r, "uuid")
if chatID == "" {
Expand Down
Loading

0 comments on commit 87be43e

Please sign in to comment.