Skip to content

Commit

Permalink
feat: ✨ add script-level provider and model options
Browse files Browse the repository at this point in the history
  • Loading branch information
pelikhan committed Dec 21, 2024
1 parent 11a12e7 commit 0b0e503
Show file tree
Hide file tree
Showing 11 changed files with 73 additions and 22 deletions.
8 changes: 6 additions & 2 deletions docs/src/content/docs/getting-started/configuration.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,15 @@ GENAISCRIPT_MODEL_SMALL="azure_serverless:..."
GENAISCRIPT_MODEL_VISION="azure_serverless:..."
```

You can also configure the default aliases for a given LLM provider by using the `--provider` argument in the CLI.
You can also configure the default aliases for a given LLM provider by using the `provider` argument.
The default are documented in this page and printed to the console output.

```js
script({ provicder: "openai" })
```

```sh
genaiscript run ... --provider anthropic
genaiscript run ... --provider openai
```

### Model aliases
Expand Down
2 changes: 1 addition & 1 deletion packages/cli/src/convert.ts
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ export async function convertFiles(
const canceller = createCancellationController()
const cancellationToken = canceller.token
const signal = toSignal(cancellationToken)
applyModelOptions(options)
applyModelOptions(options, "cli")
const outTrace = dotGenaiscriptPath(
CONVERTS_DIR_NAME,
host.path.basename(scriptId).replace(GENAI_ANYTS_REGEX, ""),
Expand Down
32 changes: 20 additions & 12 deletions packages/cli/src/modelalias.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,32 +4,40 @@ import { runtimeHost } from "../../core/src/host"
import { logVerbose } from "../../core/src/util"
import { PromptScriptRunOptions } from "./main"

export function applyModelProviderAliases(
id: string,
source: "cli" | "env" | "config" | "script"
) {
if (!id) return
const provider = MODEL_PROVIDERS.find((p) => p.id === id)
if (!provider) throw new Error(`Model provider not found: ${id}`)
for (const [key, value] of Object.entries(provider.aliases || {}))
runtimeHost.setModelAlias(source, key, provider.id + ":" + value)
}

export function applyModelOptions(
options: Partial<
Pick<
PromptScriptRunOptions,
"model" | "smallModel" | "visionModel" | "modelAlias" | "provider"
>
>
>,
source: "cli" | "env" | "config" | "script"
) {
if (options.provider) {
const provider = MODEL_PROVIDERS.find((p) => p.id === options.provider)
if (!provider)
throw new Error(`Model provider not found: ${options.provider}`)
for (const [key, value] of Object.entries(provider.aliases || {}))
runtimeHost.setModelAlias("cli", key, provider.id + ":" + value)
}
if (options.model) runtimeHost.setModelAlias("cli", "large", options.model)
if (options.provider) applyModelProviderAliases(options.provider, source)
if (options.model) runtimeHost.setModelAlias(source, "large", options.model)
if (options.smallModel)
runtimeHost.setModelAlias("cli", "small", options.smallModel)
runtimeHost.setModelAlias(source, "small", options.smallModel)
if (options.visionModel)
runtimeHost.setModelAlias("cli", "vision", options.visionModel)
runtimeHost.setModelAlias(source, "vision", options.visionModel)
for (const kv of options.modelAlias || []) {
const aliases = parseKeyValuePair(kv)
for (const [key, value] of Object.entries(aliases))
runtimeHost.setModelAlias("cli", key, value)
runtimeHost.setModelAlias(source, key, value)
}
}

export function logModelAliases() {
const modelAlias = runtimeHost.modelAliases
if (Object.values(modelAlias).some((m) => m.source !== "default"))
Object.entries(runtimeHost.modelAliases).forEach(([key, value]) =>
Expand Down
6 changes: 4 additions & 2 deletions packages/cli/src/nodehost.ts
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,13 @@ export class NodeHost implements RuntimeHost {
readonly containers = new DockerManager()
readonly browsers = new BrowserManager()
private readonly _modelAliases: Record<
"default" | "cli" | "env" | "config",
"default" | "cli" | "env" | "config" | "script",
Omit<ModelConfigurations, "large" | "small" | "vision" | "embeddings">
> = {
default: defaultModelConfigurations(),
cli: {},
env: {},
script: {},
config: {},
}
readonly userInputQueue = new PLimitPromiseQueue(1)
Expand All @@ -114,14 +115,15 @@ export class NodeHost implements RuntimeHost {
const res = {
...this._modelAliases.default,
...this._modelAliases.config,
...this._modelAliases.script,
...this._modelAliases.env,
...this._modelAliases.cli,
} as ModelConfigurations
return Object.freeze(res)
}

setModelAlias(
source: "cli" | "env" | "config",
source: "cli" | "env" | "config" | "script",
id: string,
value: string | ModelConfiguration
): void {
Expand Down
6 changes: 4 additions & 2 deletions packages/cli/src/run.ts
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ import {
stdout,
} from "../../core/src/logging"
import { ensureDotGenaiscriptPath, setupTraceWriting } from "./trace"
import { applyModelOptions } from "./modelalias"
import { applyModelOptions, logModelAliases } from "./modelalias"
import { createCancellationController } from "./cancel"
import { parsePromptScriptMeta } from "../../core/src/template"

Expand Down Expand Up @@ -181,7 +181,7 @@ export async function runScriptInternal(
const fenceFormat = options.fenceFormat

if (options.json || options.yaml) overrideStdoutWithStdErr()
applyModelOptions(options)
applyModelOptions(options, "cli")

const fail = (msg: string, exitCode: number, url?: string) => {
logError(url ? `${msg} (see ${url})` : msg)
Expand Down Expand Up @@ -264,6 +264,8 @@ export async function runScriptInternal(
const stats = new GenerationStats("")
try {
if (options.label) trace.heading(2, options.label)
applyModelOptions(script, "script")
logModelAliases()
const { info } = await resolveModelConnectionInfo(script, {
trace,
model: options.model,
Expand Down
2 changes: 1 addition & 1 deletion packages/cli/src/test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ export async function runPromptScriptTests(
testDelay?: string
}
): Promise<PromptScriptTestRunResponse> {
applyModelOptions(options)
applyModelOptions(options, "cli")

const scripts = await listTests({ ids, ...(options || {}) })
if (!scripts.length)
Expand Down
4 changes: 2 additions & 2 deletions packages/core/src/host.ts
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ export interface AzureTokenResolver {

export type ModelConfiguration = Readonly<
Pick<ModelOptions, "model" | "temperature"> & {
source: "cli" | "env" | "config" | "default"
source: "cli" | "env" | "script" | "config" | "default"
candidates?: string[]
}
>
Expand Down Expand Up @@ -171,7 +171,7 @@ export interface RuntimeHost extends Host {
pullModel(model: string, options?: TraceOptions): Promise<ResponseStatus>

setModelAlias(
source: "env" | "cli" | "config",
source: "env" | "cli" | "config" | "script",
id: string,
value: string | Omit<ModelConfiguration, "source">
): void
Expand Down
2 changes: 2 additions & 0 deletions packages/core/src/llms.json
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@
"large": "gemini-2.0-flash-exp",
"small": "gemini-1.5-flash-latest",
"vision": "gemini-2.0-flash-exp",
"long": "gemini-2.0-flash-exp",
"reasoning": "gemini-2.0-flash-thinking-exp-1219",
"reasoning_small": "gemini-2.0-flash-thinking-exp-1219",
"embeddings": "text-embedding-004"
}
},
Expand Down
4 changes: 4 additions & 0 deletions packages/core/src/template.ts
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,10 @@ async function parsePromptTemplateCore(
checker.checkString("title")
checker.checkString("description")
checker.checkString("model")
checker.checkString("smallModel")
checker.checkString("visionModel")
checker.checkString("embeddingsModel")
checker.checkString("provider")
checker.checkString("responseType")
checker.checkJSONSchema("responseSchema")

Expand Down
24 changes: 24 additions & 0 deletions packages/core/src/types/prompt_template.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,25 @@ type ModelVisionType = OptionsOrString<
"openai:gpt-4o" | "github:gpt-4o" | "azure:gpt-4o" | "azure:gpt-4o-mini"
>

type ModelProviderType =
| "openai"
| "azure"
| "azure_serverless"
| "azure_serverless_models"
| "anthropic"
| "anthropic_bedrock"
| "google"
| "huggingface"
| "mistral"
| "alibaba"
| "github"
| "transformers"
| "ollama"
| "lmstudio"
| "jan"
| "llamafile"
| "litellm"

interface ModelConnectionOptions {
/**
* Which LLM model by default or for the `large` alias.
Expand Down Expand Up @@ -473,6 +492,11 @@ interface PromptScript
EmbeddingsModelOptions,
ContentSafetyOptions,
ScriptRuntimeOptions {
/**
* Which provider to prefer when picking a model.
*/
provider?: ModelProviderType

/**
* Additional template parameters that will populate `env.vars`
*/
Expand Down
5 changes: 5 additions & 0 deletions packages/sample/genaisrc/provider.genai.mjs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
script({
provider: "google",
smallModel: "azure:gpt-4o-mini",
})
$`Write a poem.`

0 comments on commit 0b0e503

Please sign in to comment.