Skip to content

Commit

Permalink
feat: custom local ai support
Browse files Browse the repository at this point in the history
  • Loading branch information
BigJk committed Feb 3, 2024
1 parent 247aa76 commit 39bf024
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 10 deletions.
2 changes: 2 additions & 0 deletions frontend/src/js/types/settings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ type Settings = {
aiModel: string;
aiContextWindow: number;
aiMaxTokens: number;
aiUrl: string;
};

export function createEmptySettings(): Settings {
Expand Down Expand Up @@ -53,6 +54,7 @@ export function createEmptySettings(): Settings {
aiModel: '',
aiContextWindow: 0,
aiMaxTokens: 0,
aiUrl: '',
};
}

Expand Down
38 changes: 37 additions & 1 deletion frontend/src/js/ui/views/settings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import IconButton from 'js/ui/spectre/icon-button';
import Select from 'js/ui/spectre/select';
import HorizontalProperty from 'js/ui/components/horizontal-property';
import Button from 'js/ui/spectre/button';
import Input from 'js/ui/spectre/input';

import { error, neutral, success } from 'js/ui/toast';
import { clearPortal, setPortal } from 'js/ui/portal';
Expand Down Expand Up @@ -331,7 +332,7 @@ export default (): m.Component => {
},
}),
),
aiModels.length === 0
aiModels.length === 0 || settingsCopy.aiProvider.startsWith('Custom')
? null
: m(
HorizontalProperty,
Expand All @@ -349,6 +350,41 @@ export default (): m.Component => {
},
}),
),
!settingsCopy.aiProvider.startsWith('Custom')
? null
: [
m(
HorizontalProperty,
{
label: 'Custom Model',
description:
'The AI model to use for your custom AI. Depending on the provider this can be a model name, file location or a model ID. Can be left blank if not applicable.',
centered: true,
bottomBorder: true,
},
m(Input, {
value: settingsCopy.aiModel,
onChange: (val) => {
settingsCopy = { ...settingsCopy, aiModel: val };
},
}),
),
m(
HorizontalProperty,
{
label: 'Custom URL',
description: 'The URL of your custom OpenAI compatible API (e.g. http://localhost:1234)',
centered: true,
bottomBorder: true,
},
m(Input, {
value: settingsCopy.aiUrl,
onChange: (val) => {
settingsCopy = { ...settingsCopy, aiUrl: val };
},
}),
),
],
]),
),
),
Expand Down
30 changes: 21 additions & 9 deletions rpc/ai.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,23 @@ func shortHash(text string) string {
return sha[:8]
}

var supportedProviders = []string{"OpenRouter.ai", "OpenAI"}
var supportedProviders = []string{"OpenRouter.ai", "OpenAI", "Custom (e.g. Local)"}

func providerToEndpoint(provider string) (string, error) {
func providerToEndpoint(db database.Database, provider string) (string, error) {
switch provider {
case "OpenRouter.ai":
return "https://openrouter.ai/api", nil
case "OpenAI":
return "https://api.openai.com", nil
case "Custom (e.g. Local)":
settings, err := db.GetSettings()
if err != nil {
return "", err
}
if len(settings.AIURL) == 0 {
return "", errors.New("custom AI provider URL is not set")
}
return settings.AIURL, nil
default:
return "", errors.New("unknown provider")
}
Expand Down Expand Up @@ -85,7 +94,7 @@ func RegisterAI(route *echo.Group, db database.Database) {
return "", errors.New("AI is not enabled")
}

endpoint, err := providerToEndpoint(settings.AIProvider)
endpoint, err := providerToEndpoint(db, settings.AIProvider)
if err != nil {
return "", err
}
Expand Down Expand Up @@ -191,14 +200,17 @@ func RegisterAI(route *echo.Group, db database.Database) {
})))

route.POST("/aiModels", echo.WrapHandler(nra.MustBind(func(provider string) ([]string, error) {
endpoint, err := providerToEndpoint(provider)
if err != nil {
return nil, err
}

// TODO: dynamically fetch models
if provider == "OpenAI" {
switch provider {
case "OpenAI":
return []string{"gpt-3.5-turbo", "gpt-3.5-turbo-1106", "gpt-3.5-turbo-16k", "gpt-4-1106-preview", "gpt-4", "gpt-4-32k"}, nil
case "Custom (e.g. Local)":
return []string{"Custom"}, nil
}

endpoint, err := providerToEndpoint(db, provider)
if err != nil {
return nil, err
}

resp, err := http.Get(endpoint + "/v1/models")
Expand Down
1 change: 1 addition & 0 deletions settings.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,5 @@ type Settings struct {
AIProvider string `json:"aiProvider"`
AIContextWindow int `json:"aiContextWindow"`
AIMaxTokens int `json:"aiMaxTokens"`
AIURL string `json:"aiUrl"`
}

0 comments on commit 39bf024

Please sign in to comment.