diff --git a/config.json b/config.json index 4979c7f..fe02ea1 100644 --- a/config.json +++ b/config.json @@ -22,7 +22,10 @@ "name": "@aibrow/extension" }, - "defaultAiModel": "smollm2-1-7b-instruct-q4-k-m", + "defaultModels": { + "text": "smollm2-1-7b-instruct-q4-k-m", + "embedding": "nomic-embed-text-v1-5-q8-0" + }, "modelMinMachineScore": 0, "permissionRequiredForDefaultModel": false, "permissionAlwaysAllowedOrigins": ["https://aibrow.ai"], diff --git a/src/extension/background/APIHandler/AICoreModelHandler.ts b/src/extension/background/APIHandler/AICoreModelHandler.ts index 840f404..501ad7c 100644 --- a/src/extension/background/APIHandler/AICoreModelHandler.ts +++ b/src/extension/background/APIHandler/AICoreModelHandler.ts @@ -16,7 +16,7 @@ import { import APIHelper from './APIHelper' import AILlmSession from '../AI/AILlmSession' import { nanoid } from 'nanoid' -import { AICapabilityPromptType } from '#Shared/API/AI' +import { AICapabilityPromptType, AIModelType } from '#Shared/API/AI' import { kModelPromptAborted } from '#Shared/Errors' class AICoreModelHandler { @@ -46,7 +46,7 @@ class AICoreModelHandler { /* **************************************************************************/ #handleGetCapabilities = async (channel: IPCInflightChannel) => { - return APIHelper.handleGetStandardCapabilitiesData(channel, AICapabilityPromptType.CoreModel) + return APIHelper.handleGetStandardCapabilitiesData(channel, AIModelType.Text, AICapabilityPromptType.CoreModel) } /* **************************************************************************/ @@ -54,7 +54,7 @@ class AICoreModelHandler { /* **************************************************************************/ #handleCreate = async (channel: IPCInflightChannel) => { - return await APIHelper.handleStandardCreatePreflight(channel, AICapabilityPromptType.CoreModel, async ( + return await APIHelper.handleStandardCreatePreflight(channel, AIModelType.Text, AICapabilityPromptType.CoreModel, async ( manifest, payload, props @@ -75,7 +75,7 @@ class AICoreModelHandler { /* **************************************************************************/ #handlePrompt = async (channel: IPCInflightChannel) => { - return await APIHelper.handleStandardPromptPreflight(channel, async ( + return await APIHelper.handleStandardPromptPreflight(channel, AIModelType.Text, async ( manifest, payload, props @@ -102,7 +102,7 @@ class AICoreModelHandler { /* **************************************************************************/ #handleCountTokens = async (channel: IPCInflightChannel) => { - return await APIHelper.handleStandardPromptPreflight(channel, async ( + return await APIHelper.handleStandardPromptPreflight(channel, AIModelType.Text, async ( manifest, payload, props diff --git a/src/extension/background/APIHandler/AIEmbeddingHandler.ts b/src/extension/background/APIHandler/AIEmbeddingHandler.ts index af2327d..3cbc8fe 100644 --- a/src/extension/background/APIHandler/AIEmbeddingHandler.ts +++ b/src/extension/background/APIHandler/AIEmbeddingHandler.ts @@ -16,7 +16,7 @@ import { import APIHelper from './APIHelper' import AILlmSession from '../AI/AILlmSession' import { nanoid } from 'nanoid' -import { AICapabilityPromptType } from '#Shared/API/AI' +import { AICapabilityPromptType, AIModelType } from '#Shared/API/AI' class AIEmbeddingHandler { /* **************************************************************************/ @@ -44,7 +44,7 @@ class AIEmbeddingHandler { /* **************************************************************************/ #handleGetCapabilities = async (channel: IPCInflightChannel) => { - return APIHelper.handleGetStandardCapabilitiesData(channel, AICapabilityPromptType.Embedding) + return APIHelper.handleGetStandardCapabilitiesData(channel, AIModelType.Embedding, AICapabilityPromptType.Embedding) } /* **************************************************************************/ @@ -52,7 +52,7 @@ class AIEmbeddingHandler { /* **************************************************************************/ #handleCreate = async (channel: IPCInflightChannel) => { - return await APIHelper.handleStandardCreatePreflight(channel, AICapabilityPromptType.Embedding, async ( + return await APIHelper.handleStandardCreatePreflight(channel, AIModelType.Embedding, AICapabilityPromptType.Embedding, async ( manifest, payload, props @@ -73,7 +73,7 @@ class AIEmbeddingHandler { /* **************************************************************************/ #handleGet = async (channel: IPCInflightChannel) => { - return await APIHelper.handleStandardPromptPreflight(channel, async ( + return await APIHelper.handleStandardPromptPreflight(channel, AIModelType.Embedding, async ( manifest, payload, props diff --git a/src/extension/background/APIHandler/AILanguageModelHandler.ts b/src/extension/background/APIHandler/AILanguageModelHandler.ts index e1ec3cd..a14392a 100644 --- a/src/extension/background/APIHandler/AILanguageModelHandler.ts +++ b/src/extension/background/APIHandler/AILanguageModelHandler.ts @@ -22,7 +22,7 @@ import AILlmSession from '../AI/AILlmSession' import { AIModelManifest } from '#Shared/AIModelManifest' import { nanoid } from 'nanoid' import { Template } from '@huggingface/jinja' -import { AICapabilityPromptType, AIRootModelProps } from '#Shared/API/AI' +import { AICapabilityPromptType, AIRootModelProps, AIModelType } from '#Shared/API/AI' class AILanguageModelHandler { /* **************************************************************************/ @@ -104,7 +104,7 @@ class AILanguageModelHandler { /* **************************************************************************/ #handleGetCapabilities = async (channel: IPCInflightChannel) => { - return APIHelper.handleGetStandardCapabilitiesData(channel, AICapabilityPromptType.LanguageModel) + return APIHelper.handleGetStandardCapabilitiesData(channel, AIModelType.Text, AICapabilityPromptType.LanguageModel) } /* **************************************************************************/ @@ -112,7 +112,7 @@ class AILanguageModelHandler { /* **************************************************************************/ #handleCreate = async (channel: IPCInflightChannel) => { - return await APIHelper.handleStandardCreatePreflight(channel, AICapabilityPromptType.LanguageModel, async ( + return await APIHelper.handleStandardCreatePreflight(channel, AIModelType.Text, AICapabilityPromptType.LanguageModel, async ( manifest, payload, props @@ -151,7 +151,7 @@ class AILanguageModelHandler { /* **************************************************************************/ #handleCountTokens = async (channel: IPCInflightChannel) => { - return await APIHelper.handleStandardPromptPreflight(channel, async ( + return await APIHelper.handleStandardPromptPreflight(channel, AIModelType.Text, async ( manifest, payload, props @@ -186,7 +186,7 @@ class AILanguageModelHandler { * @returns the stream response */ #handlePrompt = async (channel: IPCInflightChannel) => { - return await APIHelper.handleStandardPromptPreflight(channel, async ( + return await APIHelper.handleStandardPromptPreflight(channel, AIModelType.Text, async ( manifest, payload, props diff --git a/src/extension/background/APIHandler/AIRewriterHandler.ts b/src/extension/background/APIHandler/AIRewriterHandler.ts index 9681b36..d2e8c9a 100644 --- a/src/extension/background/APIHandler/AIRewriterHandler.ts +++ b/src/extension/background/APIHandler/AIRewriterHandler.ts @@ -20,7 +20,7 @@ import AILlmSession from '../AI/AILlmSession' import { nanoid } from 'nanoid' import { AIModelManifest } from '#Shared/AIModelManifest' import { Template } from '@huggingface/jinja' -import { AICapabilityPromptType } from '#Shared/API/AI' +import { AICapabilityPromptType, AIModelType } from '#Shared/API/AI' import { kModelPromptTypeNotSupported } from '#Shared/Errors' class AIRewriterHandler { @@ -49,7 +49,7 @@ class AIRewriterHandler { /* **************************************************************************/ #handleGetCapabilities = async (channel: IPCInflightChannel) => { - return APIHelper.handleGetStandardCapabilitiesData(channel, AICapabilityPromptType.Rewriter) + return APIHelper.handleGetStandardCapabilitiesData(channel, AIModelType.Text, AICapabilityPromptType.Rewriter) } /* **************************************************************************/ @@ -57,7 +57,7 @@ class AIRewriterHandler { /* **************************************************************************/ #handleCreate = async (channel: IPCInflightChannel) => { - return await APIHelper.handleStandardCreatePreflight(channel, AICapabilityPromptType.Rewriter, async ( + return await APIHelper.handleStandardCreatePreflight(channel, AIModelType.Text, AICapabilityPromptType.Rewriter, async ( manifest, payload, props @@ -84,7 +84,7 @@ class AIRewriterHandler { /* **************************************************************************/ #handleRewrite = async (channel: IPCInflightChannel) => { - return await APIHelper.handleStandardPromptPreflight(channel, async ( + return await APIHelper.handleStandardPromptPreflight(channel, AIModelType.Text, async ( manifest, payload, props diff --git a/src/extension/background/APIHandler/AISummarizerHandler.ts b/src/extension/background/APIHandler/AISummarizerHandler.ts index 7366499..c593e84 100644 --- a/src/extension/background/APIHandler/AISummarizerHandler.ts +++ b/src/extension/background/APIHandler/AISummarizerHandler.ts @@ -20,7 +20,7 @@ import AILlmSession from '../AI/AILlmSession' import { nanoid } from 'nanoid' import { AIModelManifest } from '#Shared/AIModelManifest' import { Template } from '@huggingface/jinja' -import { AICapabilityPromptType } from '#Shared/API/AI' +import { AICapabilityPromptType, AIModelType } from '#Shared/API/AI' import { kModelPromptTypeNotSupported } from '#Shared/Errors' class AISummarizerHandler { @@ -49,7 +49,7 @@ class AISummarizerHandler { /* **************************************************************************/ #handleGetCapabilities = async (channel: IPCInflightChannel) => { - return APIHelper.handleGetStandardCapabilitiesData(channel, AICapabilityPromptType.Summarizer) + return APIHelper.handleGetStandardCapabilitiesData(channel, AIModelType.Text, AICapabilityPromptType.Summarizer) } /* **************************************************************************/ @@ -57,7 +57,7 @@ class AISummarizerHandler { /* **************************************************************************/ #handleCreate = async (channel: IPCInflightChannel) => { - return await APIHelper.handleStandardCreatePreflight(channel, AICapabilityPromptType.Summarizer, async ( + return await APIHelper.handleStandardCreatePreflight(channel, AIModelType.Text, AICapabilityPromptType.Summarizer, async ( manifest, payload, props @@ -84,7 +84,7 @@ class AISummarizerHandler { /* **************************************************************************/ #handleSummarize = async (channel: IPCInflightChannel) => { - return await APIHelper.handleStandardPromptPreflight(channel, async ( + return await APIHelper.handleStandardPromptPreflight(channel, AIModelType.Text, async ( manifest, payload, props diff --git a/src/extension/background/APIHandler/AIWriterHandler.ts b/src/extension/background/APIHandler/AIWriterHandler.ts index 7cd3cbe..3af2451 100644 --- a/src/extension/background/APIHandler/AIWriterHandler.ts +++ b/src/extension/background/APIHandler/AIWriterHandler.ts @@ -20,7 +20,7 @@ import AILlmSession from '../AI/AILlmSession' import { nanoid } from 'nanoid' import { AIModelManifest } from '#Shared/AIModelManifest' import { Template } from '@huggingface/jinja' -import { AICapabilityPromptType } from '#Shared/API/AI' +import { AICapabilityPromptType, AIModelType } from '#Shared/API/AI' import { kModelPromptTypeNotSupported } from '#Shared/Errors' class AIWriterHandler { @@ -49,7 +49,7 @@ class AIWriterHandler { /* **************************************************************************/ #handleGetCapabilities = async (channel: IPCInflightChannel) => { - return APIHelper.handleGetStandardCapabilitiesData(channel, AICapabilityPromptType.Writer) + return APIHelper.handleGetStandardCapabilitiesData(channel, AIModelType.Text, AICapabilityPromptType.Writer) } /* **************************************************************************/ @@ -57,7 +57,7 @@ class AIWriterHandler { /* **************************************************************************/ #handleCreate = async (channel: IPCInflightChannel) => { - return await APIHelper.handleStandardCreatePreflight(channel, AICapabilityPromptType.Writer, async ( + return await APIHelper.handleStandardCreatePreflight(channel, AIModelType.Text, AICapabilityPromptType.Writer, async ( manifest, payload, props @@ -84,7 +84,7 @@ class AIWriterHandler { /* **************************************************************************/ #handleWrite = async (channel: IPCInflightChannel) => { - return await APIHelper.handleStandardPromptPreflight(channel, async ( + return await APIHelper.handleStandardPromptPreflight(channel, AIModelType.Text, async ( manifest, payload, props diff --git a/src/extension/background/APIHandler/APIHelper.ts b/src/extension/background/APIHandler/APIHelper.ts index a5743a6..87cc5e6 100644 --- a/src/extension/background/APIHandler/APIHelper.ts +++ b/src/extension/background/APIHandler/APIHelper.ts @@ -1,4 +1,4 @@ -import { clamp, getAIModelId, getEnum } from '#Shared/API/Untrusted/UntrustedParser' +import { clamp, getNonEmptyString, getEnum } from '#Shared/API/Untrusted/UntrustedParser' import { getDefaultModel, getDefaultModelEngine, @@ -24,7 +24,8 @@ import { AICapabilityAvailability, AIRootModelCapabilitiesData, AICapabilityPromptType, - AIRootModelProps + AIRootModelProps, + AIModelType } from '#Shared/API/AI' import { IPCInflightChannel } from '#Shared/IPC/IPCServer' import PermissionProvider from '../PermissionProvider' @@ -44,8 +45,8 @@ class APIHelper { * @param modelId: the id of the model * @returns the model id or the default */ - async getModelId (modelId: any): Promise { - return getAIModelId(modelId, await getDefaultModel()) + async getModelId (modelId: any, modelType: AIModelType): Promise { + return getNonEmptyString(modelId, await getDefaultModel(modelType)) } /** @@ -57,10 +58,6 @@ class APIHelper { return getEnum(gpuEngine, AICapabilityGpuEngine, await getDefaultModelEngine()) } - /* **************************************************************************/ - // MARK: Models - /* **************************************************************************/ - /** * Looks to see if a model supports a given prompt type * @param manifest: the manifest of the model @@ -146,16 +143,18 @@ class APIHelper { /** * Gets the standard capabilities data * @param channel: the incoming channel + * @param modelType: the type of model we're targeting * @param promptType: the type of prompt to check is available * @returns the response for the channel */ async handleGetStandardCapabilitiesData ( channel: IPCInflightChannel, + modelType: AIModelType, promptType: AICapabilityPromptType, configFn?: (manifest: AIModelManifest) => object ): Promise { return await this.captureCommonErrorsForResponse(async () => { - const modelId = await this.getModelId(channel.payload?.model) + const modelId = await this.getModelId(channel.payload?.model, modelType) // Permission checks & requests await PermissionProvider.requestModelPermission(channel, modelId) @@ -215,12 +214,14 @@ class APIHelper { /** * Handles a bunch of preflight tasks before a create call * @param channel: the incoming IPC channel + * @param modelType: the type of model we're targeting * @param promptType: the prompt type we should check support for * @param postflightFn: a function that can execute a after the preflight calls have been executed * @returns the reply from the postflight */ async handleStandardCreatePreflight ( channel: IPCInflightChannel, + modelType: AIModelType, promptType: AICapabilityPromptType, postflightFn: ( manifest: AIModelManifest, @@ -239,7 +240,7 @@ class APIHelper { const payload = new UntrustedParser(rawPayload) return await this.captureCommonErrorsForResponse(async () => { // Values with user-defined defaults - const modelId = await this.getModelId(rawPayload?.model) + const modelId = await this.getModelId(rawPayload?.model, modelType) const gpuEngine = await this.getGpuEngine(rawPayload?.gpuEngine) // Permission checks & requests @@ -290,11 +291,13 @@ class APIHelper { /** * Handles a bunch of preflight tasks before a prompt call * @param channel: the incoming IPC channel + * @param modelType: the type of model we're targeting * @param postflightFn: a function that can execute a after the preflight calls have been executed * @returns the reply from the postflight */ async handleStandardPromptPreflight ( channel: IPCInflightChannel, + modelType: AIModelType, postflightFn: ( manifest: AIModelManifest, payload: UntrustedParser, @@ -305,7 +308,7 @@ class APIHelper { const payload = new UntrustedParser(rawPayload) // Values with user-defined defaults - const modelId = await this.getModelId(rawPayload?.props?.model) + const modelId = await this.getModelId(rawPayload?.props?.model, modelType) const gpuEngine = await this.getGpuEngine(rawPayload?.props?.gpuEngine) // Permission checks & requests diff --git a/src/extension/background/PermissionProvider/PermissionProvider.ts b/src/extension/background/PermissionProvider/PermissionProvider.ts index da5e2cc..c0275a6 100644 --- a/src/extension/background/PermissionProvider/PermissionProvider.ts +++ b/src/extension/background/PermissionProvider/PermissionProvider.ts @@ -62,7 +62,7 @@ class PermissionProvider { // The default model might already have permission if (config.permissionRequiredForDefaultModel === false) { - if (modelId === undefined || modelId === config.defaultAiModel) { return true } + if (modelId === undefined || Object.values(config.defaultModels).includes(modelId)) { return true } } // Check if we're a pre-allowed origin diff --git a/src/extension/ui-options/index.html b/src/extension/ui-options/index.html index e0d932a..86a58f4 100644 --- a/src/extension/ui-options/index.html +++ b/src/extension/ui-options/index.html @@ -16,13 +16,21 @@
Settings
- -
The default model to use when a site doesn't specify one.
+
+ + +
+ The default embedding model to use when a site doesn't specify one. +
+