Skip to content

Commit

Permalink
Add support for a default embedding model
Browse files Browse the repository at this point in the history
  • Loading branch information
Thomas101 committed Nov 13, 2024
1 parent c77a8c8 commit 963d716
Show file tree
Hide file tree
Showing 18 changed files with 119 additions and 84 deletions.
5 changes: 4 additions & 1 deletion config.json
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@
"name": "@aibrow/extension"
},

"defaultAiModel": "smollm2-1-7b-instruct-q4-k-m",
"defaultModels": {
"text": "smollm2-1-7b-instruct-q4-k-m",
"embedding": "nomic-embed-text-v1-5-q8-0"
},
"modelMinMachineScore": 0,
"permissionRequiredForDefaultModel": false,
"permissionAlwaysAllowedOrigins": ["https://aibrow.ai"],
Expand Down
10 changes: 5 additions & 5 deletions src/extension/background/APIHandler/AICoreModelHandler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import {
import APIHelper from './APIHelper'
import AILlmSession from '../AI/AILlmSession'
import { nanoid } from 'nanoid'
import { AICapabilityPromptType } from '#Shared/API/AI'
import { AICapabilityPromptType, AIModelType } from '#Shared/API/AI'
import { kModelPromptAborted } from '#Shared/Errors'

class AICoreModelHandler {
Expand Down Expand Up @@ -46,15 +46,15 @@ class AICoreModelHandler {
/* **************************************************************************/

#handleGetCapabilities = async (channel: IPCInflightChannel) => {
return APIHelper.handleGetStandardCapabilitiesData(channel, AICapabilityPromptType.CoreModel)
return APIHelper.handleGetStandardCapabilitiesData(channel, AIModelType.Text, AICapabilityPromptType.CoreModel)
}

/* **************************************************************************/
// MARK: Handlers: Lifecycle
/* **************************************************************************/

#handleCreate = async (channel: IPCInflightChannel) => {
return await APIHelper.handleStandardCreatePreflight(channel, AICapabilityPromptType.CoreModel, async (
return await APIHelper.handleStandardCreatePreflight(channel, AIModelType.Text, AICapabilityPromptType.CoreModel, async (
manifest,
payload,
props
Expand All @@ -75,7 +75,7 @@ class AICoreModelHandler {
/* **************************************************************************/

#handlePrompt = async (channel: IPCInflightChannel) => {
return await APIHelper.handleStandardPromptPreflight(channel, async (
return await APIHelper.handleStandardPromptPreflight(channel, AIModelType.Text, async (
manifest,
payload,
props
Expand All @@ -102,7 +102,7 @@ class AICoreModelHandler {
/* **************************************************************************/

#handleCountTokens = async (channel: IPCInflightChannel) => {
return await APIHelper.handleStandardPromptPreflight(channel, async (
return await APIHelper.handleStandardPromptPreflight(channel, AIModelType.Text, async (
manifest,
payload,
props
Expand Down
8 changes: 4 additions & 4 deletions src/extension/background/APIHandler/AIEmbeddingHandler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import {
import APIHelper from './APIHelper'
import AILlmSession from '../AI/AILlmSession'
import { nanoid } from 'nanoid'
import { AICapabilityPromptType } from '#Shared/API/AI'
import { AICapabilityPromptType, AIModelType } from '#Shared/API/AI'

class AIEmbeddingHandler {
/* **************************************************************************/
Expand Down Expand Up @@ -44,15 +44,15 @@ class AIEmbeddingHandler {
/* **************************************************************************/

#handleGetCapabilities = async (channel: IPCInflightChannel) => {
return APIHelper.handleGetStandardCapabilitiesData(channel, AICapabilityPromptType.Embedding)
return APIHelper.handleGetStandardCapabilitiesData(channel, AIModelType.Embedding, AICapabilityPromptType.Embedding)
}

/* **************************************************************************/
// MARK: Handlers: Lifecycle
/* **************************************************************************/

#handleCreate = async (channel: IPCInflightChannel) => {
return await APIHelper.handleStandardCreatePreflight(channel, AICapabilityPromptType.Embedding, async (
return await APIHelper.handleStandardCreatePreflight(channel, AIModelType.Embedding, AICapabilityPromptType.Embedding, async (
manifest,
payload,
props
Expand All @@ -73,7 +73,7 @@ class AIEmbeddingHandler {
/* **************************************************************************/

#handleGet = async (channel: IPCInflightChannel) => {
return await APIHelper.handleStandardPromptPreflight(channel, async (
return await APIHelper.handleStandardPromptPreflight(channel, AIModelType.Embedding, async (
manifest,
payload,
props
Expand Down
10 changes: 5 additions & 5 deletions src/extension/background/APIHandler/AILanguageModelHandler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import AILlmSession from '../AI/AILlmSession'
import { AIModelManifest } from '#Shared/AIModelManifest'
import { nanoid } from 'nanoid'
import { Template } from '@huggingface/jinja'
import { AICapabilityPromptType, AIRootModelProps } from '#Shared/API/AI'
import { AICapabilityPromptType, AIRootModelProps, AIModelType } from '#Shared/API/AI'

class AILanguageModelHandler {
/* **************************************************************************/
Expand Down Expand Up @@ -104,15 +104,15 @@ class AILanguageModelHandler {
/* **************************************************************************/

#handleGetCapabilities = async (channel: IPCInflightChannel) => {
return APIHelper.handleGetStandardCapabilitiesData(channel, AICapabilityPromptType.LanguageModel)
return APIHelper.handleGetStandardCapabilitiesData(channel, AIModelType.Text, AICapabilityPromptType.LanguageModel)
}

/* **************************************************************************/
// MARK: Handlers: Lifecycle
/* **************************************************************************/

#handleCreate = async (channel: IPCInflightChannel) => {
return await APIHelper.handleStandardCreatePreflight(channel, AICapabilityPromptType.LanguageModel, async (
return await APIHelper.handleStandardCreatePreflight(channel, AIModelType.Text, AICapabilityPromptType.LanguageModel, async (
manifest,
payload,
props
Expand Down Expand Up @@ -151,7 +151,7 @@ class AILanguageModelHandler {
/* **************************************************************************/

#handleCountTokens = async (channel: IPCInflightChannel) => {
return await APIHelper.handleStandardPromptPreflight(channel, async (
return await APIHelper.handleStandardPromptPreflight(channel, AIModelType.Text, async (
manifest,
payload,
props
Expand Down Expand Up @@ -186,7 +186,7 @@ class AILanguageModelHandler {
* @returns the stream response
*/
#handlePrompt = async (channel: IPCInflightChannel) => {
return await APIHelper.handleStandardPromptPreflight(channel, async (
return await APIHelper.handleStandardPromptPreflight(channel, AIModelType.Text, async (
manifest,
payload,
props
Expand Down
8 changes: 4 additions & 4 deletions src/extension/background/APIHandler/AIRewriterHandler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import AILlmSession from '../AI/AILlmSession'
import { nanoid } from 'nanoid'
import { AIModelManifest } from '#Shared/AIModelManifest'
import { Template } from '@huggingface/jinja'
import { AICapabilityPromptType } from '#Shared/API/AI'
import { AICapabilityPromptType, AIModelType } from '#Shared/API/AI'
import { kModelPromptTypeNotSupported } from '#Shared/Errors'

class AIRewriterHandler {
Expand Down Expand Up @@ -49,15 +49,15 @@ class AIRewriterHandler {
/* **************************************************************************/

#handleGetCapabilities = async (channel: IPCInflightChannel) => {
return APIHelper.handleGetStandardCapabilitiesData(channel, AICapabilityPromptType.Rewriter)
return APIHelper.handleGetStandardCapabilitiesData(channel, AIModelType.Text, AICapabilityPromptType.Rewriter)
}

/* **************************************************************************/
// MARK: Handlers: Lifecycle
/* **************************************************************************/

#handleCreate = async (channel: IPCInflightChannel) => {
return await APIHelper.handleStandardCreatePreflight(channel, AICapabilityPromptType.Rewriter, async (
return await APIHelper.handleStandardCreatePreflight(channel, AIModelType.Text, AICapabilityPromptType.Rewriter, async (
manifest,
payload,
props
Expand All @@ -84,7 +84,7 @@ class AIRewriterHandler {
/* **************************************************************************/

#handleRewrite = async (channel: IPCInflightChannel) => {
return await APIHelper.handleStandardPromptPreflight(channel, async (
return await APIHelper.handleStandardPromptPreflight(channel, AIModelType.Text, async (
manifest,
payload,
props
Expand Down
8 changes: 4 additions & 4 deletions src/extension/background/APIHandler/AISummarizerHandler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import AILlmSession from '../AI/AILlmSession'
import { nanoid } from 'nanoid'
import { AIModelManifest } from '#Shared/AIModelManifest'
import { Template } from '@huggingface/jinja'
import { AICapabilityPromptType } from '#Shared/API/AI'
import { AICapabilityPromptType, AIModelType } from '#Shared/API/AI'
import { kModelPromptTypeNotSupported } from '#Shared/Errors'

class AISummarizerHandler {
Expand Down Expand Up @@ -49,15 +49,15 @@ class AISummarizerHandler {
/* **************************************************************************/

#handleGetCapabilities = async (channel: IPCInflightChannel) => {
return APIHelper.handleGetStandardCapabilitiesData(channel, AICapabilityPromptType.Summarizer)
return APIHelper.handleGetStandardCapabilitiesData(channel, AIModelType.Text, AICapabilityPromptType.Summarizer)
}

/* **************************************************************************/
// MARK: Handlers: Lifecycle
/* **************************************************************************/

#handleCreate = async (channel: IPCInflightChannel) => {
return await APIHelper.handleStandardCreatePreflight(channel, AICapabilityPromptType.Summarizer, async (
return await APIHelper.handleStandardCreatePreflight(channel, AIModelType.Text, AICapabilityPromptType.Summarizer, async (
manifest,
payload,
props
Expand All @@ -84,7 +84,7 @@ class AISummarizerHandler {
/* **************************************************************************/

#handleSummarize = async (channel: IPCInflightChannel) => {
return await APIHelper.handleStandardPromptPreflight(channel, async (
return await APIHelper.handleStandardPromptPreflight(channel, AIModelType.Text, async (
manifest,
payload,
props
Expand Down
8 changes: 4 additions & 4 deletions src/extension/background/APIHandler/AIWriterHandler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import AILlmSession from '../AI/AILlmSession'
import { nanoid } from 'nanoid'
import { AIModelManifest } from '#Shared/AIModelManifest'
import { Template } from '@huggingface/jinja'
import { AICapabilityPromptType } from '#Shared/API/AI'
import { AICapabilityPromptType, AIModelType } from '#Shared/API/AI'
import { kModelPromptTypeNotSupported } from '#Shared/Errors'

class AIWriterHandler {
Expand Down Expand Up @@ -49,15 +49,15 @@ class AIWriterHandler {
/* **************************************************************************/

#handleGetCapabilities = async (channel: IPCInflightChannel) => {
return APIHelper.handleGetStandardCapabilitiesData(channel, AICapabilityPromptType.Writer)
return APIHelper.handleGetStandardCapabilitiesData(channel, AIModelType.Text, AICapabilityPromptType.Writer)
}

/* **************************************************************************/
// MARK: Handlers: Lifecycle
/* **************************************************************************/

#handleCreate = async (channel: IPCInflightChannel) => {
return await APIHelper.handleStandardCreatePreflight(channel, AICapabilityPromptType.Writer, async (
return await APIHelper.handleStandardCreatePreflight(channel, AIModelType.Text, AICapabilityPromptType.Writer, async (
manifest,
payload,
props
Expand All @@ -84,7 +84,7 @@ class AIWriterHandler {
/* **************************************************************************/

#handleWrite = async (channel: IPCInflightChannel) => {
return await APIHelper.handleStandardPromptPreflight(channel, async (
return await APIHelper.handleStandardPromptPreflight(channel, AIModelType.Text, async (
manifest,
payload,
props
Expand Down
25 changes: 14 additions & 11 deletions src/extension/background/APIHandler/APIHelper.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { clamp, getAIModelId, getEnum } from '#Shared/API/Untrusted/UntrustedParser'
import { clamp, getNonEmptyString, getEnum } from '#Shared/API/Untrusted/UntrustedParser'
import {
getDefaultModel,
getDefaultModelEngine,
Expand All @@ -24,7 +24,8 @@ import {
AICapabilityAvailability,
AIRootModelCapabilitiesData,
AICapabilityPromptType,
AIRootModelProps
AIRootModelProps,
AIModelType
} from '#Shared/API/AI'
import { IPCInflightChannel } from '#Shared/IPC/IPCServer'
import PermissionProvider from '../PermissionProvider'
Expand All @@ -44,8 +45,8 @@ class APIHelper {
* @param modelId: the id of the model
* @returns the model id or the default
*/
async getModelId (modelId: any): Promise<string> {
return getAIModelId(modelId, await getDefaultModel())
async getModelId (modelId: any, modelType: AIModelType): Promise<string> {
return getNonEmptyString(modelId, await getDefaultModel(modelType))
}

/**
Expand All @@ -57,10 +58,6 @@ class APIHelper {
return getEnum(gpuEngine, AICapabilityGpuEngine, await getDefaultModelEngine())
}

/* **************************************************************************/
// MARK: Models
/* **************************************************************************/

/**
* Looks to see if a model supports a given prompt type
* @param manifest: the manifest of the model
Expand Down Expand Up @@ -146,16 +143,18 @@ class APIHelper {
/**
* Gets the standard capabilities data
* @param channel: the incoming channel
* @param modelType: the type of model we're targeting
* @param promptType: the type of prompt to check is available
* @returns the response for the channel
*/
async handleGetStandardCapabilitiesData (
channel: IPCInflightChannel,
modelType: AIModelType,
promptType: AICapabilityPromptType,
configFn?: (manifest: AIModelManifest) => object
): Promise<AIRootModelCapabilitiesData> {
return await this.captureCommonErrorsForResponse(async () => {
const modelId = await this.getModelId(channel.payload?.model)
const modelId = await this.getModelId(channel.payload?.model, modelType)

// Permission checks & requests
await PermissionProvider.requestModelPermission(channel, modelId)
Expand Down Expand Up @@ -215,12 +214,14 @@ class APIHelper {
/**
* Handles a bunch of preflight tasks before a create call
* @param channel: the incoming IPC channel
* @param modelType: the type of model we're targeting
* @param promptType: the prompt type we should check support for
* @param postflightFn: a function that can execute a after the preflight calls have been executed
* @returns the reply from the postflight
*/
async handleStandardCreatePreflight (
channel: IPCInflightChannel,
modelType: AIModelType,
promptType: AICapabilityPromptType,
postflightFn: (
manifest: AIModelManifest,
Expand All @@ -239,7 +240,7 @@ class APIHelper {
const payload = new UntrustedParser(rawPayload)
return await this.captureCommonErrorsForResponse(async () => {
// Values with user-defined defaults
const modelId = await this.getModelId(rawPayload?.model)
const modelId = await this.getModelId(rawPayload?.model, modelType)
const gpuEngine = await this.getGpuEngine(rawPayload?.gpuEngine)

// Permission checks & requests
Expand Down Expand Up @@ -290,11 +291,13 @@ class APIHelper {
/**
* Handles a bunch of preflight tasks before a prompt call
* @param channel: the incoming IPC channel
* @param modelType: the type of model we're targeting
* @param postflightFn: a function that can execute a after the preflight calls have been executed
* @returns the reply from the postflight
*/
async handleStandardPromptPreflight (
channel: IPCInflightChannel,
modelType: AIModelType,
postflightFn: (
manifest: AIModelManifest,
payload: UntrustedParser,
Expand All @@ -305,7 +308,7 @@ class APIHelper {
const payload = new UntrustedParser(rawPayload)

// Values with user-defined defaults
const modelId = await this.getModelId(rawPayload?.props?.model)
const modelId = await this.getModelId(rawPayload?.props?.model, modelType)
const gpuEngine = await this.getGpuEngine(rawPayload?.props?.gpuEngine)

// Permission checks & requests
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class PermissionProvider {

// The default model might already have permission
if (config.permissionRequiredForDefaultModel === false) {
if (modelId === undefined || modelId === config.defaultAiModel) { return true }
if (modelId === undefined || Object.values(config.defaultModels).includes(modelId)) { return true }
}

// Check if we're a pre-allowed origin
Expand Down
12 changes: 10 additions & 2 deletions src/extension/ui-options/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,21 @@
<div class="card-body">
<h5 class="card-title">Settings</h5>
<div class="mb-3">
<label for="opt-model-default" class="form-label">Default model</label>
<select class="form-control" id="opt-model-default">
<label for="opt-model-text-default" class="form-label">Default text model</label>
<select class="form-control" id="opt-model-text-default">
</select>
<div class="form-text">
The default model to use when a site doesn't specify one.
</div>
</div>
<div class="mb-3">
<label for="opt-model-embedding-default" class="form-label">Default embedding model</label>
<select class="form-control" id="opt-model-embedding-default">
</select>
<div class="form-text">
The default embedding model to use when a site doesn't specify one.
</div>
</div>
<div class="mb-3">
<label for="opt-engine-default" class="form-label">Default engine</label>
<select class="form-control" id="opt-engine-default">
Expand Down
Loading

0 comments on commit 963d716

Please sign in to comment.