Skip to content

Commit

Permalink
Merge branch 'songquanpeng' into sync_upstream
Browse files Browse the repository at this point in the history
  • Loading branch information
MartialBE committed Dec 25, 2023
2 parents be61388 + f44fbe3 commit 47b72b8
Show file tree
Hide file tree
Showing 24 changed files with 251 additions and 49 deletions.
35 changes: 35 additions & 0 deletions common/image/image.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,22 @@ import (
_ "golang.org/x/image/webp"
)

func IsImageUrl(url string) (bool, error) {
resp, err := http.Head(url)
if err != nil {
return false, err
}
if !strings.HasPrefix(resp.Header.Get("Content-Type"), "image/") {
return false, nil
}
return true, nil
}

func GetImageSizeFromUrl(url string) (width int, height int, err error) {
isImage, err := IsImageUrl(url)
if !isImage {
return
}
resp, err := http.Get(url)
if err != nil {
return
Expand All @@ -28,6 +43,26 @@ func GetImageSizeFromUrl(url string) (width int, height int, err error) {
return img.Width, img.Height, nil
}

func GetImageFromUrl(url string) (mimeType string, data string, err error) {
isImage, err := IsImageUrl(url)
if !isImage {
return
}
resp, err := http.Get(url)
if err != nil {
return
}
defer resp.Body.Close()
buffer := bytes.NewBuffer(nil)
_, err = buffer.ReadFrom(resp.Body)
if err != nil {
return
}
mimeType = resp.Header.Get("Content-Type")
data = base64.StdEncoding.EncodeToString(buffer.Bytes())
return
}

var (
reg = regexp.MustCompile(`data:image/([^;]+);base64,`)
)
Expand Down
17 changes: 17 additions & 0 deletions common/image/image_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,3 +152,20 @@ func TestGetImageSize(t *testing.T) {
})
}
}

func TestGetImageSizeFromBase64(t *testing.T) {
for i, c := range cases {
t.Run("Decode:"+strconv.Itoa(i), func(t *testing.T) {
resp, err := http.Get(c.url)
assert.NoError(t, err)
defer resp.Body.Close()
data, err := io.ReadAll(resp.Body)
assert.NoError(t, err)
encoded := base64.StdEncoding.EncodeToString(data)
width, height, err := img.GetImageSizeFromBase64(encoded)
assert.NoError(t, err)
assert.Equal(t, c.width, width)
assert.Equal(t, c.height, height)
})
}
}
4 changes: 4 additions & 0 deletions common/model-ratio.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ var ModelRatio = map[string]float64{
"Embedding-V1": 0.1429, // ¥0.002 / 1k tokens
"PaLM-2": 1,
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens
Expand Down Expand Up @@ -115,6 +116,9 @@ func UpdateModelRatioByJSONString(jsonStr string) error {
}

func GetModelRatio(name string) float64 {
if strings.HasPrefix(name, "qwen-") && strings.HasSuffix(name, "-internet") {
name = strings.TrimSuffix(name, "-internet")
}
ratio, ok := ModelRatio[name]
if !ok {
SysError("model ratio not found: " + name)
Expand Down
9 changes: 9 additions & 0 deletions controller/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,15 @@ func init() {
Root: "gemini-pro",
Parent: nil,
},
{
Id: "gemini-pro-vision",
Object: "model",
Created: 1677649963,
OwnedBy: "google",
Permission: permission,
Root: "gemini-pro-vision",
Parent: nil,
},
{
Id: "chatglm_turbo",
Object: "model",
Expand Down
4 changes: 2 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ require (
github.com/gorilla/websocket v1.5.0
github.com/pkoukk/tiktoken-go v0.1.5
github.com/stretchr/testify v1.8.3
golang.org/x/crypto v0.14.0
golang.org/x/crypto v0.17.0
golang.org/x/image v0.14.0
gorm.io/driver/mysql v1.4.3
gorm.io/driver/postgres v1.5.2
Expand Down Expand Up @@ -59,7 +59,7 @@ require (
github.com/ugorji/go/codec v1.2.11 // indirect
golang.org/x/arch v0.3.0 // indirect
golang.org/x/net v0.17.0 // indirect
golang.org/x/sys v0.13.0 // indirect
golang.org/x/sys v0.15.0 // indirect
golang.org/x/text v0.14.0 // indirect
google.golang.org/protobuf v1.30.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
Expand Down
8 changes: 4 additions & 4 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,8 @@ golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUu
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc=
golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k=
golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4=
golang.org/x/image v0.14.0 h1:tNgSxAFe3jC4uYqvZdTr84SZoM1KfwdC9SKIFrLjFn4=
golang.org/x/image v0.14.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
Expand All @@ -166,8 +166,8 @@ golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc=
golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
Expand Down
2 changes: 2 additions & 0 deletions middleware/recover.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@ import (
"github.com/gin-gonic/gin"
"net/http"
"one-api/common"
"runtime/debug"
)

func RelayPanicRecover() gin.HandlerFunc {
return func(c *gin.Context) {
defer func() {
if err := recover(); err != nil {
common.SysError(fmt.Sprintf("panic detected: %v", err))
common.SysError(fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack())))
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/songquanpeng/one-api", err),
Expand Down
34 changes: 26 additions & 8 deletions providers/ali/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ func (aliResponse *AliChatResponse) ResponseHandler(resp *http.Response) (OpenAI
ID: aliResponse.RequestId,
Object: "chat.completion",
Created: common.GetTimestamp(),
Model: aliResponse.Model,
Choices: []types.ChatCompletionChoice{choice},
Usage: &types.Usage{
PromptTokens: aliResponse.Usage.InputTokens,
Expand All @@ -50,6 +51,8 @@ func (aliResponse *AliChatResponse) ResponseHandler(resp *http.Response) (OpenAI
return
}

const AliEnableSearchModelSuffix = "-internet"

// 获取聊天请求体
func (p *AliProvider) getChatRequestBody(request *types.ChatCompletionRequest) *AliChatRequest {
messages := make([]AliMessage, 0, len(request.Messages))
Expand All @@ -60,11 +63,23 @@ func (p *AliProvider) getChatRequestBody(request *types.ChatCompletionRequest) *
Role: strings.ToLower(message.Role),
})
}

enableSearch := false
aliModel := request.Model
if strings.HasSuffix(aliModel, AliEnableSearchModelSuffix) {
enableSearch = true
aliModel = strings.TrimSuffix(aliModel, AliEnableSearchModelSuffix)
}

return &AliChatRequest{
Model: request.Model,
Model: aliModel,
Input: AliInput{
Messages: messages,
},
Parameters: AliParameters{
EnableSearch: enableSearch,
IncrementalOutput: request.Stream,
},
}
}

Expand All @@ -86,7 +101,7 @@ func (p *AliProvider) ChatAction(request *types.ChatCompletionRequest, isModelMa
}

if request.Stream {
usage, errWithCode = p.sendStreamRequest(req)
usage, errWithCode = p.sendStreamRequest(req, request.Model)
if errWithCode != nil {
return
}
Expand All @@ -100,7 +115,9 @@ func (p *AliProvider) ChatAction(request *types.ChatCompletionRequest, isModelMa
}

} else {
aliResponse := &AliChatResponse{}
aliResponse := &AliChatResponse{
Model: request.Model,
}
errWithCode = p.SendRequest(req, aliResponse, false)
if errWithCode != nil {
return
Expand Down Expand Up @@ -128,14 +145,14 @@ func (p *AliProvider) streamResponseAli2OpenAI(aliResponse *AliChatResponse) *ty
ID: aliResponse.RequestId,
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: "ernie-bot",
Model: aliResponse.Model,
Choices: []types.ChatCompletionStreamChoice{choice},
}
return &response
}

// 发送流请求
func (p *AliProvider) sendStreamRequest(req *http.Request) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
func (p *AliProvider) sendStreamRequest(req *http.Request, model string) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
defer req.Body.Close()

usage = &types.Usage{}
Expand Down Expand Up @@ -181,7 +198,7 @@ func (p *AliProvider) sendStreamRequest(req *http.Request) (usage *types.Usage,
stopChan <- true
}()
common.SetEventStreamHeaders(p.Context)
lastResponseText := ""
// lastResponseText := ""
p.Context.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
Expand All @@ -196,9 +213,10 @@ func (p *AliProvider) sendStreamRequest(req *http.Request) (usage *types.Usage,
usage.CompletionTokens = aliResponse.Usage.OutputTokens
usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
}
aliResponse.Model = model
response := p.streamResponseAli2OpenAI(&aliResponse)
response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText)
lastResponseText = aliResponse.Output.Text
// response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText)
// lastResponseText = aliResponse.Output.Text
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
Expand Down
10 changes: 6 additions & 4 deletions providers/ali/type.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,11 @@ type AliInput struct {
}

type AliParameters struct {
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Seed uint64 `json:"seed,omitempty"`
EnableSearch bool `json:"enable_search,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Seed uint64 `json:"seed,omitempty"`
EnableSearch bool `json:"enable_search,omitempty"`
IncrementalOutput bool `json:"incremental_output,omitempty"`
}

type AliChatRequest struct {
Expand All @@ -43,6 +44,7 @@ type AliOutput struct {
type AliChatResponse struct {
Output AliOutput `json:"output"`
Usage AliUsage `json:"usage"`
Model string `json:"model,omitempty"`
AliError
}

Expand Down
11 changes: 7 additions & 4 deletions providers/baidu/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,15 @@ func (p *BaiduProvider) ChatAction(request *types.ChatCompletionRequest, isModel
}

if request.Stream {
usage, errWithCode = p.sendStreamRequest(req)
usage, errWithCode = p.sendStreamRequest(req, request.Model)
if errWithCode != nil {
return
}

} else {
baiduChatRequest := &BaiduChatResponse{}
baiduChatRequest := &BaiduChatResponse{
Model: request.Model,
}
errWithCode = p.SendRequest(req, baiduChatRequest, false)
if errWithCode != nil {
return
Expand All @@ -117,13 +119,13 @@ func (p *BaiduProvider) streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStrea
ID: baiduResponse.Id,
Object: "chat.completion.chunk",
Created: baiduResponse.Created,
Model: "ernie-bot",
Model: baiduResponse.Model,
Choices: []types.ChatCompletionStreamChoice{choice},
}
return &response
}

func (p *BaiduProvider) sendStreamRequest(req *http.Request) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
func (p *BaiduProvider) sendStreamRequest(req *http.Request, model string) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
defer req.Body.Close()

usage = &types.Usage{}
Expand Down Expand Up @@ -180,6 +182,7 @@ func (p *BaiduProvider) sendStreamRequest(req *http.Request) (usage *types.Usage
usage.PromptTokens = baiduResponse.Usage.PromptTokens
usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens
}
baiduResponse.Model = model
response := p.streamResponseBaidu2OpenAI(&baiduResponse)
jsonResponse, err := json.Marshal(response)
if err != nil {
Expand Down
1 change: 1 addition & 0 deletions providers/baidu/type.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ type BaiduChatResponse struct {
IsTruncated bool `json:"is_truncated"`
NeedClearHistory bool `json:"need_clear_history"`
Usage *types.Usage `json:"usage"`
Model string `json:"model,omitempty"`
BaiduError
}

Expand Down
1 change: 1 addition & 0 deletions providers/claude/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ func (claudeResponse *ClaudeResponse) ResponseHandler(resp *http.Response) (Open
Object: "chat.completion",
Created: common.GetTimestamp(),
Choices: []types.ChatCompletionChoice{choice},
Model: claudeResponse.Model,
}

completionTokens := common.CountTokenText(claudeResponse.Completion, claudeResponse.Model)
Expand Down
3 changes: 2 additions & 1 deletion providers/gemini/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,15 @@ func (p *GeminiProvider) GetFullRequestURL(requestURL string, modelName string)
version = p.Context.GetString("api_version")
}

return fmt.Sprintf("%s/%s/models/%s:%s?key=%s", baseURL, version, modelName, requestURL, p.Context.GetString("api_key"))
return fmt.Sprintf("%s/%s/models/%s:%s", baseURL, version, modelName, requestURL)

}

// 获取请求头
func (p *GeminiProvider) GetRequestHeaders() (headers map[string]string) {
headers = make(map[string]string)
p.CommonRequestHeaders(headers)
headers["x-goog-api-key"] = p.Context.GetString("api_key")

return headers
}
Loading

0 comments on commit 47b72b8

Please sign in to comment.