Skip to content

Commit

Permalink
feat: ✨ add support for LLM provider aliases and options
Browse files Browse the repository at this point in the history
  • Loading branch information
pelikhan committed Dec 21, 2024
1 parent b1c78f0 commit 500f3fb
Show file tree
Hide file tree
Showing 11 changed files with 93 additions and 42 deletions.
14 changes: 12 additions & 2 deletions docs/src/content/docs/reference/cli/commands.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Usage: genaiscript run [options] <script> [files...]
Runs a GenAIScript against files.
Options:
-p, --provider <string> Preferred LLM provider aliases (choices: "openai", "azure", "azure_serverless", "azure_serverless_models", "anthropic", "anthropic_bedrock", "google", "huggingface", "mistral", "alibaba", "github", "transformers", "ollama", "lmstudio", "jan", "llamafile", "litellm")
-m, --model <string> 'large' model alias (default)
-sm, --small-model <string> 'small' alias model
-vm, --vision-model <string> 'vision' alias model
Expand Down Expand Up @@ -86,9 +87,18 @@ Arguments:
are tested
Options:
-m, --model <string> model for the run
-sm, --small-model <string> small model for the run
-p, --provider <string> Preferred LLM provider aliases (choices:
"openai", "azure", "azure_serverless",
"azure_serverless_models", "anthropic",
"anthropic_bedrock", "google",
"huggingface", "mistral", "alibaba",
"github", "transformers", "ollama",
"lmstudio", "jan", "llamafile",
"litellm")
-m, --model <string> 'large' model alias (default)
-sm, --small-model <string> 'small' alias model
-vm, --vision-model <string> 'vision' alias model
-ma, --model-alias <nameid...> model alias as name=modelid
--models <models...> models to test where mode is the key
value pair list of m (model), s (small
model), t (temperature), p (top-p)
Expand Down
40 changes: 28 additions & 12 deletions packages/cli/src/cli.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
*/

import { NodeHost } from "./nodehost" // Handles node environment setup
import { Option, program } from "commander" // Command-line argument parsing library
import { Command, Option, program } from "commander" // Command-line argument parsing library
import { error, isQuiet, setConsoleColors, setQuiet } from "./log" // Logging utilities
import { startServer } from "./server" // Function to start server
import { NODE_MIN_VERSION, PROMPTFOO_VERSION } from "./version" // Version constants
Expand All @@ -24,7 +24,7 @@ import {
} from "./parse" // Parsing functions
import { compileScript, createScript, fixScripts, listScripts } from "./scripts" // Script utilities
import { codeQuery } from "./codequery" // Code parsing and query execution
import { envInfo, modelsInfo as modelAliasesInfo, scriptModelInfo, systemInfo } from "./info" // Information utilities
import { envInfo, modelAliasesInfo, scriptModelInfo, systemInfo } from "./info" // Information utilities
import { scriptTestList, scriptTestsView, scriptsTest } from "./test" // Test functions
import { cacheClear } from "./cache" // Cache management
import "node:console" // Importing console for side effects
Expand All @@ -37,7 +37,7 @@ import {
OPENAI_MAX_RETRY_DELAY,
OPENAI_RETRY_DEFAULT_DEFAULT,
OPENAI_MAX_RETRY_COUNT,
GENAI_MD_EXT,
MODEL_PROVIDERS,
} from "../../core/src/constants" // Core constants
import {
errorMessage,
Expand Down Expand Up @@ -94,14 +94,11 @@ export async function cli() {
program.on("option:quiet", () => setQuiet(true))

// Define 'run' command for executing scripts
program
const run = program
.command("run")
.description("Runs a GenAIScript against files.")
.arguments("<script> [files...]")
.option("-m, --model <string>", "'large' model alias (default)")
.option("-sm, --small-model <string>", "'small' alias model")
.option("-vm, --vision-model <string>", "'vision' alias model")
.option("-ma, --model-alias <nameid...>", "model alias as name=modelid")
addModelOptions(run) // Add model options to the command
.option("-lp, --logprobs", "enable reporting token probabilities")
.option(
"-tlp, --top-logprobs <number>",
Expand Down Expand Up @@ -203,15 +200,13 @@ export async function cli() {
// Define 'test' command group for running tests
const test = program.command("test")

test.command("run", { isDefault: true })
const testRun = test.command("run", { isDefault: true })
.description("Runs the tests for scripts")
.argument(
"[script...]",
"Script ids. If not provided, all scripts are tested"
)
.option("-m, --model <string>", "model for the run")
.option("-sm, --small-model <string>", "small model for the run")
.option("-vm, --vision-model <string>", "'vision' alias model")
addModelOptions(testRun) // Add model options to the command
.option(
"--models <models...>",
"models to test where mode is the key value pair list of m (model), s (small model), t (temperature), p (top-p)"
Expand Down Expand Up @@ -456,4 +451,25 @@ export async function cli() {
.action(modelAliasesInfo)

program.parse() // Parse command-line arguments

function addModelOptions(command: Command) {
return command
.addOption(
new Option(
"-p, --provider <string>",
"Preferred LLM provider aliases"
).choices(
MODEL_PROVIDERS.filter(
({ id, aliases }) => id !== "client"
).map(({ id }) => id)
)
)
.option("-m, --model <string>", "'large' model alias (default)")
.option("-sm, --small-model <string>", "'small' alias model")
.option("-vm, --vision-model <string>", "'vision' alias model")
.option(
"-ma, --model-alias <nameid...>",
"model alias as name=modelid"
)
}
}
2 changes: 1 addition & 1 deletion packages/cli/src/convert.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import { PLimitPromiseQueue } from "../../core/src/concurrency"
import { createPatch } from "diff"
import { unfence } from "../../core/src/fence"
import { JSONLLMTryParse, JSONTryParse } from "../../core/src/json5"
import { applyModelOptions } from "./modealias"
import { applyModelOptions } from "./modelalias"
import { ensureDotGenaiscriptPath, setupTraceWriting } from "./trace"
import { tracePromptResult } from "../../core/src/chat"
import { dirname, join } from "node:path"
Expand Down
2 changes: 1 addition & 1 deletion packages/cli/src/info.ts
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,6 @@ export async function scriptModelInfo(
console.log(YAMLStringify(info))
}

export async function modelsInfo() {
export async function modelAliasesInfo() {
console.log(YAML.stringify(runtimeHost.modelAliases))
}
16 changes: 0 additions & 16 deletions packages/cli/src/modealias.ts

This file was deleted.

38 changes: 38 additions & 0 deletions packages/cli/src/modelalias.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import { MODEL_PROVIDERS } from "../../core/src/constants"
import { parseKeyValuePair } from "../../core/src/fence"
import { runtimeHost } from "../../core/src/host"
import { logVerbose } from "../../core/src/util"
import { PromptScriptRunOptions } from "./main"

export function applyModelOptions(
options: Partial<
Pick<
PromptScriptRunOptions,
"model" | "smallModel" | "visionModel" | "modelAlias" | "provider"
>
>
) {
if (options.provider) {
const provider = MODEL_PROVIDERS.find((p) => p.id === options.provider)
if (!provider)
throw new Error(`Model provider not found: ${options.provider}`)
for (const [key, value] of Object.entries(provider.aliases || {}))
runtimeHost.setModelAlias("cli", key, provider.id + ":" + value)
}
if (options.model) runtimeHost.setModelAlias("cli", "large", options.model)
if (options.smallModel)
runtimeHost.setModelAlias("cli", "small", options.smallModel)
if (options.visionModel)
runtimeHost.setModelAlias("cli", "vision", options.visionModel)
for (const kv of options.modelAlias || []) {
const aliases = parseKeyValuePair(kv)
for (const [key, value] of Object.entries(aliases))
runtimeHost.setModelAlias("cli", key, value)
}

const modelAlias = runtimeHost.modelAliases
if (Object.values(modelAlias).some((m) => m.source !== "default"))
Object.entries(runtimeHost.modelAliases).forEach(([key, value]) =>
logVerbose(` ${key}: ${value.model} (${value.source})`)
)
}
2 changes: 1 addition & 1 deletion packages/cli/src/run.ts
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ import {
stdout,
} from "../../core/src/logging"
import { ensureDotGenaiscriptPath, setupTraceWriting } from "./trace"
import { applyModelOptions } from "./modealias"
import { applyModelOptions } from "./modelalias"
import { createCancellationController } from "./cancel"
import { parsePromptScriptMeta } from "../../core/src/template"

Expand Down
10 changes: 2 additions & 8 deletions packages/cli/src/test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ import { delay } from "es-toolkit"
import { resolveModelConnectionInfo } from "../../core/src/models"
import { filterScripts } from "../../core/src/ast"
import { link } from "../../core/src/markdown"
import { applyModelOptions } from "./modelalias"

/**
* Parses model specifications from a string and returns a ModelOptions object.
Expand Down Expand Up @@ -109,14 +110,7 @@ export async function runPromptScriptTests(
testDelay?: string
}
): Promise<PromptScriptTestRunResponse> {
if (options.model) runtimeHost.setModelAlias("cli", "large", options.model)
if (options.smallModel)
runtimeHost.setModelAlias("cli", "small", options.smallModel)
if (options.visionModel)
runtimeHost.setModelAlias("cli", "vision", options.visionModel)
Object.entries(runtimeHost.modelAliases).forEach(([key, value]) =>
logVerbose(` ${key}: ${value.model}`)
)
applyModelOptions(options)

const scripts = await listTests({ ids, ...(options || {}) })
if (!scripts.length)
Expand Down
1 change: 1 addition & 0 deletions packages/core/src/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ export const MODEL_PROVIDERS = Object.freeze<
topP?: boolean
prediction?: boolean
bearerToken?: boolean
aliases?: Record<string, string>
}[]
>(CONFIGURATION_DATA.providers)
export const MODEL_PRICINGS = Object.freeze<
Expand Down
9 changes: 8 additions & 1 deletion packages/core/src/llms.json
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,14 @@
{
"id": "huggingface",
"detail": "Hugging Face models",
"prediction": false
"prediction": false,
"aliases": {
"large": "Qwen/Qwen2.5-72B-Instruct",
"small": "Qwen/Qwen2.5-Coder-32B-Instruct",
"vision": "Qwen/Qwen2-VL-7B-Instruct",
"reasoning": "Qwen/QwQ-32B-Preview",
"reasoning_small": "Qwen/QwQ-32B-Preview"
}
},
{
"id": "mistral",
Expand Down
1 change: 1 addition & 0 deletions packages/core/src/server/messages.ts
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ export interface PromptScriptRunOptions {
visionModel: string
embeddingsModel: string
modelAlias: string[]
provider: string
csvSeparator: string
cache: boolean | string
cacheName: string
Expand Down

0 comments on commit 500f3fb

Please sign in to comment.