Skip to content

Commit

Permalink
refactor model to support aliases (#914)
Browse files Browse the repository at this point in the history
* refactor model to support aliases

* refactor: update model options and remove smallModel 🎨

* feat: πŸš€ add vision model support and update env vars

* apply alias when resolving models

* refactor: ♻️ update model alias references to runtimeHost
  • Loading branch information
pelikhan authored Dec 5, 2024
1 parent 666f1b2 commit 19d5037
Show file tree
Hide file tree
Showing 22 changed files with 162 additions and 157 deletions.
65 changes: 40 additions & 25 deletions docs/src/content/docs/getting-started/configuration.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ script({
})
```

### Large and small models
### Large, small, vision models

You can also use the `small` and `large` aliases to use the default configured small and large models.
You can also use the `small`, `large`, `vision` aliases to use the default configured small, large and vision-enabled models.
Large models are typically in the OpenAI gpt-4 reasoning range and can be used for more complex tasks.
Small models are in the OpenAI gpt-4o-mini range, and are useful for quick and simple tasks.

Expand All @@ -60,11 +60,28 @@ The model can also be overridden from the [cli run command](/genaiscript/referen
genaiscript run ... --model largemodelid --small-model smallmodelid
```

or by adding the `GENAISCRIPT_DEFAULT_MODEL` and `GENAISCRIPT_DEFAULT_SMALL_MODEL` environment variables.
or by adding the `GENAISCRIPT_LARGE_MODEL` and `GENAISCRIPT_SMALL_MODEL` environment variables.

```txt title=".env"
GENAISCRIPT_DEFAULT_MODEL="azure_serverless:..."
GENAISCRIPT_DEFAULT_SMALL_MODEL="azure_serverless:..."
GENAISCRIPT_MODEL_LARGE="azure_serverless:..."
GENAISCRIPT_MODEL_SMALL="azure_serverless:..."
GENAISCRIPT_MODEL_VISION="azure_serverless:..."
```

### Model aliases

In fact, you can define any alias for your model (only alphanumeric characters are allowed)
through environment variables of the name `GENAISCRIPT_MODEL_ALIAS`
where `ALIAS` is the alias you want to use.

```txt title=".env"
GENAISCRIPT_MODEL_TINY=...
```

Model aliases are always lowercased when used in the script.

```js
script({ model: "tiny" })
```

## `.env` file
Expand Down Expand Up @@ -93,8 +110,8 @@ Create a `.env` file in the root of your project.

<FileTree>

- .gitignore
- **.env**
- .gitignore
- **.env**

</FileTree>

Expand Down Expand Up @@ -127,19 +144,19 @@ the `.env` file will appear grayed out in Visual Studio Code.

You can specify a custom `.env` file location through the CLI or an environment variable.

- by adding the `--env <file>` argument to the CLI.
- by adding the `--env <file>` argument to the CLI.

```sh "--env .env.local"
npx genaiscript ... --env .env.local
```

- by setting the `GENAISCRIPT_ENV_FILE` environment variable.
- by setting the `GENAISCRIPT_ENV_FILE` environment variable.

```sh
GENAISCRIPT_ENV_FILE=".env.local" npx genaiscript ...
```

- by specifying the `.env` file location in a [configuration file](/genaiscript/reference/configuration-files).
- by specifying the `.env` file location in a [configuration file](/genaiscript/reference/configuration-files).

```yaml title="~/genaiscript.config.yaml"
envFile: ~/.env.genaiscript
Expand All @@ -152,13 +169,13 @@ of the genaiscript process with the configuration values.

Here are some common examples:

- Using bash syntax
- Using bash syntax

```sh
OPENAI_API_KEY="value" npx --yes genaiscript run ...
```

- GitHub Action configuration
- GitHub Action configuration

```yaml title=".github/workflows/genaiscript.yml"
run: npx --yes genaiscript run ...
Expand Down Expand Up @@ -219,11 +236,11 @@ script({
:::tip[Default Model Configuration]
Use `GENAISCRIPT_DEFAULT_MODEL` and `GENAISCRIPT_DEFAULT_SMALL_MODEL` in your `.env` file to set the default model and small model.
Use `GENAISCRIPT_MODEL_LARGE` and `GENAISCRIPT_MODEL_SMALL` in your `.env` file to set the default model and small model.
```txt
GENAISCRIPT_DEFAULT_MODEL=openai:gpt-4o
GENAISCRIPT_DEFAULT_SMALL_MODEL=openai:gpt-4o-mini
GENAISCRIPT_MODEL_LARGE=openai:gpt-4o
GENAISCRIPT_MODEL_SMALL=openai:gpt-4o-mini
```
:::
Expand Down Expand Up @@ -412,11 +429,11 @@ AZURE_OPENAI_API_CREDENTIALS=cli

The types are mapped directly to their [@azure/identity](https://www.npmjs.com/package/@azure/identity) credential types:

- `cli` - `AzureCliCredential`
- `env` - `EnvironmentCredential`
- `powershell` - `AzurePowerShellCredential`
- `devcli` - `AzureDeveloperCliCredential`
- `managedidentity` - `ManagedIdentityCredential`
- `cli` - `AzureCliCredential`
- `env` - `EnvironmentCredential`
- `powershell` - `AzurePowerShellCredential`
- `devcli` - `AzureDeveloperCliCredential`
- `managedidentity` - `ManagedIdentityCredential`

### Custom token scopes

Expand Down Expand Up @@ -1080,7 +1097,6 @@ script({
})
```


### Ollama with Docker

You can conviniately run Ollama in a Docker container.
Expand All @@ -1097,11 +1113,10 @@ docker run -d -v ollama:/root/.ollama -p 11434:11434 --name ollama ollama/ollama
docker stop ollama && docker rm ollama
```


## LMStudio

The `lmstudio` provider connects to the [LMStudio](https://lmstudio.ai/) headless server.
and allows to run local LLMs.
and allows to run local LLMs.

:::note

Expand Down Expand Up @@ -1259,8 +1274,8 @@ This `transformers` provider runs models on device using [Hugging Face Transform

The model syntax is `transformers:<repo>:<dtype>` where

- `repo` is the model repository on Hugging Face,
- [`dtype`](https://huggingface.co/docs/transformers.js/guides/dtypes) is the quantization type.
- `repo` is the model repository on Hugging Face,
- [`dtype`](https://huggingface.co/docs/transformers.js/guides/dtypes) is the quantization type.

```js "transformers:"
script({
Expand Down
1 change: 1 addition & 0 deletions docs/src/content/docs/reference/cli/commands.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ Options:
-m, --model <string> 'large' model alias (default)
-sm, --small-model <string> 'small' alias model
-vm, --vision-model <string> 'vision' alias model
-ma, --model-alias <nameid...> model alias as name=modelid
-lp, --logprobs enable reporting token probabilities
-tlp, --top-logprobs <number> number of top logprobs (1 to 5)
-ef, --excluded-files <string...> excluded files
Expand Down
1 change: 1 addition & 0 deletions packages/cli/src/cli.ts
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ export async function cli() {
.option("-m, --model <string>", "'large' model alias (default)")
.option("-sm, --small-model <string>", "'small' alias model")
.option("-vm, --vision-model <string>", "'vision' alias model")
.option("-ma, --model-alias <nameid...>", "model alias as name=modelid")
.option("-lp, --logprobs", "enable reporting token probabilities")
.option(
"-tlp, --top-logprobs <number>",
Expand Down
8 changes: 4 additions & 4 deletions packages/cli/src/info.ts
Original file line number Diff line number Diff line change
Expand Up @@ -84,20 +84,20 @@ export async function envInfo(

/**
* Resolves connection information for script templates by deduplicating model options.
* @param templates - Array of model connection options to resolve.
* @param scripts - Array of model connection options to resolve.
* @param options - Configuration options, including whether to show tokens.
* @returns A promise that resolves to an array of model connection information.
*/
async function resolveScriptsConnectionInfo(
templates: ModelConnectionOptions[],
scripts: ModelConnectionOptions[],
options?: { token?: boolean }
): Promise<ModelConnectionInfo[]> {
const models: Record<string, ModelConnectionOptions> = {}

// Deduplicate model connection options
for (const template of templates) {
for (const script of scripts) {
const conn: ModelConnectionOptions = {
model: template.model ?? host.defaultModelOptions.model,
model: script.model ?? runtimeHost.modelAliases.large.model,
}
const key = JSON.stringify(conn)
if (!models[key]) models[key] = conn
Expand Down
14 changes: 6 additions & 8 deletions packages/cli/src/nodehost.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ import {
setRuntimeHost,
ResponseStatus,
AzureTokenResolver,
ModelConfigurations,
} from "../../core/src/host"
import { AbortSignalOptions, TraceOptions } from "../../core/src/trace"
import { logError, logVerbose } from "../../core/src/util"
Expand Down Expand Up @@ -139,14 +140,11 @@ export class NodeHost implements RuntimeHost {
readonly workspace = createFileSystem()
readonly containers = new DockerManager()
readonly browsers = new BrowserManager()
readonly defaultModelOptions = {
model: DEFAULT_MODEL,
smallModel: DEFAULT_SMALL_MODEL,
visionModel: DEFAULT_VISION_MODEL,
temperature: DEFAULT_TEMPERATURE,
}
readonly defaultEmbeddingsModelOptions = {
embeddingsModel: DEFAULT_EMBEDDINGS_MODEL,
readonly modelAliases: ModelConfigurations = {
large: { model: DEFAULT_MODEL },
small: { model: DEFAULT_SMALL_MODEL },
vision: { model: DEFAULT_VISION_MODEL },
embeddings: { model: DEFAULT_EMBEDDINGS_MODEL },
}
readonly userInputQueue = new PLimitPromiseQueue(1)
readonly azureToken: AzureTokenResolver
Expand Down
20 changes: 13 additions & 7 deletions packages/cli/src/run.ts
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ import {
stderr,
stdout,
} from "../../core/src/logging"
import { setModelAlias } from "../../core/src/connection"

async function setupTraceWriting(trace: MarkdownTrace, filename: string) {
logVerbose(`trace: ${filename}`)
Expand Down Expand Up @@ -208,21 +209,26 @@ export async function runScript(
const topLogprobs = normalizeInt(options.topLogprobs)

if (options.json || options.yaml) overrideStdoutWithStdErr()
if (options.model) host.defaultModelOptions.model = options.model
if (options.model) runtimeHost.modelAliases.large.model = options.model
if (options.smallModel)
host.defaultModelOptions.smallModel = options.smallModel
runtimeHost.modelAliases.small.model = options.smallModel
if (options.visionModel)
host.defaultModelOptions.visionModel = options.visionModel
runtimeHost.modelAliases.vision.model = options.visionModel
for (const kv of options.modelAlias || []) {
const aliases = parseKeyValuePair(kv)
for (const [key, value] of Object.entries(aliases))
setModelAlias(key, value)
}

const fail = (msg: string, exitCode: number, url?: string) => {
logError(url ? `${msg} (see ${url})` : msg)
return { exitCode, result }
}

logInfo(`genaiscript: ${scriptId}`)
logVerbose(` large : ${host.defaultModelOptions.model}`)
logVerbose(` small : ${host.defaultModelOptions.smallModel}`)
logVerbose(` vision: ${host.defaultModelOptions.visionModel}`)
Object.entries(runtimeHost.modelAliases).forEach(([key, value]) =>
logVerbose(` ${key}: ${value.model}`)
)

if (out) {
if (removeOut) await emptyDir(out)
Expand Down Expand Up @@ -376,7 +382,7 @@ export async function runScript(
model: info.model,
embeddingsModel:
options.embeddingsModel ??
host.defaultEmbeddingsModelOptions.embeddingsModel,
runtimeHost.modelAliases.embeddings.model,
retry,
retryDelay,
maxDelay,
Expand Down
19 changes: 9 additions & 10 deletions packages/cli/src/test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import {
import { promptFooDriver } from "../../core/src/default_prompts"
import { serializeError } from "../../core/src/error"
import { parseKeyValuePairs } from "../../core/src/fence"
import { host } from "../../core/src/host"
import { host, runtimeHost } from "../../core/src/host"
import { JSON5TryParse } from "../../core/src/json5"
import { MarkdownTrace } from "../../core/src/trace"
import {
Expand Down Expand Up @@ -51,7 +51,7 @@ import { filterScripts } from "../../core/src/ast"
* @param m - The string representation of the model specification.
* @returns A ModelOptions object with model, temperature, and topP fields if applicable.
*/
function parseModelSpec(m: string): ModelOptions {
function parseModelSpec(m: string): ModelOptions & ModelAliasesOptions {
const values = m
.split(/&/g)
.map((kv) => kv.split("=", 2))
Expand Down Expand Up @@ -108,14 +108,13 @@ export async function runPromptScriptTests(
testDelay?: string
}
): Promise<PromptScriptTestRunResponse> {
if (options.model) host.defaultModelOptions.model = options.model
if (options.model) runtimeHost.modelAliases.large.model = options.model
if (options.smallModel)
host.defaultModelOptions.smallModel = options.smallModel
runtimeHost.modelAliases.small.model = options.smallModel
if (options.visionModel)
host.defaultModelOptions.visionModel = options.visionModel

logVerbose(
`model: ${host.defaultModelOptions.model}, small model: ${host.defaultModelOptions.smallModel}, vision model: ${host.defaultModelOptions.visionModel}`
runtimeHost.modelAliases.vision.model = options.visionModel
Object.entries(runtimeHost.modelAliases).forEach(([key, value]) =>
logVerbose(` ${key}: ${value.model}`)
)

const scripts = await listTests({ ids, ...(options || {}) })
Expand Down Expand Up @@ -147,12 +146,12 @@ export async function runPromptScriptTests(
: script.filename.replace(GENAI_ANY_REGEX, ".promptfoo.yaml")
logInfo(` ${fn}`)
const { info: chatInfo } = await resolveModelConnectionInfo(script, {
model: host.defaultModelOptions.model,
model: runtimeHost.modelAliases.large.model,
})
if (chatInfo.error) throw new Error(chatInfo.error)
let { info: embeddingsInfo } = await resolveModelConnectionInfo(
script,
{ model: host.defaultEmbeddingsModelOptions.embeddingsModel }
{ model: runtimeHost.modelAliases.embeddings.model }
)
if (embeddingsInfo?.error) embeddingsInfo = undefined
const config = generatePromptFooConfiguration(script, {
Expand Down
Loading

0 comments on commit 19d5037

Please sign in to comment.