diff --git a/.gitignore b/.gitignore index 2d83068..3afd804 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ coverage.out +.DS_Store \ No newline at end of file diff --git a/connectors/mid-journey/message.go b/connectors/mid-journey/message.go index 214f42e..27603cd 100644 --- a/connectors/mid-journey/message.go +++ b/connectors/mid-journey/message.go @@ -11,6 +11,10 @@ import ( "github.com/soulteary/sparrow/internal/define" ) +func BuildMessage(conversationID string, parentMessageID string, messageID string, nextMessageID string, userPrompt string) []byte { + return []byte(fmt.Sprintf("%s\n%s\n%s\n%s\n%s", conversationID, parentMessageID, messageID, nextMessageID, userPrompt)) +} + func PostMessage(conn *websocket.Conn, message []byte) error { err := conn.WriteMessage(websocket.TextMessage, message) if err != nil { @@ -41,7 +45,8 @@ func CreateReceiver(done *chan bool, conn *websocket.Conn, brokerPool *eb.Broker } } -func ParseMessage(payload []byte) (parentMessageID string, conversationID string, response string, done bool, err error) { +// conversationID string, parentMessageID string, messageID string +func ParseMessage(payload []byte) (conversationID string, parentMessageID string, messageID string, nextMessageID string, response string, done bool, err error) { input := string(payload) texts := strings.Split(input, "\n") @@ -50,19 +55,22 @@ func ParseMessage(payload []byte) (parentMessageID string, conversationID string fmt.Println(input) } - if len(texts) < 2 { + if len(texts) < 5 { fmt.Println(texts) - return "", "", "", false, fmt.Errorf("the format of the message sent back by the server is incorrect") + return "", "", "", "", "", false, fmt.Errorf("the format of the message sent back by the server is incorrect") } - parentMessageID = strings.TrimSpace(texts[0]) - conversationID = "" - ControlText := strings.TrimSpace(texts[1]) - response = strings.Join(texts[2:], "\n") + ControlText := strings.TrimSpace(texts[0]) + conversationID = strings.TrimSpace(texts[1]) + parentMessageID = strings.TrimSpace(texts[2]) + messageID = strings.TrimSpace(texts[3]) + nextMessageID = strings.TrimSpace(texts[4]) + + response = strings.Join(texts[5:], "\n") if ControlText == "[MESSAGE:CLOSE]" { - return parentMessageID, conversationID, response, true, nil + return conversationID, parentMessageID, messageID, nextMessageID, response, true, nil } - return parentMessageID, conversationID, response, false, nil + return conversationID, parentMessageID, messageID, nextMessageID, response, false, nil } func FnReceiver() func(err error, p []byte, brokerPool *eb.BrokersPool) { @@ -72,7 +80,7 @@ func FnReceiver() func(err error, p []byte, brokerPool *eb.BrokersPool) { return } - parentMessageID, conversationID, response, done, err := ParseMessage(p) + conversationID, parentMessageID, messageID, nextMessageID, response, done, err := ParseMessage(p) if err != nil { fmt.Println("Error parsing message", err) return @@ -87,12 +95,10 @@ func FnReceiver() func(err error, p []byte, brokerPool *eb.BrokersPool) { return } - nextMessageID := define.GenerateUUID() - // TODO bind message id if !done { - sr.StreamBuilder(user, conversationID, parentMessageID, conversationID, nextMessageID, modelSlug, broker, response, sr.MSG_STATUS_CONTINUE) + sr.StreamBuilder(user, conversationID, parentMessageID, messageID, nextMessageID, modelSlug, broker, response, sr.MSG_STATUS_CONTINUE) } else { - sr.StreamBuilder(user, conversationID, parentMessageID, conversationID, nextMessageID, modelSlug, broker, response, sr.MSG_STATUS_DONE) + sr.StreamBuilder(user, conversationID, parentMessageID, messageID, nextMessageID, modelSlug, broker, response, sr.MSG_STATUS_DONE) } } } diff --git a/internal/api/conversation/conversation.go b/internal/api/conversation/conversation.go index 508cd24..5f2c3d9 100644 --- a/internal/api/conversation/conversation.go +++ b/internal/api/conversation/conversation.go @@ -117,7 +117,7 @@ func CreateConversation(brokerPool *eb.BrokersPool) func(c *gin.Context) { switch userModel { case datatypes.MODEL_MIDJOURNEY.Slug: - message := []byte(fmt.Sprintf("%s\n%s", parentMessageID, userPrompt)) + message := midjourney.BuildMessage(data.ConversationID, parentMessageID, messageID, nextMessageID, userPrompt) midjourney.PostMessage(midjourney.GetConn(), message) broker.Serve(c, messageChan) return