diff --git a/docs/src/content/docs/getting-started/configuration.mdx b/docs/src/content/docs/getting-started/configuration.mdx
index 7f5ec233e..d1584a3e1 100644
--- a/docs/src/content/docs/getting-started/configuration.mdx
+++ b/docs/src/content/docs/getting-started/configuration.mdx
@@ -40,9 +40,9 @@ script({
})
```
-### Large and small models
+### Large, small, vision models
-You can also use the `small` and `large` aliases to use the default configured small and large models.
+You can also use the `small`, `large`, `vision` aliases to use the default configured small, large and vision-enabled models.
Large models are typically in the OpenAI gpt-4 reasoning range and can be used for more complex tasks.
Small models are in the OpenAI gpt-4o-mini range, and are useful for quick and simple tasks.
@@ -60,11 +60,28 @@ The model can also be overridden from the [cli run command](/genaiscript/referen
genaiscript run ... --model largemodelid --small-model smallmodelid
```
-or by adding the `GENAISCRIPT_DEFAULT_MODEL` and `GENAISCRIPT_DEFAULT_SMALL_MODEL` environment variables.
+or by adding the `GENAISCRIPT_LARGE_MODEL` and `GENAISCRIPT_SMALL_MODEL` environment variables.
```txt title=".env"
-GENAISCRIPT_DEFAULT_MODEL="azure_serverless:..."
-GENAISCRIPT_DEFAULT_SMALL_MODEL="azure_serverless:..."
+GENAISCRIPT_MODEL_LARGE="azure_serverless:..."
+GENAISCRIPT_MODEL_SMALL="azure_serverless:..."
+GENAISCRIPT_MODEL_VISION="azure_serverless:..."
+```
+
+### Model aliases
+
+In fact, you can define any alias for your model (only alphanumeric characters are allowed)
+through environment variables of the name `GENAISCRIPT_MODEL_ALIAS`
+where `ALIAS` is the alias you want to use.
+
+```txt title=".env"
+GENAISCRIPT_MODEL_TINY=...
+```
+
+Model aliases are always lowercased when used in the script.
+
+```js
+script({ model: "tiny" })
```
## `.env` file
@@ -93,8 +110,8 @@ Create a `.env` file in the root of your project.
-- .gitignore
-- **.env**
+- .gitignore
+- **.env**
@@ -127,19 +144,19 @@ the `.env` file will appear grayed out in Visual Studio Code.
You can specify a custom `.env` file location through the CLI or an environment variable.
-- by adding the `--env ` argument to the CLI.
+- by adding the `--env ` argument to the CLI.
```sh "--env .env.local"
npx genaiscript ... --env .env.local
```
-- by setting the `GENAISCRIPT_ENV_FILE` environment variable.
+- by setting the `GENAISCRIPT_ENV_FILE` environment variable.
```sh
GENAISCRIPT_ENV_FILE=".env.local" npx genaiscript ...
```
-- by specifying the `.env` file location in a [configuration file](/genaiscript/reference/configuration-files).
+- by specifying the `.env` file location in a [configuration file](/genaiscript/reference/configuration-files).
```yaml title="~/genaiscript.config.yaml"
envFile: ~/.env.genaiscript
@@ -152,13 +169,13 @@ of the genaiscript process with the configuration values.
Here are some common examples:
-- Using bash syntax
+- Using bash syntax
```sh
OPENAI_API_KEY="value" npx --yes genaiscript run ...
```
-- GitHub Action configuration
+- GitHub Action configuration
```yaml title=".github/workflows/genaiscript.yml"
run: npx --yes genaiscript run ...
@@ -219,11 +236,11 @@ script({
:::tip[Default Model Configuration]
-Use `GENAISCRIPT_DEFAULT_MODEL` and `GENAISCRIPT_DEFAULT_SMALL_MODEL` in your `.env` file to set the default model and small model.
+Use `GENAISCRIPT_MODEL_LARGE` and `GENAISCRIPT_MODEL_SMALL` in your `.env` file to set the default model and small model.
```txt
-GENAISCRIPT_DEFAULT_MODEL=openai:gpt-4o
-GENAISCRIPT_DEFAULT_SMALL_MODEL=openai:gpt-4o-mini
+GENAISCRIPT_MODEL_LARGE=openai:gpt-4o
+GENAISCRIPT_MODEL_SMALL=openai:gpt-4o-mini
```
:::
@@ -412,11 +429,11 @@ AZURE_OPENAI_API_CREDENTIALS=cli
The types are mapped directly to their [@azure/identity](https://www.npmjs.com/package/@azure/identity) credential types:
-- `cli` - `AzureCliCredential`
-- `env` - `EnvironmentCredential`
-- `powershell` - `AzurePowerShellCredential`
-- `devcli` - `AzureDeveloperCliCredential`
-- `managedidentity` - `ManagedIdentityCredential`
+- `cli` - `AzureCliCredential`
+- `env` - `EnvironmentCredential`
+- `powershell` - `AzurePowerShellCredential`
+- `devcli` - `AzureDeveloperCliCredential`
+- `managedidentity` - `ManagedIdentityCredential`
### Custom token scopes
@@ -1080,7 +1097,6 @@ script({
})
```
-
### Ollama with Docker
You can conviniately run Ollama in a Docker container.
@@ -1097,11 +1113,10 @@ docker run -d -v ollama:/root/.ollama -p 11434:11434 --name ollama ollama/ollama
docker stop ollama && docker rm ollama
```
-
## LMStudio
The `lmstudio` provider connects to the [LMStudio](https://lmstudio.ai/) headless server.
-and allows to run local LLMs.
+and allows to run local LLMs.
:::note
@@ -1259,8 +1274,8 @@ This `transformers` provider runs models on device using [Hugging Face Transform
The model syntax is `transformers::` where
-- `repo` is the model repository on Hugging Face,
-- [`dtype`](https://huggingface.co/docs/transformers.js/guides/dtypes) is the quantization type.
+- `repo` is the model repository on Hugging Face,
+- [`dtype`](https://huggingface.co/docs/transformers.js/guides/dtypes) is the quantization type.
```js "transformers:"
script({
diff --git a/docs/src/content/docs/reference/cli/commands.md b/docs/src/content/docs/reference/cli/commands.md
index 3460c5abe..fb36768aa 100644
--- a/docs/src/content/docs/reference/cli/commands.md
+++ b/docs/src/content/docs/reference/cli/commands.md
@@ -19,6 +19,7 @@ Options:
-m, --model 'large' model alias (default)
-sm, --small-model 'small' alias model
-vm, --vision-model 'vision' alias model
+ -ma, --model-alias model alias as name=modelid
-lp, --logprobs enable reporting token probabilities
-tlp, --top-logprobs number of top logprobs (1 to 5)
-ef, --excluded-files excluded files
diff --git a/packages/cli/src/cli.ts b/packages/cli/src/cli.ts
index f4c0fee85..de61e7427 100644
--- a/packages/cli/src/cli.ts
+++ b/packages/cli/src/cli.ts
@@ -99,6 +99,7 @@ export async function cli() {
.option("-m, --model ", "'large' model alias (default)")
.option("-sm, --small-model ", "'small' alias model")
.option("-vm, --vision-model ", "'vision' alias model")
+ .option("-ma, --model-alias ", "model alias as name=modelid")
.option("-lp, --logprobs", "enable reporting token probabilities")
.option(
"-tlp, --top-logprobs ",
diff --git a/packages/cli/src/info.ts b/packages/cli/src/info.ts
index 78009a990..086ad13c0 100644
--- a/packages/cli/src/info.ts
+++ b/packages/cli/src/info.ts
@@ -84,20 +84,20 @@ export async function envInfo(
/**
* Resolves connection information for script templates by deduplicating model options.
- * @param templates - Array of model connection options to resolve.
+ * @param scripts - Array of model connection options to resolve.
* @param options - Configuration options, including whether to show tokens.
* @returns A promise that resolves to an array of model connection information.
*/
async function resolveScriptsConnectionInfo(
- templates: ModelConnectionOptions[],
+ scripts: ModelConnectionOptions[],
options?: { token?: boolean }
): Promise {
const models: Record = {}
// Deduplicate model connection options
- for (const template of templates) {
+ for (const script of scripts) {
const conn: ModelConnectionOptions = {
- model: template.model ?? host.defaultModelOptions.model,
+ model: script.model ?? runtimeHost.modelAliases.large.model,
}
const key = JSON.stringify(conn)
if (!models[key]) models[key] = conn
diff --git a/packages/cli/src/nodehost.ts b/packages/cli/src/nodehost.ts
index c0a235cb9..aab145441 100644
--- a/packages/cli/src/nodehost.ts
+++ b/packages/cli/src/nodehost.ts
@@ -44,6 +44,7 @@ import {
setRuntimeHost,
ResponseStatus,
AzureTokenResolver,
+ ModelConfigurations,
} from "../../core/src/host"
import { AbortSignalOptions, TraceOptions } from "../../core/src/trace"
import { logError, logVerbose } from "../../core/src/util"
@@ -139,14 +140,11 @@ export class NodeHost implements RuntimeHost {
readonly workspace = createFileSystem()
readonly containers = new DockerManager()
readonly browsers = new BrowserManager()
- readonly defaultModelOptions = {
- model: DEFAULT_MODEL,
- smallModel: DEFAULT_SMALL_MODEL,
- visionModel: DEFAULT_VISION_MODEL,
- temperature: DEFAULT_TEMPERATURE,
- }
- readonly defaultEmbeddingsModelOptions = {
- embeddingsModel: DEFAULT_EMBEDDINGS_MODEL,
+ readonly modelAliases: ModelConfigurations = {
+ large: { model: DEFAULT_MODEL },
+ small: { model: DEFAULT_SMALL_MODEL },
+ vision: { model: DEFAULT_VISION_MODEL },
+ embeddings: { model: DEFAULT_EMBEDDINGS_MODEL },
}
readonly userInputQueue = new PLimitPromiseQueue(1)
readonly azureToken: AzureTokenResolver
diff --git a/packages/cli/src/run.ts b/packages/cli/src/run.ts
index ca59392ed..69524253f 100644
--- a/packages/cli/src/run.ts
+++ b/packages/cli/src/run.ts
@@ -87,6 +87,7 @@ import {
stderr,
stdout,
} from "../../core/src/logging"
+import { setModelAlias } from "../../core/src/connection"
async function setupTraceWriting(trace: MarkdownTrace, filename: string) {
logVerbose(`trace: ${filename}`)
@@ -208,11 +209,16 @@ export async function runScript(
const topLogprobs = normalizeInt(options.topLogprobs)
if (options.json || options.yaml) overrideStdoutWithStdErr()
- if (options.model) host.defaultModelOptions.model = options.model
+ if (options.model) runtimeHost.modelAliases.large.model = options.model
if (options.smallModel)
- host.defaultModelOptions.smallModel = options.smallModel
+ runtimeHost.modelAliases.small.model = options.smallModel
if (options.visionModel)
- host.defaultModelOptions.visionModel = options.visionModel
+ runtimeHost.modelAliases.vision.model = options.visionModel
+ for (const kv of options.modelAlias || []) {
+ const aliases = parseKeyValuePair(kv)
+ for (const [key, value] of Object.entries(aliases))
+ setModelAlias(key, value)
+ }
const fail = (msg: string, exitCode: number, url?: string) => {
logError(url ? `${msg} (see ${url})` : msg)
@@ -220,9 +226,9 @@ export async function runScript(
}
logInfo(`genaiscript: ${scriptId}`)
- logVerbose(` large : ${host.defaultModelOptions.model}`)
- logVerbose(` small : ${host.defaultModelOptions.smallModel}`)
- logVerbose(` vision: ${host.defaultModelOptions.visionModel}`)
+ Object.entries(runtimeHost.modelAliases).forEach(([key, value]) =>
+ logVerbose(` ${key}: ${value.model}`)
+ )
if (out) {
if (removeOut) await emptyDir(out)
@@ -376,7 +382,7 @@ export async function runScript(
model: info.model,
embeddingsModel:
options.embeddingsModel ??
- host.defaultEmbeddingsModelOptions.embeddingsModel,
+ runtimeHost.modelAliases.embeddings.model,
retry,
retryDelay,
maxDelay,
diff --git a/packages/cli/src/test.ts b/packages/cli/src/test.ts
index 39d085087..1d93d8d49 100644
--- a/packages/cli/src/test.ts
+++ b/packages/cli/src/test.ts
@@ -20,7 +20,7 @@ import {
import { promptFooDriver } from "../../core/src/default_prompts"
import { serializeError } from "../../core/src/error"
import { parseKeyValuePairs } from "../../core/src/fence"
-import { host } from "../../core/src/host"
+import { host, runtimeHost } from "../../core/src/host"
import { JSON5TryParse } from "../../core/src/json5"
import { MarkdownTrace } from "../../core/src/trace"
import {
@@ -51,7 +51,7 @@ import { filterScripts } from "../../core/src/ast"
* @param m - The string representation of the model specification.
* @returns A ModelOptions object with model, temperature, and topP fields if applicable.
*/
-function parseModelSpec(m: string): ModelOptions {
+function parseModelSpec(m: string): ModelOptions & ModelAliasesOptions {
const values = m
.split(/&/g)
.map((kv) => kv.split("=", 2))
@@ -108,14 +108,13 @@ export async function runPromptScriptTests(
testDelay?: string
}
): Promise {
- if (options.model) host.defaultModelOptions.model = options.model
+ if (options.model) runtimeHost.modelAliases.large.model = options.model
if (options.smallModel)
- host.defaultModelOptions.smallModel = options.smallModel
+ runtimeHost.modelAliases.small.model = options.smallModel
if (options.visionModel)
- host.defaultModelOptions.visionModel = options.visionModel
-
- logVerbose(
- `model: ${host.defaultModelOptions.model}, small model: ${host.defaultModelOptions.smallModel}, vision model: ${host.defaultModelOptions.visionModel}`
+ runtimeHost.modelAliases.vision.model = options.visionModel
+ Object.entries(runtimeHost.modelAliases).forEach(([key, value]) =>
+ logVerbose(` ${key}: ${value.model}`)
)
const scripts = await listTests({ ids, ...(options || {}) })
@@ -147,12 +146,12 @@ export async function runPromptScriptTests(
: script.filename.replace(GENAI_ANY_REGEX, ".promptfoo.yaml")
logInfo(` ${fn}`)
const { info: chatInfo } = await resolveModelConnectionInfo(script, {
- model: host.defaultModelOptions.model,
+ model: runtimeHost.modelAliases.large.model,
})
if (chatInfo.error) throw new Error(chatInfo.error)
let { info: embeddingsInfo } = await resolveModelConnectionInfo(
script,
- { model: host.defaultEmbeddingsModelOptions.embeddingsModel }
+ { model: runtimeHost.modelAliases.embeddings.model }
)
if (embeddingsInfo?.error) embeddingsInfo = undefined
const config = generatePromptFooConfiguration(script, {
diff --git a/packages/core/src/chat.ts b/packages/core/src/chat.ts
index c842278d4..6e9e4aaf2 100644
--- a/packages/core/src/chat.ts
+++ b/packages/core/src/chat.ts
@@ -1,7 +1,7 @@
// cspell: disable
import { MarkdownTrace } from "./trace"
import { PromptImage, PromptPrediction, renderPromptNode } from "./promptdom"
-import { LanguageModelConfiguration, host } from "./host"
+import { LanguageModelConfiguration, host, runtimeHost } from "./host"
import { GenerationOptions } from "./generation"
import { dispose } from "./dispose"
import {
@@ -735,21 +735,13 @@ export function mergeGenerationOptions(
model:
runOptions?.model ??
options?.model ??
- host.defaultModelOptions.model,
- smallModel:
- runOptions?.smallModel ??
- options?.smallModel ??
- host.defaultModelOptions.smallModel,
- visionModel:
- runOptions?.visionModel ??
- options?.visionModel ??
- host.defaultModelOptions.visionModel,
+ runtimeHost.modelAliases.large.model,
temperature:
- runOptions?.temperature ?? host.defaultModelOptions.temperature,
+ runOptions?.temperature ?? runtimeHost.modelAliases.large.temperature,
embeddingsModel:
runOptions?.embeddingsModel ??
options?.embeddingsModel ??
- host.defaultEmbeddingsModelOptions.embeddingsModel,
+ runtimeHost.modelAliases.embeddings.model,
} satisfies GenerationOptions
return res
}
@@ -803,8 +795,8 @@ export async function executeChatSession(
): Promise {
const {
trace,
- model = host.defaultModelOptions.model,
- temperature = host.defaultModelOptions.temperature,
+ model,
+ temperature,
topP,
maxTokens,
seed,
@@ -815,6 +807,7 @@ export async function executeChatSession(
choices,
topLogprobs,
} = genOptions
+ assert(!!model, "model is required")
const top_logprobs = genOptions.topLogprobs > 0 ? topLogprobs : undefined
const logprobs = genOptions.logprobs || top_logprobs > 0 ? true : undefined
traceLanguageModelConnection(trace, genOptions, connectionToken)
@@ -863,7 +856,7 @@ export async function executeChatSession(
model,
choices
)
- req = {
+ req = deleteUndefinedValues({
model,
temperature: temperature,
top_p: topP,
@@ -894,7 +887,7 @@ export async function executeChatSession(
},
}
: undefined,
- }
+ })
if (/^o1/i.test(model)) {
req.max_completion_tokens = maxTokens
delete req.max_tokens
diff --git a/packages/core/src/connection.ts b/packages/core/src/connection.ts
index 1cd7873d6..447f712f4 100644
--- a/packages/core/src/connection.ts
+++ b/packages/core/src/connection.ts
@@ -39,6 +39,7 @@ import {
host,
LanguageModelConfiguration,
AzureCredentialsType,
+ runtimeHost,
} from "./host"
import { parseModelIdentifier } from "./models"
import { normalizeFloat, trimTrailingSlash } from "./util"
@@ -71,20 +72,29 @@ export function findEnvVar(
return undefined
}
+export function setModelAlias(id: string, modelid: string) {
+ id = id.toLowerCase()
+ const c =
+ runtimeHost.modelAliases[id] || (runtimeHost.modelAliases[id] = {})
+ c.model = modelid
+}
+
export async function parseDefaultsFromEnv(env: Record) {
+ // legacy
if (env.GENAISCRIPT_DEFAULT_MODEL)
- host.defaultModelOptions.model = env.GENAISCRIPT_DEFAULT_MODEL
- if (env.GENAISCRIPT_DEFAULT_SMALL_MODEL)
- host.defaultModelOptions.smallModel =
- env.GENAISCRIPT_DEFAULT_SMALL_MODEL
- if (env.GENAISCRIPT_DEFAULT_VISION_MODEL)
- host.defaultModelOptions.visionModel =
- env.GENAISCRIPT_DEFAULT_VISION_MODEL
+ runtimeHost.modelAliases.large.model = env.GENAISCRIPT_DEFAULT_MODEL
+
+ const rx =
+ /^GENAISCRIPT(_DEFAULT)?_((?[A-Z0-9]+)_MODEL|MODEL_(?[A-Z0-9]+))$/i
+ for (const kv of Object.entries(env)) {
+ const [k, v] = kv
+ const m = rx.exec(k)
+ if (!m) continue
+ const id = m.groups.id || m.groups.id2
+ setModelAlias(id, v)
+ }
const t = normalizeFloat(env.GENAISCRIPT_DEFAULT_TEMPERATURE)
- if (!isNaN(t)) host.defaultModelOptions.temperature = t
- if (env.GENAISCRIPT_DEFAULT_EMBEDDINGS_MODEL)
- host.defaultEmbeddingsModelOptions.embeddingsModel =
- env.GENAISCRIPT_DEFAULT_EMBEDDINGS_MODEL
+ if (!isNaN(t)) runtimeHost.modelAliases.large.temperature = t
}
export async function parseTokenFromEnv(
@@ -92,7 +102,7 @@ export async function parseTokenFromEnv(
modelId: string
): Promise {
const { provider, model, tag } = parseModelIdentifier(
- modelId ?? host.defaultModelOptions.model
+ modelId ?? runtimeHost.modelAliases.large.model
)
const TOKEN_SUFFIX = ["_API_KEY", "_API_TOKEN", "_TOKEN", "_KEY"]
const BASE_SUFFIX = ["_API_BASE", "_API_ENDPOINT", "_BASE", "_ENDPOINT"]
diff --git a/packages/core/src/encoders.ts b/packages/core/src/encoders.ts
index 5c87fb053..0cebe8f94 100644
--- a/packages/core/src/encoders.ts
+++ b/packages/core/src/encoders.ts
@@ -18,7 +18,7 @@ export async function resolveTokenEncoder(
): Promise {
const { disableFallback } = options || {}
// Parse the model identifier to extract the model information
- if (!modelId) modelId = runtimeHost.defaultModelOptions.model
+ if (!modelId) modelId = runtimeHost.modelAliases.large.model
const { model } = parseModelIdentifier(modelId)
const module = model.toLowerCase() // Assign model to module for dynamic import path
diff --git a/packages/core/src/expander.ts b/packages/core/src/expander.ts
index c1a60bd09..8b4cfee91 100644
--- a/packages/core/src/expander.ts
+++ b/packages/core/src/expander.ts
@@ -25,7 +25,7 @@ import {
import { importPrompt } from "./importprompt"
import { parseModelIdentifier } from "./models"
import { JSONSchemaStringifyToTypeScript, toStrictJSONSchema } from "./schema"
-import { host } from "./host"
+import { host, runtimeHost } from "./host"
import { resolveSystems } from "./systems"
import { GenerationOptions, GenerationStatus } from "./generation"
import { AICIRequest, ChatCompletionMessageParam } from "./chattypes"
@@ -194,7 +194,7 @@ export async function expandTemplate(
options.temperature ??
normalizeFloat(env.vars["temperature"]) ??
template.temperature ??
- host.defaultModelOptions.temperature
+ runtimeHost.modelAliases.large.temperature
const topP =
options.topP ?? normalizeFloat(env.vars["top_p"]) ?? template.topP
const maxTokens =
diff --git a/packages/core/src/git.ts b/packages/core/src/git.ts
index 389bf35b3..5fd062837 100644
--- a/packages/core/src/git.ts
+++ b/packages/core/src/git.ts
@@ -316,7 +316,7 @@ export class GitClient implements Git {
if (!nameOnly && llmify) {
res = llmifyDiff(res)
const { encode: encoder } = await resolveTokenEncoder(
- runtimeHost.defaultModelOptions.model || DEFAULT_MODEL
+ runtimeHost.modelAliases.large.model
)
const tokens = estimateTokens(res, encoder)
if (tokens > maxTokensFullDiff)
@@ -329,7 +329,7 @@ ${truncateTextToTokens(res, maxTokensFullDiff, encoder)}
## Files
${await this.diff({ ...options, nameOnly: true })}
`
- }
+ }
return res
}
diff --git a/packages/core/src/globals.ts b/packages/core/src/globals.ts
index 025287579..4b93bbaad 100644
--- a/packages/core/src/globals.ts
+++ b/packages/core/src/globals.ts
@@ -68,7 +68,7 @@ export function installGlobals() {
// Freeze XML utilities
glb.XML = Object.freeze({
- parse: XMLParse, // Parse XML string to objects
+ parse: XMLParse, // Parse XML string to objects
})
// Freeze Markdown utilities with frontmatter operations
@@ -124,14 +124,14 @@ export function installGlobals() {
resolve: resolveTokenEncoder,
count: async (text, options) => {
const { encode: encoder } = await resolveTokenEncoder(
- options?.model || runtimeHost.defaultModelOptions.model
+ options?.model || runtimeHost.modelAliases.large.model
)
const c = await estimateTokens(text, encoder)
return c
},
truncate: async (text, maxTokens, options) => {
const { encode: encoder } = await resolveTokenEncoder(
- options?.model || runtimeHost.defaultModelOptions.model
+ options?.model || runtimeHost.modelAliases.large.model
)
return await truncateTextToTokens(text, maxTokens, encoder, options)
},
diff --git a/packages/core/src/host.ts b/packages/core/src/host.ts
index 67d1e9b0c..8f2ba6fd6 100644
--- a/packages/core/src/host.ts
+++ b/packages/core/src/host.ts
@@ -113,6 +113,15 @@ export interface AzureTokenResolver {
): Promise
}
+export type ModelConfiguration = Pick
+
+export type ModelConfigurations = {
+ large: ModelConfiguration
+ small: ModelConfiguration
+ vision: ModelConfiguration
+ embeddings: ModelConfiguration
+} & Record
+
export interface Host {
userState: any
server: ServerManager
@@ -124,13 +133,6 @@ export interface Host {
installFolder(): string
resolvePath(...segments: string[]): string
- // read a secret from the environment or a .env file
- defaultModelOptions: Required<
- Pick
- >
- defaultEmbeddingsModelOptions: Required<
- Pick
- >
getLanguageModelConfiguration(
modelId: string,
options?: { token?: boolean } & AbortSignalOptions & TraceOptions
@@ -164,6 +166,7 @@ export interface RuntimeHost extends Host {
models: ModelService
workspace: Omit
azureToken: AzureTokenResolver
+ modelAliases: ModelConfigurations
readConfig(): Promise
readSecret(name: string): Promise
diff --git a/packages/core/src/models.ts b/packages/core/src/models.ts
index e67c9cd57..f2babb2a8 100644
--- a/packages/core/src/models.ts
+++ b/packages/core/src/models.ts
@@ -11,7 +11,7 @@ import {
VISION_MODEL_ID,
} from "./constants"
import { errorMessage } from "./error"
-import { LanguageModelConfiguration, host } from "./host"
+import { LanguageModelConfiguration, host, runtimeHost } from "./host"
import { AbortSignalOptions, MarkdownTrace, TraceOptions } from "./trace"
import { assert } from "./util"
@@ -106,23 +106,29 @@ export async function resolveModelConnectionInfo(
if (m === SMALL_MODEL_ID) {
m = undefined
candidates ??= [
- host.defaultModelOptions.smallModel,
+ runtimeHost.modelAliases.small.model,
...DEFAULT_SMALL_MODEL_CANDIDATES,
]
} else if (m === VISION_MODEL_ID) {
m = undefined
candidates ??= [
- host.defaultModelOptions.visionModel,
+ runtimeHost.modelAliases.vision.model,
...DEFAULT_VISION_MODEL_CANDIDATES,
]
} else if (m === LARGE_MODEL_ID) {
m = undefined
candidates ??= [
- host.defaultModelOptions.model,
+ runtimeHost.modelAliases.large.model,
...DEFAULT_MODEL_CANDIDATES,
]
}
- candidates ??= [host.defaultModelOptions.model, ...DEFAULT_MODEL_CANDIDATES]
+ candidates ??= [
+ runtimeHost.modelAliases.large.model,
+ ...DEFAULT_MODEL_CANDIDATES,
+ ]
+
+ // apply model alias
+ m = runtimeHost.modelAliases[m]?.model || m
const resolveModel = async (
model: string,
diff --git a/packages/core/src/promptcontext.ts b/packages/core/src/promptcontext.ts
index 63e1a9b64..96c532a7c 100644
--- a/packages/core/src/promptcontext.ts
+++ b/packages/core/src/promptcontext.ts
@@ -190,7 +190,7 @@ export async function createPromptContext(
searchOptions.embeddingsModel =
searchOptions?.embeddingsModel ??
options?.embeddingsModel ??
- host.defaultEmbeddingsModelOptions.embeddingsModel
+ runtimeHost.modelAliases.embeddings.model
const key = await hash({ files, searchOptions }, { length: 12 })
const folderPath = dotGenaiscriptPath("vectors", key)
const res = await vectorSearch(q, files, {
@@ -213,8 +213,9 @@ export async function createPromptContext(
// Define the host for executing commands, browsing, and other operations
const promptHost: PromptHost = Object.freeze({
- fetch: (url, options) => fetch(url, {...(options || {}), trace }),
- fetchText: (url, options) => fetchText(url, {...(options || {}), trace }),
+ fetch: (url, options) => fetch(url, { ...(options || {}), trace }),
+ fetchText: (url, options) =>
+ fetchText(url, { ...(options || {}), trace }),
resolveLanguageModel: async (modelId) => {
const { configuration } = await resolveModelConnectionInfo(
{ model: modelId },
diff --git a/packages/core/src/server/messages.ts b/packages/core/src/server/messages.ts
index 035aa4c2d..386ce623b 100644
--- a/packages/core/src/server/messages.ts
+++ b/packages/core/src/server/messages.ts
@@ -88,6 +88,7 @@ export interface PromptScriptRunOptions {
smallModel: string
visionModel: string
embeddingsModel: string
+ modelAlias: string[]
csvSeparator: string
cache: boolean | string
cacheName: string
diff --git a/packages/core/src/test.ts b/packages/core/src/test.ts
index 227cb1900..7e6535a89 100644
--- a/packages/core/src/test.ts
+++ b/packages/core/src/test.ts
@@ -9,7 +9,7 @@ import {
VISION_MODEL_ID,
} from "./constants"
import { arrayify, deleteUndefinedValues } from "./util"
-import { host } from "./host"
+import { host, runtimeHost } from "./host"
import { ModelConnectionInfo, parseModelIdentifier } from "./models"
/**
@@ -64,12 +64,12 @@ function resolveTestProvider(
export function generatePromptFooConfiguration(
script: PromptScript,
options: {
- chatInfo: ModelConnectionInfo
+ chatInfo: ModelConnectionInfo & ModelAliasesOptions
embeddingsInfo?: ModelConnectionInfo
provider?: string
out?: string
cli?: string
- models?: ModelOptions[]
+ models?: (ModelOptions & ModelAliasesOptions)[]
}
) {
// Destructure options with default values
@@ -94,14 +94,7 @@ export function generatePromptFooConfiguration(
const cli = options?.cli
const transform = "output.text"
- const resolveModel = (m: string) =>
- m === SMALL_MODEL_ID
- ? host.defaultModelOptions.smallModel
- : m === VISION_MODEL_ID
- ? host.defaultModelOptions.visionModel
- : m === LARGE_MODEL_ID
- ? host.defaultModelOptions.model
- : m
+ const resolveModel = (m: string) => runtimeHost.modelAliases[m]?.model ?? m
const testProvider = deleteUndefinedValues({
text: resolveTestProvider(chatInfo, "chat"),
@@ -119,16 +112,17 @@ export function generatePromptFooConfiguration(
// Map model options to providers
providers: models
.map(({ model, smallModel, visionModel, temperature, topP }) => ({
- model: resolveModel(model) ?? host.defaultModelOptions.model,
+ model:
+ resolveModel(model) ?? runtimeHost.modelAliases.large.model,
smallModel:
resolveModel(smallModel) ??
- host.defaultModelOptions.smallModel,
+ runtimeHost.modelAliases.small.model,
visionModel:
resolveModel(visionModel) ??
- host.defaultModelOptions.visionModel,
+ runtimeHost.modelAliases.vision.model,
temperature: !isNaN(temperature)
? temperature
- : host.defaultModelOptions.temperature,
+ : runtimeHost.modelAliases.temperature,
top_p: topP,
}))
.map(({ model, smallModel, visionModel, temperature, top_p }) => ({
diff --git a/packages/core/src/testhost.ts b/packages/core/src/testhost.ts
index e09a601a2..e81a468d9 100644
--- a/packages/core/src/testhost.ts
+++ b/packages/core/src/testhost.ts
@@ -15,6 +15,7 @@ import {
setRuntimeHost,
RuntimeHost,
AzureTokenResolver,
+ ModelConfigurations,
} from "./host"
import { TraceOptions } from "./trace"
import {
@@ -70,15 +71,11 @@ export class TestHost implements RuntimeHost {
azureToken: AzureTokenResolver = undefined
// Default options for language models
- readonly defaultModelOptions = {
- model: DEFAULT_MODEL,
- smallModel: DEFAULT_SMALL_MODEL,
- visionModel: DEFAULT_VISION_MODEL,
- temperature: DEFAULT_TEMPERATURE,
- }
- // Default options for embeddings models
- readonly defaultEmbeddingsModelOptions = {
- embeddingsModel: DEFAULT_EMBEDDINGS_MODEL,
+ readonly modelAliases: ModelConfigurations = {
+ large: { model: DEFAULT_MODEL },
+ small: { model: DEFAULT_SMALL_MODEL },
+ vision: { model: DEFAULT_VISION_MODEL },
+ embeddings: { model: DEFAULT_EMBEDDINGS_MODEL },
}
// Static method to set this class as the runtime host
diff --git a/packages/core/src/types/prompt_template.d.ts b/packages/core/src/types/prompt_template.d.ts
index e8b17d139..0068002da 100644
--- a/packages/core/src/types/prompt_template.d.ts
+++ b/packages/core/src/types/prompt_template.d.ts
@@ -186,20 +186,19 @@ type ModelVisionType = OptionsOrString<
interface ModelConnectionOptions {
/**
- * Which LLM model to use. Use `large` for the default set of model candidates, `small` for the set of small models like gpt-4o-mini.
+ * Which LLM model by default or for the `large` alias.
*/
model?: ModelType
+}
+interface ModelAliasesOptions {
/**
- * Which LLM model to use for the "small" model.
- *
- * @default gpt-4
- * @example gpt-4
+ * Configure the `small` model alias.
*/
smallModel?: ModelSmallType
/**
- * Which LLM to use for the "vision" model.
+ * Configure the `vision` model alias.
*/
visionModel?: ModelVisionType
}
@@ -446,6 +445,7 @@ interface ContentSafetyOptions {
interface PromptScript
extends PromptLike,
ModelOptions,
+ ModelAliasesOptions,
PromptSystemOptions,
EmbeddingsModelOptions,
ContentSafetyOptions,
diff --git a/packages/core/src/vectorsearch.ts b/packages/core/src/vectorsearch.ts
index e06efe612..466a5e9e0 100644
--- a/packages/core/src/vectorsearch.ts
+++ b/packages/core/src/vectorsearch.ts
@@ -175,8 +175,7 @@ export async function vectorSearch(
const {
topK,
folderPath,
- embeddingsModel = runtimeHost.defaultEmbeddingsModelOptions
- .embeddingsModel,
+ embeddingsModel = runtimeHost.modelAliases.embeddings.model,
minScore = 0,
trace,
} = options
@@ -199,7 +198,7 @@ export async function vectorSearch(
{
token: true,
candidates: [
- host.defaultEmbeddingsModelOptions.embeddingsModel,
+ runtimeHost.modelAliases.embeddings.model,
...DEFAULT_EMBEDDINGS_MODEL_CANDIDATES,
],
}
diff --git a/packages/vscode/src/vshost.ts b/packages/vscode/src/vshost.ts
index 91d1bbab3..5920f792a 100644
--- a/packages/vscode/src/vshost.ts
+++ b/packages/vscode/src/vshost.ts
@@ -6,15 +6,6 @@ import { ExtensionState } from "./state"
import { Utils } from "vscode-uri"
import { checkFileExists, readFileText } from "./fs"
import { filterGitIgnore } from "../../core/src/gitignore"
-import { parseDefaultsFromEnv } from "../../core/src/connection"
-import {
- DEFAULT_EMBEDDINGS_MODEL,
- DEFAULT_MODEL,
- DEFAULT_SMALL_MODEL,
- DEFAULT_TEMPERATURE,
- DEFAULT_VISION_MODEL,
-} from "../../core/src/constants"
-import { dotEnvTryParse } from "../../core/src/dotenv"
import {
setHost,
LanguageModelConfiguration,
@@ -30,16 +21,6 @@ export class VSCodeHost extends EventTarget implements Host {
userState: any = {}
readonly path = createVSPath()
readonly server: TerminalServerManager
- readonly defaultModelOptions = {
- model: DEFAULT_MODEL,
- smallModel: DEFAULT_SMALL_MODEL,
- visionModel: DEFAULT_VISION_MODEL,
- temperature: DEFAULT_TEMPERATURE,
- }
- readonly defaultEmbeddingsModelOptions = {
- embeddingsModel: DEFAULT_EMBEDDINGS_MODEL,
- }
-
constructor(readonly state: ExtensionState) {
super()
setHost(this)