diff --git a/docs/src/content/docs/getting-started/configuration.mdx b/docs/src/content/docs/getting-started/configuration.mdx index 5c6155956..4fd543a57 100644 --- a/docs/src/content/docs/getting-started/configuration.mdx +++ b/docs/src/content/docs/getting-started/configuration.mdx @@ -64,11 +64,15 @@ GENAISCRIPT_MODEL_SMALL="azure_serverless:..." GENAISCRIPT_MODEL_VISION="azure_serverless:..." ``` -You can also configure the default aliases for a given LLM provider by using the `--provider` argument in the CLI. +You can also configure the default aliases for a given LLM provider by using the `provider` argument. The default are documented in this page and printed to the console output. +```js +script({ provicder: "openai" }) +``` + ```sh -genaiscript run ... --provider anthropic +genaiscript run ... --provider openai ``` ### Model aliases diff --git a/packages/cli/src/convert.ts b/packages/cli/src/convert.ts index 6442c5649..cd7ddf477 100644 --- a/packages/cli/src/convert.ts +++ b/packages/cli/src/convert.ts @@ -57,7 +57,7 @@ export async function convertFiles( const canceller = createCancellationController() const cancellationToken = canceller.token const signal = toSignal(cancellationToken) - applyModelOptions(options) + applyModelOptions(options, "cli") const outTrace = dotGenaiscriptPath( CONVERTS_DIR_NAME, host.path.basename(scriptId).replace(GENAI_ANYTS_REGEX, ""), diff --git a/packages/cli/src/modelalias.ts b/packages/cli/src/modelalias.ts index ddff1d3b1..4f3c0785a 100644 --- a/packages/cli/src/modelalias.ts +++ b/packages/cli/src/modelalias.ts @@ -4,32 +4,40 @@ import { runtimeHost } from "../../core/src/host" import { logVerbose } from "../../core/src/util" import { PromptScriptRunOptions } from "./main" +export function applyModelProviderAliases( + id: string, + source: "cli" | "env" | "config" | "script" +) { + if (!id) return + const provider = MODEL_PROVIDERS.find((p) => p.id === id) + if (!provider) throw new Error(`Model provider not found: ${id}`) + for (const [key, value] of Object.entries(provider.aliases || {})) + runtimeHost.setModelAlias(source, key, provider.id + ":" + value) +} + export function applyModelOptions( options: Partial< Pick< PromptScriptRunOptions, "model" | "smallModel" | "visionModel" | "modelAlias" | "provider" > - > + >, + source: "cli" | "env" | "config" | "script" ) { - if (options.provider) { - const provider = MODEL_PROVIDERS.find((p) => p.id === options.provider) - if (!provider) - throw new Error(`Model provider not found: ${options.provider}`) - for (const [key, value] of Object.entries(provider.aliases || {})) - runtimeHost.setModelAlias("cli", key, provider.id + ":" + value) - } - if (options.model) runtimeHost.setModelAlias("cli", "large", options.model) + if (options.provider) applyModelProviderAliases(options.provider, source) + if (options.model) runtimeHost.setModelAlias(source, "large", options.model) if (options.smallModel) - runtimeHost.setModelAlias("cli", "small", options.smallModel) + runtimeHost.setModelAlias(source, "small", options.smallModel) if (options.visionModel) - runtimeHost.setModelAlias("cli", "vision", options.visionModel) + runtimeHost.setModelAlias(source, "vision", options.visionModel) for (const kv of options.modelAlias || []) { const aliases = parseKeyValuePair(kv) for (const [key, value] of Object.entries(aliases)) - runtimeHost.setModelAlias("cli", key, value) + runtimeHost.setModelAlias(source, key, value) } +} +export function logModelAliases() { const modelAlias = runtimeHost.modelAliases if (Object.values(modelAlias).some((m) => m.source !== "default")) Object.entries(runtimeHost.modelAliases).forEach(([key, value]) => diff --git a/packages/cli/src/nodehost.ts b/packages/cli/src/nodehost.ts index 0a83d3c77..7f11cb05e 100644 --- a/packages/cli/src/nodehost.ts +++ b/packages/cli/src/nodehost.ts @@ -84,12 +84,13 @@ export class NodeHost implements RuntimeHost { readonly containers = new DockerManager() readonly browsers = new BrowserManager() private readonly _modelAliases: Record< - "default" | "cli" | "env" | "config", + "default" | "cli" | "env" | "config" | "script", Omit > = { default: defaultModelConfigurations(), cli: {}, env: {}, + script: {}, config: {}, } readonly userInputQueue = new PLimitPromiseQueue(1) @@ -114,6 +115,7 @@ export class NodeHost implements RuntimeHost { const res = { ...this._modelAliases.default, ...this._modelAliases.config, + ...this._modelAliases.script, ...this._modelAliases.env, ...this._modelAliases.cli, } as ModelConfigurations @@ -121,7 +123,7 @@ export class NodeHost implements RuntimeHost { } setModelAlias( - source: "cli" | "env" | "config", + source: "cli" | "env" | "config" | "script", id: string, value: string | ModelConfiguration ): void { diff --git a/packages/cli/src/run.ts b/packages/cli/src/run.ts index 1ce33ee9d..71bf690c6 100644 --- a/packages/cli/src/run.ts +++ b/packages/cli/src/run.ts @@ -83,7 +83,7 @@ import { stdout, } from "../../core/src/logging" import { ensureDotGenaiscriptPath, setupTraceWriting } from "./trace" -import { applyModelOptions } from "./modelalias" +import { applyModelOptions, logModelAliases } from "./modelalias" import { createCancellationController } from "./cancel" import { parsePromptScriptMeta } from "../../core/src/template" @@ -181,7 +181,7 @@ export async function runScriptInternal( const fenceFormat = options.fenceFormat if (options.json || options.yaml) overrideStdoutWithStdErr() - applyModelOptions(options) + applyModelOptions(options, "cli") const fail = (msg: string, exitCode: number, url?: string) => { logError(url ? `${msg} (see ${url})` : msg) @@ -264,6 +264,8 @@ export async function runScriptInternal( const stats = new GenerationStats("") try { if (options.label) trace.heading(2, options.label) + applyModelOptions(script, "script") + logModelAliases() const { info } = await resolveModelConnectionInfo(script, { trace, model: options.model, diff --git a/packages/cli/src/test.ts b/packages/cli/src/test.ts index c5276dcfe..b9e5c36fe 100644 --- a/packages/cli/src/test.ts +++ b/packages/cli/src/test.ts @@ -110,7 +110,7 @@ export async function runPromptScriptTests( testDelay?: string } ): Promise { - applyModelOptions(options) + applyModelOptions(options, "cli") const scripts = await listTests({ ids, ...(options || {}) }) if (!scripts.length) diff --git a/packages/core/src/host.ts b/packages/core/src/host.ts index 023f6c159..5e20b5223 100644 --- a/packages/core/src/host.ts +++ b/packages/core/src/host.ts @@ -111,7 +111,7 @@ export interface AzureTokenResolver { export type ModelConfiguration = Readonly< Pick & { - source: "cli" | "env" | "config" | "default" + source: "cli" | "env" | "script" | "config" | "default" candidates?: string[] } > @@ -171,7 +171,7 @@ export interface RuntimeHost extends Host { pullModel(model: string, options?: TraceOptions): Promise setModelAlias( - source: "env" | "cli" | "config", + source: "env" | "cli" | "config" | "script", id: string, value: string | Omit ): void diff --git a/packages/core/src/llms.json b/packages/core/src/llms.json index 5e144827d..537459a40 100644 --- a/packages/core/src/llms.json +++ b/packages/core/src/llms.json @@ -74,7 +74,9 @@ "large": "gemini-2.0-flash-exp", "small": "gemini-1.5-flash-latest", "vision": "gemini-2.0-flash-exp", + "long": "gemini-2.0-flash-exp", "reasoning": "gemini-2.0-flash-thinking-exp-1219", + "reasoning_small": "gemini-2.0-flash-thinking-exp-1219", "embeddings": "text-embedding-004" } }, diff --git a/packages/core/src/template.ts b/packages/core/src/template.ts index b6393c54b..168c58e2d 100644 --- a/packages/core/src/template.ts +++ b/packages/core/src/template.ts @@ -383,6 +383,10 @@ async function parsePromptTemplateCore( checker.checkString("title") checker.checkString("description") checker.checkString("model") + checker.checkString("smallModel") + checker.checkString("visionModel") + checker.checkString("embeddingsModel") + checker.checkString("provider") checker.checkString("responseType") checker.checkJSONSchema("responseSchema") diff --git a/packages/core/src/types/prompt_template.d.ts b/packages/core/src/types/prompt_template.d.ts index 4ad900c21..c4dfbbb17 100644 --- a/packages/core/src/types/prompt_template.d.ts +++ b/packages/core/src/types/prompt_template.d.ts @@ -191,6 +191,25 @@ type ModelVisionType = OptionsOrString< "openai:gpt-4o" | "github:gpt-4o" | "azure:gpt-4o" | "azure:gpt-4o-mini" > +type ModelProviderType = + | "openai" + | "azure" + | "azure_serverless" + | "azure_serverless_models" + | "anthropic" + | "anthropic_bedrock" + | "google" + | "huggingface" + | "mistral" + | "alibaba" + | "github" + | "transformers" + | "ollama" + | "lmstudio" + | "jan" + | "llamafile" + | "litellm" + interface ModelConnectionOptions { /** * Which LLM model by default or for the `large` alias. @@ -473,6 +492,11 @@ interface PromptScript EmbeddingsModelOptions, ContentSafetyOptions, ScriptRuntimeOptions { + /** + * Which provider to prefer when picking a model. + */ + provider?: ModelProviderType + /** * Additional template parameters that will populate `env.vars` */ diff --git a/packages/sample/genaisrc/provider.genai.mjs b/packages/sample/genaisrc/provider.genai.mjs new file mode 100644 index 000000000..c660d6d92 --- /dev/null +++ b/packages/sample/genaisrc/provider.genai.mjs @@ -0,0 +1,5 @@ +script({ + provider: "google", + smallModel: "azure:gpt-4o-mini", +}) +$`Write a poem.`