Skip to content

Commit

Permalink
chore: update midjourney
Browse files Browse the repository at this point in the history
  • Loading branch information
soulteary committed Jun 8, 2023
1 parent fb42165 commit a4700d0
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 15 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
coverage.out
.DS_Store
34 changes: 20 additions & 14 deletions connectors/mid-journey/message.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ import (
"github.com/soulteary/sparrow/internal/define"
)

func BuildMessage(conversationID string, parentMessageID string, messageID string, nextMessageID string, userPrompt string) []byte {
return []byte(fmt.Sprintf("%s\n%s\n%s\n%s\n%s", conversationID, parentMessageID, messageID, nextMessageID, userPrompt))
}

func PostMessage(conn *websocket.Conn, message []byte) error {
err := conn.WriteMessage(websocket.TextMessage, message)
if err != nil {
Expand Down Expand Up @@ -41,7 +45,8 @@ func CreateReceiver(done *chan bool, conn *websocket.Conn, brokerPool *eb.Broker
}
}

func ParseMessage(payload []byte) (parentMessageID string, conversationID string, response string, done bool, err error) {
// conversationID string, parentMessageID string, messageID string
func ParseMessage(payload []byte) (conversationID string, parentMessageID string, messageID string, nextMessageID string, response string, done bool, err error) {
input := string(payload)
texts := strings.Split(input, "\n")

Expand All @@ -50,19 +55,22 @@ func ParseMessage(payload []byte) (parentMessageID string, conversationID string
fmt.Println(input)
}

if len(texts) < 2 {
if len(texts) < 5 {
fmt.Println(texts)
return "", "", "", false, fmt.Errorf("the format of the message sent back by the server is incorrect")
return "", "", "", "", "", false, fmt.Errorf("the format of the message sent back by the server is incorrect")
}

parentMessageID = strings.TrimSpace(texts[0])
conversationID = ""
ControlText := strings.TrimSpace(texts[1])
response = strings.Join(texts[2:], "\n")
ControlText := strings.TrimSpace(texts[0])
conversationID = strings.TrimSpace(texts[1])
parentMessageID = strings.TrimSpace(texts[2])
messageID = strings.TrimSpace(texts[3])
nextMessageID = strings.TrimSpace(texts[4])

response = strings.Join(texts[5:], "\n")
if ControlText == "[MESSAGE:CLOSE]" {
return parentMessageID, conversationID, response, true, nil
return conversationID, parentMessageID, messageID, nextMessageID, response, true, nil
}
return parentMessageID, conversationID, response, false, nil
return conversationID, parentMessageID, messageID, nextMessageID, response, false, nil
}

func FnReceiver() func(err error, p []byte, brokerPool *eb.BrokersPool) {
Expand All @@ -72,7 +80,7 @@ func FnReceiver() func(err error, p []byte, brokerPool *eb.BrokersPool) {
return
}

parentMessageID, conversationID, response, done, err := ParseMessage(p)
conversationID, parentMessageID, messageID, nextMessageID, response, done, err := ParseMessage(p)
if err != nil {
fmt.Println("Error parsing message", err)
return
Expand All @@ -87,12 +95,10 @@ func FnReceiver() func(err error, p []byte, brokerPool *eb.BrokersPool) {
return
}

nextMessageID := define.GenerateUUID()
// TODO bind message id
if !done {
sr.StreamBuilder(user, conversationID, parentMessageID, conversationID, nextMessageID, modelSlug, broker, response, sr.MSG_STATUS_CONTINUE)
sr.StreamBuilder(user, conversationID, parentMessageID, messageID, nextMessageID, modelSlug, broker, response, sr.MSG_STATUS_CONTINUE)
} else {
sr.StreamBuilder(user, conversationID, parentMessageID, conversationID, nextMessageID, modelSlug, broker, response, sr.MSG_STATUS_DONE)
sr.StreamBuilder(user, conversationID, parentMessageID, messageID, nextMessageID, modelSlug, broker, response, sr.MSG_STATUS_DONE)
}
}
}
2 changes: 1 addition & 1 deletion internal/api/conversation/conversation.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ func CreateConversation(brokerPool *eb.BrokersPool) func(c *gin.Context) {

switch userModel {
case datatypes.MODEL_MIDJOURNEY.Slug:
message := []byte(fmt.Sprintf("%s\n%s", parentMessageID, userPrompt))
message := midjourney.BuildMessage(data.ConversationID, parentMessageID, messageID, nextMessageID, userPrompt)
midjourney.PostMessage(midjourney.GetConn(), message)
broker.Serve(c, messageChan)
return
Expand Down

0 comments on commit a4700d0

Please sign in to comment.