Skip to content

Commit

Permalink
增加错误信息输出返回
Browse files Browse the repository at this point in the history
  • Loading branch information
fruitbars committed Jul 13, 2024
1 parent ac2352a commit e1796a1
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 10 deletions.
32 changes: 23 additions & 9 deletions pkg/simple_client/simple_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"context"
"encoding/json"
"errors"
"github.com/gin-gonic/gin"
"github.com/sashabaranov/go-openai"
"io"
Expand Down Expand Up @@ -40,17 +41,26 @@ func (c *SimpleClient) CreateChatCompletion(
httpReq, _ := http.NewRequest("POST", "/v1/chat/completions", bytes.NewBuffer(reqBody))
httpReq.Header.Set("Content-Type", "application/json")

// 使用httptest创建一个响应记录器
// 创建Gin的实例和配置路由
ginc := gin.New()
ginc.POST("/v1/chat/completions", func(ctx *gin.Context) {
handler.HandleOpenAIRequest(ctx, &request)
})

// 创建响应记录器
w := httptest.NewRecorder()

// 使用Gin的默认上下文
ginc, _ := gin.CreateTestContext(w)
// 使用ServeHTTP处理请求
ginc.ServeHTTP(w, httpReq)

// 解析响应

// 使用上面创建的HTTP请求
ginc.Request = httpReq
handler.HandleOpenAIRequest(ginc, &request)
if w.Code >= http.StatusBadRequest {
err = errors.New(string(w.Body.Bytes()))
return
}

json.Unmarshal(w.Body.Bytes(), &response)
err = json.Unmarshal(w.Body.Bytes(), &response)

return
}
Expand All @@ -63,10 +73,14 @@ func (c *SimpleClient) CreateChatCompletionStream(
// 创建io.Pipe连接
reader, writer := io.Pipe()

recorder := httptest.NewRecorder()

// 配置gin的上下文和请求
ginc := gin.New()
ginc.Use(func(ctx *gin.Context) {
ctx.Writer = NewCustomResponseWriter(writer)
crw := NewCustomResponseWriter(writer)
ctx.Writer = crw
ctx.Next()
})
ginc.POST("/v1/chat/completions", func(ctx *gin.Context) {
handler.HandleOpenAIRequest(ctx, &request)
Expand All @@ -78,7 +92,7 @@ func (c *SimpleClient) CreateChatCompletionStream(
requestData, _ := json.Marshal(request)
httpReq, _ := http.NewRequest("POST", "/v1/chat/completions", bytes.NewBuffer(requestData))
httpReq.Header.Set("Content-Type", "application/json")
ginc.ServeHTTP(httptest.NewRecorder(), httpReq)
ginc.ServeHTTP(recorder, httpReq)
}()

return NewSimpleChatCompletionStream(reader), nil
Expand Down
7 changes: 6 additions & 1 deletion pkg/simple_client/simple_stream_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@ type CustomResponseWriter struct {
gin.ResponseWriter
writer io.Writer
status int
header http.Header
body *bytes.Buffer
}

func NewCustomResponseWriter(w io.Writer) *CustomResponseWriter {
return &CustomResponseWriter{
writer: w,
header: http.Header{},
body: bytes.NewBuffer([]byte{}),
}
}
Expand All @@ -33,6 +35,7 @@ func (crw *CustomResponseWriter) Write(data []byte) (int, error) {

func (crw *CustomResponseWriter) WriteHeader(statusCode int) {
crw.status = statusCode // Store status code
crw.writer.Write([]byte(fmt.Sprintf("HTTP/1.1 %d %s\r\n", statusCode, http.StatusText(statusCode))))
}

func (crw *CustomResponseWriter) WriteString(s string) (int, error) {
Expand Down Expand Up @@ -100,5 +103,7 @@ func (scs *SimpleChatCompletionStream) Recv() (*openai.ChatCompletionStreamRespo
return &response, nil
}

return &response, fmt.Errorf("unexpected data format: %s", data)
errData, _ := io.ReadAll(scs.reader)

return &response, fmt.Errorf("unexpected data format: %s", string(errData))
}

0 comments on commit e1796a1

Please sign in to comment.