From ac677e4d5504eb12c9820494233a19325b398628 Mon Sep 17 00:00:00 2001 From: pelikhan Date: Thu, 5 Dec 2024 08:32:57 -0800 Subject: [PATCH 1/5] refactor model to support aliases --- packages/cli/src/info.ts | 8 +++--- packages/cli/src/nodehost.ts | 14 +++++------ packages/cli/src/run.ts | 16 ++++++------ packages/cli/src/test.ts | 16 ++++++------ packages/core/src/chat.ts | 19 +++++++------- packages/core/src/connection.ts | 26 +++++++++++--------- packages/core/src/encoders.ts | 2 +- packages/core/src/expander.ts | 2 +- packages/core/src/git.ts | 4 +-- packages/core/src/globals.ts | 6 ++--- packages/core/src/host.ts | 17 +++++++------ packages/core/src/models.ts | 8 +++--- packages/core/src/promptcontext.ts | 7 +++--- packages/core/src/test.ts | 19 ++++---------- packages/core/src/testhost.ts | 15 +++++------ packages/core/src/types/prompt_template.d.ts | 7 ++---- packages/core/src/vectorsearch.ts | 5 ++-- packages/vscode/src/vshost.ts | 15 +++++------ 18 files changed, 94 insertions(+), 112 deletions(-) diff --git a/packages/cli/src/info.ts b/packages/cli/src/info.ts index 78009a9908..962e7d3339 100644 --- a/packages/cli/src/info.ts +++ b/packages/cli/src/info.ts @@ -84,20 +84,20 @@ export async function envInfo( /** * Resolves connection information for script templates by deduplicating model options. - * @param templates - Array of model connection options to resolve. + * @param scripts - Array of model connection options to resolve. * @param options - Configuration options, including whether to show tokens. * @returns A promise that resolves to an array of model connection information. */ async function resolveScriptsConnectionInfo( - templates: ModelConnectionOptions[], + scripts: ModelConnectionOptions[], options?: { token?: boolean } ): Promise { const models: Record = {} // Deduplicate model connection options - for (const template of templates) { + for (const script of scripts) { const conn: ModelConnectionOptions = { - model: template.model ?? host.defaultModelOptions.model, + model: script.model ?? host.modelAliases.large.model, } const key = JSON.stringify(conn) if (!models[key]) models[key] = conn diff --git a/packages/cli/src/nodehost.ts b/packages/cli/src/nodehost.ts index c0a235cb99..aab1454417 100644 --- a/packages/cli/src/nodehost.ts +++ b/packages/cli/src/nodehost.ts @@ -44,6 +44,7 @@ import { setRuntimeHost, ResponseStatus, AzureTokenResolver, + ModelConfigurations, } from "../../core/src/host" import { AbortSignalOptions, TraceOptions } from "../../core/src/trace" import { logError, logVerbose } from "../../core/src/util" @@ -139,14 +140,11 @@ export class NodeHost implements RuntimeHost { readonly workspace = createFileSystem() readonly containers = new DockerManager() readonly browsers = new BrowserManager() - readonly defaultModelOptions = { - model: DEFAULT_MODEL, - smallModel: DEFAULT_SMALL_MODEL, - visionModel: DEFAULT_VISION_MODEL, - temperature: DEFAULT_TEMPERATURE, - } - readonly defaultEmbeddingsModelOptions = { - embeddingsModel: DEFAULT_EMBEDDINGS_MODEL, + readonly modelAliases: ModelConfigurations = { + large: { model: DEFAULT_MODEL }, + small: { model: DEFAULT_SMALL_MODEL }, + vision: { model: DEFAULT_VISION_MODEL }, + embeddings: { model: DEFAULT_EMBEDDINGS_MODEL }, } readonly userInputQueue = new PLimitPromiseQueue(1) readonly azureToken: AzureTokenResolver diff --git a/packages/cli/src/run.ts b/packages/cli/src/run.ts index ca59392edb..45ace4ce93 100644 --- a/packages/cli/src/run.ts +++ b/packages/cli/src/run.ts @@ -208,11 +208,10 @@ export async function runScript( const topLogprobs = normalizeInt(options.topLogprobs) if (options.json || options.yaml) overrideStdoutWithStdErr() - if (options.model) host.defaultModelOptions.model = options.model - if (options.smallModel) - host.defaultModelOptions.smallModel = options.smallModel + if (options.model) host.modelAliases.large.model = options.model + if (options.smallModel) host.modelAliases.small.model = options.smallModel if (options.visionModel) - host.defaultModelOptions.visionModel = options.visionModel + host.modelAliases.vision.model = options.visionModel const fail = (msg: string, exitCode: number, url?: string) => { logError(url ? `${msg} (see ${url})` : msg) @@ -220,9 +219,9 @@ export async function runScript( } logInfo(`genaiscript: ${scriptId}`) - logVerbose(` large : ${host.defaultModelOptions.model}`) - logVerbose(` small : ${host.defaultModelOptions.smallModel}`) - logVerbose(` vision: ${host.defaultModelOptions.visionModel}`) + Object.entries(host.modelAliases).forEach(([key, host]) => + logVerbose(` ${key}: ${host.model}`) + ) if (out) { if (removeOut) await emptyDir(out) @@ -375,8 +374,7 @@ export async function runScript( maxDataRepairs, model: info.model, embeddingsModel: - options.embeddingsModel ?? - host.defaultEmbeddingsModelOptions.embeddingsModel, + options.embeddingsModel ?? host.modelAliases.embeddings.model, retry, retryDelay, maxDelay, diff --git a/packages/cli/src/test.ts b/packages/cli/src/test.ts index 39d0850876..ab86d492f4 100644 --- a/packages/cli/src/test.ts +++ b/packages/cli/src/test.ts @@ -108,14 +108,12 @@ export async function runPromptScriptTests( testDelay?: string } ): Promise { - if (options.model) host.defaultModelOptions.model = options.model - if (options.smallModel) - host.defaultModelOptions.smallModel = options.smallModel + if (options.model) host.modelAliases.large.model = options.model + if (options.smallModel) host.modelAliases.small.model = options.smallModel if (options.visionModel) - host.defaultModelOptions.visionModel = options.visionModel - - logVerbose( - `model: ${host.defaultModelOptions.model}, small model: ${host.defaultModelOptions.smallModel}, vision model: ${host.defaultModelOptions.visionModel}` + host.modelAliases.vision.model = options.visionModel + Object.entries(host.modelAliases).forEach(([key, host]) => + logVerbose(` ${key}: ${host.model}`) ) const scripts = await listTests({ ids, ...(options || {}) }) @@ -147,12 +145,12 @@ export async function runPromptScriptTests( : script.filename.replace(GENAI_ANY_REGEX, ".promptfoo.yaml") logInfo(` ${fn}`) const { info: chatInfo } = await resolveModelConnectionInfo(script, { - model: host.defaultModelOptions.model, + model: host.modelAliases.large.model, }) if (chatInfo.error) throw new Error(chatInfo.error) let { info: embeddingsInfo } = await resolveModelConnectionInfo( script, - { model: host.defaultEmbeddingsModelOptions.embeddingsModel } + { model: host.modelAliases.embeddings.model } ) if (embeddingsInfo?.error) embeddingsInfo = undefined const config = generatePromptFooConfiguration(script, { diff --git a/packages/core/src/chat.ts b/packages/core/src/chat.ts index c842278d4e..60fbafbab8 100644 --- a/packages/core/src/chat.ts +++ b/packages/core/src/chat.ts @@ -735,21 +735,21 @@ export function mergeGenerationOptions( model: runOptions?.model ?? options?.model ?? - host.defaultModelOptions.model, + host.modelAliases.large.model, smallModel: runOptions?.smallModel ?? options?.smallModel ?? - host.defaultModelOptions.smallModel, + host.modelAliases.small.model, visionModel: runOptions?.visionModel ?? options?.visionModel ?? - host.defaultModelOptions.visionModel, + host.modelAliases.vision.model, temperature: - runOptions?.temperature ?? host.defaultModelOptions.temperature, + runOptions?.temperature ?? host.modelAliases.large.temperature, embeddingsModel: runOptions?.embeddingsModel ?? options?.embeddingsModel ?? - host.defaultEmbeddingsModelOptions.embeddingsModel, + host.modelAliases.embeddings.model, } satisfies GenerationOptions return res } @@ -803,8 +803,8 @@ export async function executeChatSession( ): Promise { const { trace, - model = host.defaultModelOptions.model, - temperature = host.defaultModelOptions.temperature, + model, + temperature, topP, maxTokens, seed, @@ -815,6 +815,7 @@ export async function executeChatSession( choices, topLogprobs, } = genOptions + assert(!!model, "model is required") const top_logprobs = genOptions.topLogprobs > 0 ? topLogprobs : undefined const logprobs = genOptions.logprobs || top_logprobs > 0 ? true : undefined traceLanguageModelConnection(trace, genOptions, connectionToken) @@ -863,7 +864,7 @@ export async function executeChatSession( model, choices ) - req = { + req = deleteUndefinedValues({ model, temperature: temperature, top_p: topP, @@ -894,7 +895,7 @@ export async function executeChatSession( }, } : undefined, - } + }) if (/^o1/i.test(model)) { req.max_completion_tokens = maxTokens delete req.max_tokens diff --git a/packages/core/src/connection.ts b/packages/core/src/connection.ts index 1cd7873d61..72955c4140 100644 --- a/packages/core/src/connection.ts +++ b/packages/core/src/connection.ts @@ -72,19 +72,21 @@ export function findEnvVar( } export async function parseDefaultsFromEnv(env: Record) { + // legacy if (env.GENAISCRIPT_DEFAULT_MODEL) - host.defaultModelOptions.model = env.GENAISCRIPT_DEFAULT_MODEL - if (env.GENAISCRIPT_DEFAULT_SMALL_MODEL) - host.defaultModelOptions.smallModel = - env.GENAISCRIPT_DEFAULT_SMALL_MODEL - if (env.GENAISCRIPT_DEFAULT_VISION_MODEL) - host.defaultModelOptions.visionModel = - env.GENAISCRIPT_DEFAULT_VISION_MODEL + host.modelAliases.large.model = env.GENAISCRIPT_DEFAULT_MODEL + + const m = /^GENAISCRIPT_DEFAULT_(?[A-Z0-9]+)_MODEL$/i + Object.keys(env) + .map((k) => m.exec(k)?.groups.id) + .filter((id) => !!id) + .forEach((id: string) => { + id = id.toLocaleLowerCase() + const c = host.modelAliases[id] || (host.modelAliases[id] = {}) + c.model = env[`GENAISCRIPT_DEFAULT_${id}_MODEL`] + }) const t = normalizeFloat(env.GENAISCRIPT_DEFAULT_TEMPERATURE) - if (!isNaN(t)) host.defaultModelOptions.temperature = t - if (env.GENAISCRIPT_DEFAULT_EMBEDDINGS_MODEL) - host.defaultEmbeddingsModelOptions.embeddingsModel = - env.GENAISCRIPT_DEFAULT_EMBEDDINGS_MODEL + if (!isNaN(t)) host.modelAliases.large.temperature = t } export async function parseTokenFromEnv( @@ -92,7 +94,7 @@ export async function parseTokenFromEnv( modelId: string ): Promise { const { provider, model, tag } = parseModelIdentifier( - modelId ?? host.defaultModelOptions.model + modelId ?? host.modelAliases.large.model ) const TOKEN_SUFFIX = ["_API_KEY", "_API_TOKEN", "_TOKEN", "_KEY"] const BASE_SUFFIX = ["_API_BASE", "_API_ENDPOINT", "_BASE", "_ENDPOINT"] diff --git a/packages/core/src/encoders.ts b/packages/core/src/encoders.ts index 5c87fb0535..0cebe8f94c 100644 --- a/packages/core/src/encoders.ts +++ b/packages/core/src/encoders.ts @@ -18,7 +18,7 @@ export async function resolveTokenEncoder( ): Promise { const { disableFallback } = options || {} // Parse the model identifier to extract the model information - if (!modelId) modelId = runtimeHost.defaultModelOptions.model + if (!modelId) modelId = runtimeHost.modelAliases.large.model const { model } = parseModelIdentifier(modelId) const module = model.toLowerCase() // Assign model to module for dynamic import path diff --git a/packages/core/src/expander.ts b/packages/core/src/expander.ts index c1a60bd099..f9b0558b0a 100644 --- a/packages/core/src/expander.ts +++ b/packages/core/src/expander.ts @@ -194,7 +194,7 @@ export async function expandTemplate( options.temperature ?? normalizeFloat(env.vars["temperature"]) ?? template.temperature ?? - host.defaultModelOptions.temperature + host.modelAliases.large.temperature const topP = options.topP ?? normalizeFloat(env.vars["top_p"]) ?? template.topP const maxTokens = diff --git a/packages/core/src/git.ts b/packages/core/src/git.ts index 389bf35b31..5fd0628379 100644 --- a/packages/core/src/git.ts +++ b/packages/core/src/git.ts @@ -316,7 +316,7 @@ export class GitClient implements Git { if (!nameOnly && llmify) { res = llmifyDiff(res) const { encode: encoder } = await resolveTokenEncoder( - runtimeHost.defaultModelOptions.model || DEFAULT_MODEL + runtimeHost.modelAliases.large.model ) const tokens = estimateTokens(res, encoder) if (tokens > maxTokensFullDiff) @@ -329,7 +329,7 @@ ${truncateTextToTokens(res, maxTokensFullDiff, encoder)} ## Files ${await this.diff({ ...options, nameOnly: true })} ` - } + } return res } diff --git a/packages/core/src/globals.ts b/packages/core/src/globals.ts index 0252875795..4b93bbaad9 100644 --- a/packages/core/src/globals.ts +++ b/packages/core/src/globals.ts @@ -68,7 +68,7 @@ export function installGlobals() { // Freeze XML utilities glb.XML = Object.freeze({ - parse: XMLParse, // Parse XML string to objects + parse: XMLParse, // Parse XML string to objects }) // Freeze Markdown utilities with frontmatter operations @@ -124,14 +124,14 @@ export function installGlobals() { resolve: resolveTokenEncoder, count: async (text, options) => { const { encode: encoder } = await resolveTokenEncoder( - options?.model || runtimeHost.defaultModelOptions.model + options?.model || runtimeHost.modelAliases.large.model ) const c = await estimateTokens(text, encoder) return c }, truncate: async (text, maxTokens, options) => { const { encode: encoder } = await resolveTokenEncoder( - options?.model || runtimeHost.defaultModelOptions.model + options?.model || runtimeHost.modelAliases.large.model ) return await truncateTextToTokens(text, maxTokens, encoder, options) }, diff --git a/packages/core/src/host.ts b/packages/core/src/host.ts index 67d1e9b0cd..701e5c4c58 100644 --- a/packages/core/src/host.ts +++ b/packages/core/src/host.ts @@ -113,6 +113,15 @@ export interface AzureTokenResolver { ): Promise } +export type ModelConfiguration = Pick + +export type ModelConfigurations = { + large: ModelConfiguration + small: ModelConfiguration + vision: ModelConfiguration + embeddings: ModelConfiguration +} & Record + export interface Host { userState: any server: ServerManager @@ -124,13 +133,7 @@ export interface Host { installFolder(): string resolvePath(...segments: string[]): string - // read a secret from the environment or a .env file - defaultModelOptions: Required< - Pick - > - defaultEmbeddingsModelOptions: Required< - Pick - > + modelAliases: ModelConfigurations getLanguageModelConfiguration( modelId: string, options?: { token?: boolean } & AbortSignalOptions & TraceOptions diff --git a/packages/core/src/models.ts b/packages/core/src/models.ts index e67c9cd57c..1b10f0c52e 100644 --- a/packages/core/src/models.ts +++ b/packages/core/src/models.ts @@ -106,23 +106,23 @@ export async function resolveModelConnectionInfo( if (m === SMALL_MODEL_ID) { m = undefined candidates ??= [ - host.defaultModelOptions.smallModel, + host.modelAliases.small.model, ...DEFAULT_SMALL_MODEL_CANDIDATES, ] } else if (m === VISION_MODEL_ID) { m = undefined candidates ??= [ - host.defaultModelOptions.visionModel, + host.modelAliases.vision.model, ...DEFAULT_VISION_MODEL_CANDIDATES, ] } else if (m === LARGE_MODEL_ID) { m = undefined candidates ??= [ - host.defaultModelOptions.model, + host.modelAliases.large.model, ...DEFAULT_MODEL_CANDIDATES, ] } - candidates ??= [host.defaultModelOptions.model, ...DEFAULT_MODEL_CANDIDATES] + candidates ??= [host.modelAliases.large.model, ...DEFAULT_MODEL_CANDIDATES] const resolveModel = async ( model: string, diff --git a/packages/core/src/promptcontext.ts b/packages/core/src/promptcontext.ts index 63e1a9b649..c17ebf403d 100644 --- a/packages/core/src/promptcontext.ts +++ b/packages/core/src/promptcontext.ts @@ -190,7 +190,7 @@ export async function createPromptContext( searchOptions.embeddingsModel = searchOptions?.embeddingsModel ?? options?.embeddingsModel ?? - host.defaultEmbeddingsModelOptions.embeddingsModel + host.modelAliases.embeddings.model const key = await hash({ files, searchOptions }, { length: 12 }) const folderPath = dotGenaiscriptPath("vectors", key) const res = await vectorSearch(q, files, { @@ -213,8 +213,9 @@ export async function createPromptContext( // Define the host for executing commands, browsing, and other operations const promptHost: PromptHost = Object.freeze({ - fetch: (url, options) => fetch(url, {...(options || {}), trace }), - fetchText: (url, options) => fetchText(url, {...(options || {}), trace }), + fetch: (url, options) => fetch(url, { ...(options || {}), trace }), + fetchText: (url, options) => + fetchText(url, { ...(options || {}), trace }), resolveLanguageModel: async (modelId) => { const { configuration } = await resolveModelConnectionInfo( { model: modelId }, diff --git a/packages/core/src/test.ts b/packages/core/src/test.ts index 227cb19006..694ab7f1fc 100644 --- a/packages/core/src/test.ts +++ b/packages/core/src/test.ts @@ -94,14 +94,7 @@ export function generatePromptFooConfiguration( const cli = options?.cli const transform = "output.text" - const resolveModel = (m: string) => - m === SMALL_MODEL_ID - ? host.defaultModelOptions.smallModel - : m === VISION_MODEL_ID - ? host.defaultModelOptions.visionModel - : m === LARGE_MODEL_ID - ? host.defaultModelOptions.model - : m + const resolveModel = (m: string) => host.modelAliases[m]?.model ?? m const testProvider = deleteUndefinedValues({ text: resolveTestProvider(chatInfo, "chat"), @@ -119,16 +112,14 @@ export function generatePromptFooConfiguration( // Map model options to providers providers: models .map(({ model, smallModel, visionModel, temperature, topP }) => ({ - model: resolveModel(model) ?? host.defaultModelOptions.model, + model: resolveModel(model) ?? host.modelAliases.large.model, smallModel: - resolveModel(smallModel) ?? - host.defaultModelOptions.smallModel, + resolveModel(smallModel) ?? host.modelAliases.small.model, visionModel: - resolveModel(visionModel) ?? - host.defaultModelOptions.visionModel, + resolveModel(visionModel) ?? host.modelAliases.vision.model, temperature: !isNaN(temperature) ? temperature - : host.defaultModelOptions.temperature, + : host.modelAliases.temperature, top_p: topP, })) .map(({ model, smallModel, visionModel, temperature, top_p }) => ({ diff --git a/packages/core/src/testhost.ts b/packages/core/src/testhost.ts index e09a601a2a..e81a468d92 100644 --- a/packages/core/src/testhost.ts +++ b/packages/core/src/testhost.ts @@ -15,6 +15,7 @@ import { setRuntimeHost, RuntimeHost, AzureTokenResolver, + ModelConfigurations, } from "./host" import { TraceOptions } from "./trace" import { @@ -70,15 +71,11 @@ export class TestHost implements RuntimeHost { azureToken: AzureTokenResolver = undefined // Default options for language models - readonly defaultModelOptions = { - model: DEFAULT_MODEL, - smallModel: DEFAULT_SMALL_MODEL, - visionModel: DEFAULT_VISION_MODEL, - temperature: DEFAULT_TEMPERATURE, - } - // Default options for embeddings models - readonly defaultEmbeddingsModelOptions = { - embeddingsModel: DEFAULT_EMBEDDINGS_MODEL, + readonly modelAliases: ModelConfigurations = { + large: { model: DEFAULT_MODEL }, + small: { model: DEFAULT_SMALL_MODEL }, + vision: { model: DEFAULT_VISION_MODEL }, + embeddings: { model: DEFAULT_EMBEDDINGS_MODEL }, } // Static method to set this class as the runtime host diff --git a/packages/core/src/types/prompt_template.d.ts b/packages/core/src/types/prompt_template.d.ts index e8b17d139b..acf5193d29 100644 --- a/packages/core/src/types/prompt_template.d.ts +++ b/packages/core/src/types/prompt_template.d.ts @@ -191,15 +191,12 @@ interface ModelConnectionOptions { model?: ModelType /** - * Which LLM model to use for the "small" model. - * - * @default gpt-4 - * @example gpt-4 + * @deprecated use model aliases */ smallModel?: ModelSmallType /** - * Which LLM to use for the "vision" model. + * @deprecated use model aliases */ visionModel?: ModelVisionType } diff --git a/packages/core/src/vectorsearch.ts b/packages/core/src/vectorsearch.ts index e06efe6120..77184baac2 100644 --- a/packages/core/src/vectorsearch.ts +++ b/packages/core/src/vectorsearch.ts @@ -175,8 +175,7 @@ export async function vectorSearch( const { topK, folderPath, - embeddingsModel = runtimeHost.defaultEmbeddingsModelOptions - .embeddingsModel, + embeddingsModel = runtimeHost.modelAliases.embeddings.model, minScore = 0, trace, } = options @@ -199,7 +198,7 @@ export async function vectorSearch( { token: true, candidates: [ - host.defaultEmbeddingsModelOptions.embeddingsModel, + host.modelAliases.embeddings.model, ...DEFAULT_EMBEDDINGS_MODEL_CANDIDATES, ], } diff --git a/packages/vscode/src/vshost.ts b/packages/vscode/src/vshost.ts index 91d1bbab3e..c154c35bbe 100644 --- a/packages/vscode/src/vshost.ts +++ b/packages/vscode/src/vshost.ts @@ -20,6 +20,7 @@ import { LanguageModelConfiguration, LogLevel, Host, + ModelConfigurations, } from "../../core/src/host" import { TraceOptions, AbortSignalOptions } from "../../core/src/trace" import { arrayify } from "../../core/src/util" @@ -30,16 +31,12 @@ export class VSCodeHost extends EventTarget implements Host { userState: any = {} readonly path = createVSPath() readonly server: TerminalServerManager - readonly defaultModelOptions = { - model: DEFAULT_MODEL, - smallModel: DEFAULT_SMALL_MODEL, - visionModel: DEFAULT_VISION_MODEL, - temperature: DEFAULT_TEMPERATURE, + readonly modelAliases: ModelConfigurations = { + large: { model: DEFAULT_MODEL }, + small: { model: DEFAULT_SMALL_MODEL }, + vision: { model: DEFAULT_VISION_MODEL }, + embeddings: { model: DEFAULT_EMBEDDINGS_MODEL }, } - readonly defaultEmbeddingsModelOptions = { - embeddingsModel: DEFAULT_EMBEDDINGS_MODEL, - } - constructor(readonly state: ExtensionState) { super() setHost(this) From 71ca0226506ab6858d312fd49b91f516c2b02dc6 Mon Sep 17 00:00:00 2001 From: pelikhan Date: Thu, 5 Dec 2024 08:40:27 -0800 Subject: [PATCH 2/5] =?UTF-8?q?refactor:=20update=20model=20options=20and?= =?UTF-8?q?=20remove=20smallModel=20=F0=9F=8E=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- packages/cli/src/test.ts | 2 +- packages/core/src/chat.ts | 8 -------- packages/core/src/test.ts | 6 +++--- packages/core/src/types/prompt_template.d.ts | 3 +++ 4 files changed, 7 insertions(+), 12 deletions(-) diff --git a/packages/cli/src/test.ts b/packages/cli/src/test.ts index ab86d492f4..c443693a95 100644 --- a/packages/cli/src/test.ts +++ b/packages/cli/src/test.ts @@ -51,7 +51,7 @@ import { filterScripts } from "../../core/src/ast" * @param m - The string representation of the model specification. * @returns A ModelOptions object with model, temperature, and topP fields if applicable. */ -function parseModelSpec(m: string): ModelOptions { +function parseModelSpec(m: string): ModelOptions & ModelAliasesOptions { const values = m .split(/&/g) .map((kv) => kv.split("=", 2)) diff --git a/packages/core/src/chat.ts b/packages/core/src/chat.ts index 60fbafbab8..d1e72b7613 100644 --- a/packages/core/src/chat.ts +++ b/packages/core/src/chat.ts @@ -736,14 +736,6 @@ export function mergeGenerationOptions( runOptions?.model ?? options?.model ?? host.modelAliases.large.model, - smallModel: - runOptions?.smallModel ?? - options?.smallModel ?? - host.modelAliases.small.model, - visionModel: - runOptions?.visionModel ?? - options?.visionModel ?? - host.modelAliases.vision.model, temperature: runOptions?.temperature ?? host.modelAliases.large.temperature, embeddingsModel: diff --git a/packages/core/src/test.ts b/packages/core/src/test.ts index 694ab7f1fc..29cc258431 100644 --- a/packages/core/src/test.ts +++ b/packages/core/src/test.ts @@ -69,7 +69,7 @@ export function generatePromptFooConfiguration( provider?: string out?: string cli?: string - models?: ModelOptions[] + models?: (ModelOptions & ModelAliasesOptions)[] } ) { // Destructure options with default values @@ -86,8 +86,8 @@ export function generatePromptFooConfiguration( models.push({ ...script, model: chatInfo.model, - smallModel: chatInfo.smallModel, - visionModel: chatInfo.visionModel, + smallModel: chatInfo.model, + visionModel: chatInfo.model, }) } diff --git a/packages/core/src/types/prompt_template.d.ts b/packages/core/src/types/prompt_template.d.ts index acf5193d29..f707bd120c 100644 --- a/packages/core/src/types/prompt_template.d.ts +++ b/packages/core/src/types/prompt_template.d.ts @@ -189,7 +189,9 @@ interface ModelConnectionOptions { * Which LLM model to use. Use `large` for the default set of model candidates, `small` for the set of small models like gpt-4o-mini. */ model?: ModelType +} +interface ModelAliasesOptions { /** * @deprecated use model aliases */ @@ -443,6 +445,7 @@ interface ContentSafetyOptions { interface PromptScript extends PromptLike, ModelOptions, + ModelAliasesOptions, PromptSystemOptions, EmbeddingsModelOptions, ContentSafetyOptions, From 956346620b078581be48e7a505b7a740b30cee11 Mon Sep 17 00:00:00 2001 From: pelikhan Date: Thu, 5 Dec 2024 08:58:29 -0800 Subject: [PATCH 3/5] =?UTF-8?q?feat:=20=F0=9F=9A=80=20add=20vision=20model?= =?UTF-8?q?=20support=20and=20update=20env=20vars?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../docs/getting-started/configuration.mdx | 65 ++++++++++++------- packages/core/src/connection.ts | 19 +++--- packages/core/src/test.ts | 6 +- packages/core/src/types/prompt_template.d.ts | 6 +- 4 files changed, 56 insertions(+), 40 deletions(-) diff --git a/docs/src/content/docs/getting-started/configuration.mdx b/docs/src/content/docs/getting-started/configuration.mdx index 7f5ec233e8..d1584a3e1b 100644 --- a/docs/src/content/docs/getting-started/configuration.mdx +++ b/docs/src/content/docs/getting-started/configuration.mdx @@ -40,9 +40,9 @@ script({ }) ``` -### Large and small models +### Large, small, vision models -You can also use the `small` and `large` aliases to use the default configured small and large models. +You can also use the `small`, `large`, `vision` aliases to use the default configured small, large and vision-enabled models. Large models are typically in the OpenAI gpt-4 reasoning range and can be used for more complex tasks. Small models are in the OpenAI gpt-4o-mini range, and are useful for quick and simple tasks. @@ -60,11 +60,28 @@ The model can also be overridden from the [cli run command](/genaiscript/referen genaiscript run ... --model largemodelid --small-model smallmodelid ``` -or by adding the `GENAISCRIPT_DEFAULT_MODEL` and `GENAISCRIPT_DEFAULT_SMALL_MODEL` environment variables. +or by adding the `GENAISCRIPT_LARGE_MODEL` and `GENAISCRIPT_SMALL_MODEL` environment variables. ```txt title=".env" -GENAISCRIPT_DEFAULT_MODEL="azure_serverless:..." -GENAISCRIPT_DEFAULT_SMALL_MODEL="azure_serverless:..." +GENAISCRIPT_MODEL_LARGE="azure_serverless:..." +GENAISCRIPT_MODEL_SMALL="azure_serverless:..." +GENAISCRIPT_MODEL_VISION="azure_serverless:..." +``` + +### Model aliases + +In fact, you can define any alias for your model (only alphanumeric characters are allowed) +through environment variables of the name `GENAISCRIPT_MODEL_ALIAS` +where `ALIAS` is the alias you want to use. + +```txt title=".env" +GENAISCRIPT_MODEL_TINY=... +``` + +Model aliases are always lowercased when used in the script. + +```js +script({ model: "tiny" }) ``` ## `.env` file @@ -93,8 +110,8 @@ Create a `.env` file in the root of your project. -- .gitignore -- **.env** +- .gitignore +- **.env** @@ -127,19 +144,19 @@ the `.env` file will appear grayed out in Visual Studio Code. You can specify a custom `.env` file location through the CLI or an environment variable. -- by adding the `--env ` argument to the CLI. +- by adding the `--env ` argument to the CLI. ```sh "--env .env.local" npx genaiscript ... --env .env.local ``` -- by setting the `GENAISCRIPT_ENV_FILE` environment variable. +- by setting the `GENAISCRIPT_ENV_FILE` environment variable. ```sh GENAISCRIPT_ENV_FILE=".env.local" npx genaiscript ... ``` -- by specifying the `.env` file location in a [configuration file](/genaiscript/reference/configuration-files). +- by specifying the `.env` file location in a [configuration file](/genaiscript/reference/configuration-files). ```yaml title="~/genaiscript.config.yaml" envFile: ~/.env.genaiscript @@ -152,13 +169,13 @@ of the genaiscript process with the configuration values. Here are some common examples: -- Using bash syntax +- Using bash syntax ```sh OPENAI_API_KEY="value" npx --yes genaiscript run ... ``` -- GitHub Action configuration +- GitHub Action configuration ```yaml title=".github/workflows/genaiscript.yml" run: npx --yes genaiscript run ... @@ -219,11 +236,11 @@ script({ :::tip[Default Model Configuration] -Use `GENAISCRIPT_DEFAULT_MODEL` and `GENAISCRIPT_DEFAULT_SMALL_MODEL` in your `.env` file to set the default model and small model. +Use `GENAISCRIPT_MODEL_LARGE` and `GENAISCRIPT_MODEL_SMALL` in your `.env` file to set the default model and small model. ```txt -GENAISCRIPT_DEFAULT_MODEL=openai:gpt-4o -GENAISCRIPT_DEFAULT_SMALL_MODEL=openai:gpt-4o-mini +GENAISCRIPT_MODEL_LARGE=openai:gpt-4o +GENAISCRIPT_MODEL_SMALL=openai:gpt-4o-mini ``` ::: @@ -412,11 +429,11 @@ AZURE_OPENAI_API_CREDENTIALS=cli The types are mapped directly to their [@azure/identity](https://www.npmjs.com/package/@azure/identity) credential types: -- `cli` - `AzureCliCredential` -- `env` - `EnvironmentCredential` -- `powershell` - `AzurePowerShellCredential` -- `devcli` - `AzureDeveloperCliCredential` -- `managedidentity` - `ManagedIdentityCredential` +- `cli` - `AzureCliCredential` +- `env` - `EnvironmentCredential` +- `powershell` - `AzurePowerShellCredential` +- `devcli` - `AzureDeveloperCliCredential` +- `managedidentity` - `ManagedIdentityCredential` ### Custom token scopes @@ -1080,7 +1097,6 @@ script({ }) ``` - ### Ollama with Docker You can conviniately run Ollama in a Docker container. @@ -1097,11 +1113,10 @@ docker run -d -v ollama:/root/.ollama -p 11434:11434 --name ollama ollama/ollama docker stop ollama && docker rm ollama ``` - ## LMStudio The `lmstudio` provider connects to the [LMStudio](https://lmstudio.ai/) headless server. -and allows to run local LLMs. +and allows to run local LLMs. :::note @@ -1259,8 +1274,8 @@ This `transformers` provider runs models on device using [Hugging Face Transform The model syntax is `transformers::` where -- `repo` is the model repository on Hugging Face, -- [`dtype`](https://huggingface.co/docs/transformers.js/guides/dtypes) is the quantization type. +- `repo` is the model repository on Hugging Face, +- [`dtype`](https://huggingface.co/docs/transformers.js/guides/dtypes) is the quantization type. ```js "transformers:" script({ diff --git a/packages/core/src/connection.ts b/packages/core/src/connection.ts index 72955c4140..5b8ec2c2a8 100644 --- a/packages/core/src/connection.ts +++ b/packages/core/src/connection.ts @@ -76,15 +76,16 @@ export async function parseDefaultsFromEnv(env: Record) { if (env.GENAISCRIPT_DEFAULT_MODEL) host.modelAliases.large.model = env.GENAISCRIPT_DEFAULT_MODEL - const m = /^GENAISCRIPT_DEFAULT_(?[A-Z0-9]+)_MODEL$/i - Object.keys(env) - .map((k) => m.exec(k)?.groups.id) - .filter((id) => !!id) - .forEach((id: string) => { - id = id.toLocaleLowerCase() - const c = host.modelAliases[id] || (host.modelAliases[id] = {}) - c.model = env[`GENAISCRIPT_DEFAULT_${id}_MODEL`] - }) + const rx = + /^GENAISCRIPT(_DEFAULT)?_((?[A-Z0-9]+)_MODEL|MODEL_(?[A-Z0-9]+))$/i + for (const kv of Object.entries(env)) { + const [k, v] = kv + const m = rx.exec(k) + if (!m) continue + const id = (m.groups.id || m.groups.id2).toLocaleLowerCase() + const c = host.modelAliases[id] || (host.modelAliases[id] = {}) + c.model = v + } const t = normalizeFloat(env.GENAISCRIPT_DEFAULT_TEMPERATURE) if (!isNaN(t)) host.modelAliases.large.temperature = t } diff --git a/packages/core/src/test.ts b/packages/core/src/test.ts index 29cc258431..2ef24533b6 100644 --- a/packages/core/src/test.ts +++ b/packages/core/src/test.ts @@ -64,7 +64,7 @@ function resolveTestProvider( export function generatePromptFooConfiguration( script: PromptScript, options: { - chatInfo: ModelConnectionInfo + chatInfo: ModelConnectionInfo & ModelAliasesOptions embeddingsInfo?: ModelConnectionInfo provider?: string out?: string @@ -86,8 +86,8 @@ export function generatePromptFooConfiguration( models.push({ ...script, model: chatInfo.model, - smallModel: chatInfo.model, - visionModel: chatInfo.model, + smallModel: chatInfo.smallModel, + visionModel: chatInfo.visionModel, }) } diff --git a/packages/core/src/types/prompt_template.d.ts b/packages/core/src/types/prompt_template.d.ts index f707bd120c..0068002da9 100644 --- a/packages/core/src/types/prompt_template.d.ts +++ b/packages/core/src/types/prompt_template.d.ts @@ -186,19 +186,19 @@ type ModelVisionType = OptionsOrString< interface ModelConnectionOptions { /** - * Which LLM model to use. Use `large` for the default set of model candidates, `small` for the set of small models like gpt-4o-mini. + * Which LLM model by default or for the `large` alias. */ model?: ModelType } interface ModelAliasesOptions { /** - * @deprecated use model aliases + * Configure the `small` model alias. */ smallModel?: ModelSmallType /** - * @deprecated use model aliases + * Configure the `vision` model alias. */ visionModel?: ModelVisionType } From 33fb94ffdb1751c8be4cc660baefb7c78c4d98ba Mon Sep 17 00:00:00 2001 From: pelikhan Date: Thu, 5 Dec 2024 09:12:06 -0800 Subject: [PATCH 4/5] apply alias when resolving models --- docs/src/content/docs/reference/cli/commands.md | 1 + packages/cli/src/cli.ts | 1 + packages/cli/src/run.ts | 6 ++++++ packages/core/src/connection.ts | 11 ++++++++--- packages/core/src/models.ts | 5 ++++- packages/core/src/server/messages.ts | 1 + 6 files changed, 21 insertions(+), 4 deletions(-) diff --git a/docs/src/content/docs/reference/cli/commands.md b/docs/src/content/docs/reference/cli/commands.md index 3460c5abe9..fb36768aac 100644 --- a/docs/src/content/docs/reference/cli/commands.md +++ b/docs/src/content/docs/reference/cli/commands.md @@ -19,6 +19,7 @@ Options: -m, --model 'large' model alias (default) -sm, --small-model 'small' alias model -vm, --vision-model 'vision' alias model + -ma, --model-alias model alias as name=modelid -lp, --logprobs enable reporting token probabilities -tlp, --top-logprobs number of top logprobs (1 to 5) -ef, --excluded-files excluded files diff --git a/packages/cli/src/cli.ts b/packages/cli/src/cli.ts index f4c0fee853..de61e7427b 100644 --- a/packages/cli/src/cli.ts +++ b/packages/cli/src/cli.ts @@ -99,6 +99,7 @@ export async function cli() { .option("-m, --model ", "'large' model alias (default)") .option("-sm, --small-model ", "'small' alias model") .option("-vm, --vision-model ", "'vision' alias model") + .option("-ma, --model-alias ", "model alias as name=modelid") .option("-lp, --logprobs", "enable reporting token probabilities") .option( "-tlp, --top-logprobs ", diff --git a/packages/cli/src/run.ts b/packages/cli/src/run.ts index 45ace4ce93..db1ec3815f 100644 --- a/packages/cli/src/run.ts +++ b/packages/cli/src/run.ts @@ -87,6 +87,7 @@ import { stderr, stdout, } from "../../core/src/logging" +import { setModelAlias } from "../../core/src/connection" async function setupTraceWriting(trace: MarkdownTrace, filename: string) { logVerbose(`trace: ${filename}`) @@ -212,6 +213,11 @@ export async function runScript( if (options.smallModel) host.modelAliases.small.model = options.smallModel if (options.visionModel) host.modelAliases.vision.model = options.visionModel + for (const kv of options.modelAlias || []) { + const aliases = parseKeyValuePair(kv) + for (const [key, value] of Object.entries(aliases)) + setModelAlias(key, value) + } const fail = (msg: string, exitCode: number, url?: string) => { logError(url ? `${msg} (see ${url})` : msg) diff --git a/packages/core/src/connection.ts b/packages/core/src/connection.ts index 5b8ec2c2a8..d54d2625c7 100644 --- a/packages/core/src/connection.ts +++ b/packages/core/src/connection.ts @@ -71,6 +71,12 @@ export function findEnvVar( return undefined } +export function setModelAlias(id: string, modelid: string) { + id = id.toLowerCase() + const c = host.modelAliases[id] || (host.modelAliases[id] = {}) + c.model = modelid +} + export async function parseDefaultsFromEnv(env: Record) { // legacy if (env.GENAISCRIPT_DEFAULT_MODEL) @@ -82,9 +88,8 @@ export async function parseDefaultsFromEnv(env: Record) { const [k, v] = kv const m = rx.exec(k) if (!m) continue - const id = (m.groups.id || m.groups.id2).toLocaleLowerCase() - const c = host.modelAliases[id] || (host.modelAliases[id] = {}) - c.model = v + const id = m.groups.id || m.groups.id2 + setModelAlias(id, v) } const t = normalizeFloat(env.GENAISCRIPT_DEFAULT_TEMPERATURE) if (!isNaN(t)) host.modelAliases.large.temperature = t diff --git a/packages/core/src/models.ts b/packages/core/src/models.ts index 1b10f0c52e..1c5f8884d2 100644 --- a/packages/core/src/models.ts +++ b/packages/core/src/models.ts @@ -121,9 +121,12 @@ export async function resolveModelConnectionInfo( host.modelAliases.large.model, ...DEFAULT_MODEL_CANDIDATES, ] - } + } candidates ??= [host.modelAliases.large.model, ...DEFAULT_MODEL_CANDIDATES] + // apply model alias + m = host.modelAliases[m]?.model || m + const resolveModel = async ( model: string, resolveOptions: { withToken: boolean; reportError: boolean } diff --git a/packages/core/src/server/messages.ts b/packages/core/src/server/messages.ts index 035aa4c2d0..386ce623b1 100644 --- a/packages/core/src/server/messages.ts +++ b/packages/core/src/server/messages.ts @@ -88,6 +88,7 @@ export interface PromptScriptRunOptions { smallModel: string visionModel: string embeddingsModel: string + modelAlias: string[] csvSeparator: string cache: boolean | string cacheName: string From 771d3c4acff967bec7beb9a1b0964d7e2a5965d2 Mon Sep 17 00:00:00 2001 From: pelikhan Date: Thu, 5 Dec 2024 09:20:16 -0800 Subject: [PATCH 5/5] =?UTF-8?q?refactor:=20=E2=99=BB=EF=B8=8F=20update=20m?= =?UTF-8?q?odel=20alias=20references=20to=20runtimeHost?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- packages/cli/src/info.ts | 2 +- packages/cli/src/run.ts | 14 ++++++++------ packages/cli/src/test.ts | 17 +++++++++-------- packages/core/src/chat.ts | 8 ++++---- packages/core/src/connection.ts | 10 ++++++---- packages/core/src/expander.ts | 4 ++-- packages/core/src/host.ts | 2 +- packages/core/src/models.ts | 17 ++++++++++------- packages/core/src/promptcontext.ts | 2 +- packages/core/src/test.ts | 15 +++++++++------ packages/core/src/vectorsearch.ts | 2 +- packages/vscode/src/vshost.ts | 16 ---------------- 12 files changed, 52 insertions(+), 57 deletions(-) diff --git a/packages/cli/src/info.ts b/packages/cli/src/info.ts index 962e7d3339..086ad13c09 100644 --- a/packages/cli/src/info.ts +++ b/packages/cli/src/info.ts @@ -97,7 +97,7 @@ async function resolveScriptsConnectionInfo( // Deduplicate model connection options for (const script of scripts) { const conn: ModelConnectionOptions = { - model: script.model ?? host.modelAliases.large.model, + model: script.model ?? runtimeHost.modelAliases.large.model, } const key = JSON.stringify(conn) if (!models[key]) models[key] = conn diff --git a/packages/cli/src/run.ts b/packages/cli/src/run.ts index db1ec3815f..69524253fb 100644 --- a/packages/cli/src/run.ts +++ b/packages/cli/src/run.ts @@ -209,10 +209,11 @@ export async function runScript( const topLogprobs = normalizeInt(options.topLogprobs) if (options.json || options.yaml) overrideStdoutWithStdErr() - if (options.model) host.modelAliases.large.model = options.model - if (options.smallModel) host.modelAliases.small.model = options.smallModel + if (options.model) runtimeHost.modelAliases.large.model = options.model + if (options.smallModel) + runtimeHost.modelAliases.small.model = options.smallModel if (options.visionModel) - host.modelAliases.vision.model = options.visionModel + runtimeHost.modelAliases.vision.model = options.visionModel for (const kv of options.modelAlias || []) { const aliases = parseKeyValuePair(kv) for (const [key, value] of Object.entries(aliases)) @@ -225,8 +226,8 @@ export async function runScript( } logInfo(`genaiscript: ${scriptId}`) - Object.entries(host.modelAliases).forEach(([key, host]) => - logVerbose(` ${key}: ${host.model}`) + Object.entries(runtimeHost.modelAliases).forEach(([key, value]) => + logVerbose(` ${key}: ${value.model}`) ) if (out) { @@ -380,7 +381,8 @@ export async function runScript( maxDataRepairs, model: info.model, embeddingsModel: - options.embeddingsModel ?? host.modelAliases.embeddings.model, + options.embeddingsModel ?? + runtimeHost.modelAliases.embeddings.model, retry, retryDelay, maxDelay, diff --git a/packages/cli/src/test.ts b/packages/cli/src/test.ts index c443693a95..1d93d8d497 100644 --- a/packages/cli/src/test.ts +++ b/packages/cli/src/test.ts @@ -20,7 +20,7 @@ import { import { promptFooDriver } from "../../core/src/default_prompts" import { serializeError } from "../../core/src/error" import { parseKeyValuePairs } from "../../core/src/fence" -import { host } from "../../core/src/host" +import { host, runtimeHost } from "../../core/src/host" import { JSON5TryParse } from "../../core/src/json5" import { MarkdownTrace } from "../../core/src/trace" import { @@ -108,12 +108,13 @@ export async function runPromptScriptTests( testDelay?: string } ): Promise { - if (options.model) host.modelAliases.large.model = options.model - if (options.smallModel) host.modelAliases.small.model = options.smallModel + if (options.model) runtimeHost.modelAliases.large.model = options.model + if (options.smallModel) + runtimeHost.modelAliases.small.model = options.smallModel if (options.visionModel) - host.modelAliases.vision.model = options.visionModel - Object.entries(host.modelAliases).forEach(([key, host]) => - logVerbose(` ${key}: ${host.model}`) + runtimeHost.modelAliases.vision.model = options.visionModel + Object.entries(runtimeHost.modelAliases).forEach(([key, value]) => + logVerbose(` ${key}: ${value.model}`) ) const scripts = await listTests({ ids, ...(options || {}) }) @@ -145,12 +146,12 @@ export async function runPromptScriptTests( : script.filename.replace(GENAI_ANY_REGEX, ".promptfoo.yaml") logInfo(` ${fn}`) const { info: chatInfo } = await resolveModelConnectionInfo(script, { - model: host.modelAliases.large.model, + model: runtimeHost.modelAliases.large.model, }) if (chatInfo.error) throw new Error(chatInfo.error) let { info: embeddingsInfo } = await resolveModelConnectionInfo( script, - { model: host.modelAliases.embeddings.model } + { model: runtimeHost.modelAliases.embeddings.model } ) if (embeddingsInfo?.error) embeddingsInfo = undefined const config = generatePromptFooConfiguration(script, { diff --git a/packages/core/src/chat.ts b/packages/core/src/chat.ts index d1e72b7613..6e9e4aaf24 100644 --- a/packages/core/src/chat.ts +++ b/packages/core/src/chat.ts @@ -1,7 +1,7 @@ // cspell: disable import { MarkdownTrace } from "./trace" import { PromptImage, PromptPrediction, renderPromptNode } from "./promptdom" -import { LanguageModelConfiguration, host } from "./host" +import { LanguageModelConfiguration, host, runtimeHost } from "./host" import { GenerationOptions } from "./generation" import { dispose } from "./dispose" import { @@ -735,13 +735,13 @@ export function mergeGenerationOptions( model: runOptions?.model ?? options?.model ?? - host.modelAliases.large.model, + runtimeHost.modelAliases.large.model, temperature: - runOptions?.temperature ?? host.modelAliases.large.temperature, + runOptions?.temperature ?? runtimeHost.modelAliases.large.temperature, embeddingsModel: runOptions?.embeddingsModel ?? options?.embeddingsModel ?? - host.modelAliases.embeddings.model, + runtimeHost.modelAliases.embeddings.model, } satisfies GenerationOptions return res } diff --git a/packages/core/src/connection.ts b/packages/core/src/connection.ts index d54d2625c7..447f712f4f 100644 --- a/packages/core/src/connection.ts +++ b/packages/core/src/connection.ts @@ -39,6 +39,7 @@ import { host, LanguageModelConfiguration, AzureCredentialsType, + runtimeHost, } from "./host" import { parseModelIdentifier } from "./models" import { normalizeFloat, trimTrailingSlash } from "./util" @@ -73,14 +74,15 @@ export function findEnvVar( export function setModelAlias(id: string, modelid: string) { id = id.toLowerCase() - const c = host.modelAliases[id] || (host.modelAliases[id] = {}) + const c = + runtimeHost.modelAliases[id] || (runtimeHost.modelAliases[id] = {}) c.model = modelid } export async function parseDefaultsFromEnv(env: Record) { // legacy if (env.GENAISCRIPT_DEFAULT_MODEL) - host.modelAliases.large.model = env.GENAISCRIPT_DEFAULT_MODEL + runtimeHost.modelAliases.large.model = env.GENAISCRIPT_DEFAULT_MODEL const rx = /^GENAISCRIPT(_DEFAULT)?_((?[A-Z0-9]+)_MODEL|MODEL_(?[A-Z0-9]+))$/i @@ -92,7 +94,7 @@ export async function parseDefaultsFromEnv(env: Record) { setModelAlias(id, v) } const t = normalizeFloat(env.GENAISCRIPT_DEFAULT_TEMPERATURE) - if (!isNaN(t)) host.modelAliases.large.temperature = t + if (!isNaN(t)) runtimeHost.modelAliases.large.temperature = t } export async function parseTokenFromEnv( @@ -100,7 +102,7 @@ export async function parseTokenFromEnv( modelId: string ): Promise { const { provider, model, tag } = parseModelIdentifier( - modelId ?? host.modelAliases.large.model + modelId ?? runtimeHost.modelAliases.large.model ) const TOKEN_SUFFIX = ["_API_KEY", "_API_TOKEN", "_TOKEN", "_KEY"] const BASE_SUFFIX = ["_API_BASE", "_API_ENDPOINT", "_BASE", "_ENDPOINT"] diff --git a/packages/core/src/expander.ts b/packages/core/src/expander.ts index f9b0558b0a..8b4cfee918 100644 --- a/packages/core/src/expander.ts +++ b/packages/core/src/expander.ts @@ -25,7 +25,7 @@ import { import { importPrompt } from "./importprompt" import { parseModelIdentifier } from "./models" import { JSONSchemaStringifyToTypeScript, toStrictJSONSchema } from "./schema" -import { host } from "./host" +import { host, runtimeHost } from "./host" import { resolveSystems } from "./systems" import { GenerationOptions, GenerationStatus } from "./generation" import { AICIRequest, ChatCompletionMessageParam } from "./chattypes" @@ -194,7 +194,7 @@ export async function expandTemplate( options.temperature ?? normalizeFloat(env.vars["temperature"]) ?? template.temperature ?? - host.modelAliases.large.temperature + runtimeHost.modelAliases.large.temperature const topP = options.topP ?? normalizeFloat(env.vars["top_p"]) ?? template.topP const maxTokens = diff --git a/packages/core/src/host.ts b/packages/core/src/host.ts index 701e5c4c58..8f2ba6fd66 100644 --- a/packages/core/src/host.ts +++ b/packages/core/src/host.ts @@ -133,7 +133,6 @@ export interface Host { installFolder(): string resolvePath(...segments: string[]): string - modelAliases: ModelConfigurations getLanguageModelConfiguration( modelId: string, options?: { token?: boolean } & AbortSignalOptions & TraceOptions @@ -167,6 +166,7 @@ export interface RuntimeHost extends Host { models: ModelService workspace: Omit azureToken: AzureTokenResolver + modelAliases: ModelConfigurations readConfig(): Promise readSecret(name: string): Promise diff --git a/packages/core/src/models.ts b/packages/core/src/models.ts index 1c5f8884d2..f2babb2a85 100644 --- a/packages/core/src/models.ts +++ b/packages/core/src/models.ts @@ -11,7 +11,7 @@ import { VISION_MODEL_ID, } from "./constants" import { errorMessage } from "./error" -import { LanguageModelConfiguration, host } from "./host" +import { LanguageModelConfiguration, host, runtimeHost } from "./host" import { AbortSignalOptions, MarkdownTrace, TraceOptions } from "./trace" import { assert } from "./util" @@ -106,26 +106,29 @@ export async function resolveModelConnectionInfo( if (m === SMALL_MODEL_ID) { m = undefined candidates ??= [ - host.modelAliases.small.model, + runtimeHost.modelAliases.small.model, ...DEFAULT_SMALL_MODEL_CANDIDATES, ] } else if (m === VISION_MODEL_ID) { m = undefined candidates ??= [ - host.modelAliases.vision.model, + runtimeHost.modelAliases.vision.model, ...DEFAULT_VISION_MODEL_CANDIDATES, ] } else if (m === LARGE_MODEL_ID) { m = undefined candidates ??= [ - host.modelAliases.large.model, + runtimeHost.modelAliases.large.model, ...DEFAULT_MODEL_CANDIDATES, ] - } - candidates ??= [host.modelAliases.large.model, ...DEFAULT_MODEL_CANDIDATES] + } + candidates ??= [ + runtimeHost.modelAliases.large.model, + ...DEFAULT_MODEL_CANDIDATES, + ] // apply model alias - m = host.modelAliases[m]?.model || m + m = runtimeHost.modelAliases[m]?.model || m const resolveModel = async ( model: string, diff --git a/packages/core/src/promptcontext.ts b/packages/core/src/promptcontext.ts index c17ebf403d..96c532a7c3 100644 --- a/packages/core/src/promptcontext.ts +++ b/packages/core/src/promptcontext.ts @@ -190,7 +190,7 @@ export async function createPromptContext( searchOptions.embeddingsModel = searchOptions?.embeddingsModel ?? options?.embeddingsModel ?? - host.modelAliases.embeddings.model + runtimeHost.modelAliases.embeddings.model const key = await hash({ files, searchOptions }, { length: 12 }) const folderPath = dotGenaiscriptPath("vectors", key) const res = await vectorSearch(q, files, { diff --git a/packages/core/src/test.ts b/packages/core/src/test.ts index 2ef24533b6..7e6535a89f 100644 --- a/packages/core/src/test.ts +++ b/packages/core/src/test.ts @@ -9,7 +9,7 @@ import { VISION_MODEL_ID, } from "./constants" import { arrayify, deleteUndefinedValues } from "./util" -import { host } from "./host" +import { host, runtimeHost } from "./host" import { ModelConnectionInfo, parseModelIdentifier } from "./models" /** @@ -94,7 +94,7 @@ export function generatePromptFooConfiguration( const cli = options?.cli const transform = "output.text" - const resolveModel = (m: string) => host.modelAliases[m]?.model ?? m + const resolveModel = (m: string) => runtimeHost.modelAliases[m]?.model ?? m const testProvider = deleteUndefinedValues({ text: resolveTestProvider(chatInfo, "chat"), @@ -112,14 +112,17 @@ export function generatePromptFooConfiguration( // Map model options to providers providers: models .map(({ model, smallModel, visionModel, temperature, topP }) => ({ - model: resolveModel(model) ?? host.modelAliases.large.model, + model: + resolveModel(model) ?? runtimeHost.modelAliases.large.model, smallModel: - resolveModel(smallModel) ?? host.modelAliases.small.model, + resolveModel(smallModel) ?? + runtimeHost.modelAliases.small.model, visionModel: - resolveModel(visionModel) ?? host.modelAliases.vision.model, + resolveModel(visionModel) ?? + runtimeHost.modelAliases.vision.model, temperature: !isNaN(temperature) ? temperature - : host.modelAliases.temperature, + : runtimeHost.modelAliases.temperature, top_p: topP, })) .map(({ model, smallModel, visionModel, temperature, top_p }) => ({ diff --git a/packages/core/src/vectorsearch.ts b/packages/core/src/vectorsearch.ts index 77184baac2..466a5e9e0d 100644 --- a/packages/core/src/vectorsearch.ts +++ b/packages/core/src/vectorsearch.ts @@ -198,7 +198,7 @@ export async function vectorSearch( { token: true, candidates: [ - host.modelAliases.embeddings.model, + runtimeHost.modelAliases.embeddings.model, ...DEFAULT_EMBEDDINGS_MODEL_CANDIDATES, ], } diff --git a/packages/vscode/src/vshost.ts b/packages/vscode/src/vshost.ts index c154c35bbe..5920f792a5 100644 --- a/packages/vscode/src/vshost.ts +++ b/packages/vscode/src/vshost.ts @@ -6,21 +6,11 @@ import { ExtensionState } from "./state" import { Utils } from "vscode-uri" import { checkFileExists, readFileText } from "./fs" import { filterGitIgnore } from "../../core/src/gitignore" -import { parseDefaultsFromEnv } from "../../core/src/connection" -import { - DEFAULT_EMBEDDINGS_MODEL, - DEFAULT_MODEL, - DEFAULT_SMALL_MODEL, - DEFAULT_TEMPERATURE, - DEFAULT_VISION_MODEL, -} from "../../core/src/constants" -import { dotEnvTryParse } from "../../core/src/dotenv" import { setHost, LanguageModelConfiguration, LogLevel, Host, - ModelConfigurations, } from "../../core/src/host" import { TraceOptions, AbortSignalOptions } from "../../core/src/trace" import { arrayify } from "../../core/src/util" @@ -31,12 +21,6 @@ export class VSCodeHost extends EventTarget implements Host { userState: any = {} readonly path = createVSPath() readonly server: TerminalServerManager - readonly modelAliases: ModelConfigurations = { - large: { model: DEFAULT_MODEL }, - small: { model: DEFAULT_SMALL_MODEL }, - vision: { model: DEFAULT_VISION_MODEL }, - embeddings: { model: DEFAULT_EMBEDDINGS_MODEL }, - } constructor(readonly state: ExtensionState) { super() setHost(this)