Skip to content

Commit

Permalink
增加对于异常消息content为object类型时候的兼容
Browse files Browse the repository at this point in the history
  • Loading branch information
fruitbars committed Jul 13, 2024
1 parent f190884 commit 4cbfa33
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 4 deletions.
44 changes: 40 additions & 4 deletions pkg/handler/openai_handler.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
package handler

import (
"bytes"
"context"
"errors"
"github.com/gin-gonic/gin"
"github.com/sashabaranov/go-openai"
"go.uber.org/zap"
"io"
"net/http"
"simple-one-api/pkg/adapter"
"simple-one-api/pkg/config"
Expand Down Expand Up @@ -58,6 +60,22 @@ func LogRequestDetails(c *gin.Context) {
)
}

func getBodyDataCopy(c *gin.Context) ([]byte, error) {
body, err := c.GetRawData()
if err != nil {
//c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "Unable to read request body"})
return nil, err
}

// 将原始数据保存到上下文
c.Set("rawData", body)

// 重新设置请求体,以便后续能够读取
c.Request.Body = io.NopCloser(bytes.NewBuffer(body))

return body, nil
}

// OpenAIHandler handles POST requests on /v1/chat/completions path
func OpenAIHandler(c *gin.Context) {
if !validateRequestMethod(c, "POST") {
Expand All @@ -71,11 +89,29 @@ func OpenAIHandler(c *gin.Context) {
return
}

bodyData, getBodyerr := getBodyDataCopy(c)

var oaiReq openai.ChatCompletionRequest
if err := c.ShouldBindJSON(&oaiReq); err != nil {
mylog.Logger.Error(err.Error())
sendErrorResponse(c, http.StatusBadRequest, err.Error())
return
// 尝试重新解析请求体

if getBodyerr != nil {
mylog.Logger.Error(err.Error())
sendErrorResponse(c, http.StatusBadRequest, err.Error())
return
}

mylog.Logger.Debug(string(bodyData))
parsedReq, parseErr := mycommon.ParseChatCompletionRequest(bodyData)
if parseErr != nil {
mylog.Logger.Error("ParseChatCompletionRequest error: " + parseErr.Error())
sendErrorResponse(c, http.StatusBadRequest, parseErr.Error())
return
}

// 将重新解析的结果赋值给 oaiReq
oaiReq = *parsedReq
}

mycommon.LogChatCompletionRequest(oaiReq)
Expand Down Expand Up @@ -170,8 +206,8 @@ func HandleOpenAIRequest(c *gin.Context, oaiReq *openai.ChatCompletionRequest) {
// Log a message if the request could not obtain a token within the specified timeout period.
// 假设 logger 是一个已经配置好的 zap.Logger 实例
mylog.Logger.Error("Failed to obtain token within the specified time",
zap.Error(err), // 记录错误对象
zap.Int("timeout", timeout), // 假设 timeout 是 time.Duration 类型
zap.Error(err), // 记录错误对象
zap.Int("timeout", timeout), // 假设 timeout 是 time.Duration 类型
zap.Duration("elapsed", elapsed)) // 假设 elapsed 是 time.Duration 类型

} else if errors.Is(err, context.Canceled) {
Expand Down
68 changes: 68 additions & 0 deletions pkg/mycommon/oai_message_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -201,3 +201,71 @@ func LogChatCompletionRequest(request openai.ChatCompletionRequest) {
mylog.Logger.Info("LogChatCompletionRequest", zap.String("request", string(jsonData)))

}

func ParseChatCompletionRequest(data []byte) (*openai.ChatCompletionRequest, error) {
var rawRequest struct {
Model string `json:"model"`
Messages []json.RawMessage `json:"messages"`
Temperature float32 `json:"temperature,omitempty"`
Stream bool `json:"stream,omitempty"`
}

if err := json.Unmarshal(data, &rawRequest); err != nil {
return nil, err
}

request := &openai.ChatCompletionRequest{
Model: rawRequest.Model,
Temperature: rawRequest.Temperature,
Stream: rawRequest.Stream,
}

for _, rawMsg := range rawRequest.Messages {
var rawMessage struct {
Role string `json:"role"`
Content json.RawMessage `json:"content"`
}
if err := json.Unmarshal(rawMsg, &rawMessage); err != nil {
return nil, err
}

message := openai.ChatCompletionMessage{
Role: rawMessage.Role,
}

// 尝试将 Content 解析为字符串
var contentStr string
if err := json.Unmarshal(rawMessage.Content, &contentStr); err == nil {
message.Content = contentStr
} else {
// 尝试将 Content 解析为对象,并取出其中的 text 字段
var contentObj struct {
Type string `json:"type"`
Text string `json:"text"`
}
if err := json.Unmarshal(rawMessage.Content, &contentObj); err == nil {
if contentObj.Type == string(openai.ChatMessagePartTypeText) {
message.Content = contentObj.Text
} else {
return nil, fmt.Errorf("unexpected content type: %s", contentObj.Type)
}
} else {
// 尝试将 Content 解析为数组并保留原样
var contentArr []openai.ChatMessagePart
if err := json.Unmarshal(rawMessage.Content, &contentArr); err == nil {
messageBytes, err := json.Marshal(contentArr)
if err != nil {
return nil, fmt.Errorf("failed to marshal content array")
}
message.Content = string(messageBytes)
} else {
return nil, fmt.Errorf("failed to unmarshal content")
}
}
}

request.Messages = append(request.Messages, message)
}

return request, nil
}

0 comments on commit 4cbfa33

Please sign in to comment.