Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate support for Transformers models #887

Merged
merged 8 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,063 changes: 1,051 additions & 12 deletions THIRD_PARTY_LICENSES.md

Large diffs are not rendered by default.

21 changes: 21 additions & 0 deletions docs/src/content/docs/getting-started/configuration.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -1059,6 +1059,27 @@ OPENROUTER_SITE_URL=... # populates HTTP-Referer header
OPENROUTER_SITE_NAME=... # populate X-Title header
```

## Hugging Face Transformer.js (experimental)

This `transformers` provider runs models on device using [Hugging Face Transformers.js](https://huggingface.co/docs/transformers.js/index).

The model syntax is `transformers:<repo>:<dtype>` where

- `repo` is the model repository on Hugging Face,
- [`dtype`](https://huggingface.co/docs/transformers.js/guides/dtypes) is the quantization type.

```js "transformers:"
script({
model: "transformers:onnx-community/Qwen2.5-Coder-0.5B-Instruct:q4",
})
```

:::note

This provider is experimental and may not work with all models.

:::

## Model specific environment variables

You can provide different environment variables
Expand Down
6 changes: 3 additions & 3 deletions docs/yarn.lock
Original file line number Diff line number Diff line change
Expand Up @@ -2260,9 +2260,9 @@ dset@^3.1.3, dset@^3.1.4:
integrity sha512-2QF/g9/zTaPDc3BjNcVTGoBbXBgYfMTTceLaYcFJ/W9kggFUkhxD/hMEeuLKbugyef9SqAx8cpgwlIP/jinUTA==

electron-to-chromium@^1.5.41:
version "1.5.63"
resolved "https://registry.yarnpkg.com/electron-to-chromium/-/electron-to-chromium-1.5.63.tgz#69444d592fbbe628d129866c2355691ea93eda3e"
integrity sha512-ddeXKuY9BHo/mw145axlyWjlJ1UBt4WK3AlvkT7W2AbqfRQoacVoRUCF6wL3uIx/8wT9oLKXzI+rFqHHscByaA==
version "1.5.64"
resolved "https://registry.yarnpkg.com/electron-to-chromium/-/electron-to-chromium-1.5.64.tgz#ac8c4c89075d35a1514b620f47dfe48a71ec3697"
integrity sha512-IXEuxU+5ClW2IGEYFC2T7szbyVgehupCWQe5GNh+H065CD6U6IFN0s4KeAMFGNmQolRU4IV7zGBWSYMmZ8uuqQ==

emmet@^2.4.3:
version "2.4.11"
Expand Down
5 changes: 3 additions & 2 deletions packages/cli/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
"xlsx": "https://cdn.sheetjs.com/xlsx-0.20.2/xlsx-0.20.2.tgz"
},
"optionalDependencies": {
"@huggingface/transformers": "^3.0.2",
"@lvce-editor/ripgrep": "^1.4.0",
"pdfjs-dist": "4.8.69",
"playwright": "^1.49.0",
Expand Down Expand Up @@ -94,8 +95,8 @@
"zx": "^8.2.2"
},
"scripts": {
"compile": "esbuild src/main.ts --metafile=./esbuild.meta.json --bundle --platform=node --target=node20 --outfile=built/genaiscript.cjs --external:tsx --external:esbuild --external:get-tsconfig --external:resolve-pkg-maps --external:dockerode --external:pdfjs-dist --external:web-tree-sitter --external:tree-sitter-wasms --external:promptfoo --external:typescript --external:@lvce-editor/ripgrep --external:gpt-3-encoder --external:mammoth --external:xlsx --external:mathjs --external:@azure/identity --external:gpt-tokenizer --external:playwright --external:@inquirer/prompts --external:jimp --external:turndown --external:vectra --external:tabletojson --external:html-to-text --external:@octokit/rest --external:@octokit/plugin-throttling --external:@octokit/plugin-retry --external:@octokit/plugin-paginate-rest --external:skia-canvas && node ../../scripts/patch-cli.mjs",
"compile-debug": "esbuild src/main.ts --sourcemap --metafile=./esbuild.meta.json --bundle --platform=node --target=node20 --outfile=built/genaiscript.cjs --external:tsx --external:esbuild --external:get-tsconfig --external:resolve-pkg-maps --external:dockerode --external:pdfjs-dist --external:web-tree-sitter --external:tree-sitter-wasms --external:promptfoo --external:typescript --external:@lvce-editor/ripgrep --external:gpt-3-encoder --external:mammoth --external:xlsx --external:mathjs --external:@azure/identity --external:gpt-tokenizer --external:playwright --external:@inquirer/prompts --external:jimp --external:turndown --external:vectra --external:tabletojson --external:html-to-text --external:@octokit/rest --external:@octokit/plugin-throttling --external:@octokit/plugin-retry --external:@octokit/plugin-paginate-rest --external:skia-canvas",
"compile": "esbuild src/main.ts --metafile=./esbuild.meta.json --bundle --platform=node --target=node20 --outfile=built/genaiscript.cjs --external:tsx --external:esbuild --external:get-tsconfig --external:resolve-pkg-maps --external:dockerode --external:pdfjs-dist --external:web-tree-sitter --external:tree-sitter-wasms --external:promptfoo --external:typescript --external:@lvce-editor/ripgrep --external:gpt-3-encoder --external:mammoth --external:xlsx --external:mathjs --external:@azure/identity --external:gpt-tokenizer --external:playwright --external:@inquirer/prompts --external:jimp --external:turndown --external:vectra --external:tabletojson --external:html-to-text --external:@octokit/rest --external:@octokit/plugin-throttling --external:@octokit/plugin-retry --external:@octokit/plugin-paginate-rest --external:skia-canvas --external:@huggingface/transformers && node ../../scripts/patch-cli.mjs",
"compile-debug": "esbuild src/main.ts --sourcemap --metafile=./esbuild.meta.json --bundle --platform=node --target=node20 --outfile=built/genaiscript.cjs --external:tsx --external:esbuild --external:get-tsconfig --external:resolve-pkg-maps --external:dockerode --external:pdfjs-dist --external:web-tree-sitter --external:tree-sitter-wasms --external:promptfoo --external:typescript --external:@lvce-editor/ripgrep --external:gpt-3-encoder --external:mammoth --external:xlsx --external:mathjs --external:@azure/identity --external:gpt-tokenizer --external:playwright --external:@inquirer/prompts --external:jimp --external:turndown --external:vectra --external:tabletojson --external:html-to-text --external:@octokit/rest --external:@octokit/plugin-throttling --external:@octokit/plugin-retry --external:@octokit/plugin-paginate-rest --external:skia-canvas --external:@huggingface/transformers",
"postcompile": "node built/genaiscript.cjs info help > ../../docs/src/content/docs/reference/cli/commands.md",
"vis:treemap": "npx --yes esbuild-visualizer --metadata esbuild.meta.json --filename esbuild.treemap.html",
"vis:network": "npx --yes esbuild-visualizer --metadata esbuild.meta.json --filename esbuild.network.html --template network",
Expand Down
1 change: 1 addition & 0 deletions packages/core/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"@anthropic-ai/sdk": "^0.32.1",
"@azure/identity": "^4.5.0",
"@huggingface/jinja": "^0.3.2",
"@huggingface/transformers": "^3.0.2",
"@octokit/plugin-paginate-rest": "^11.3.5",
"@octokit/plugin-retry": "^7.1.2",
"@octokit/plugin-throttling": "^9.3.2",
Expand Down
24 changes: 1 addition & 23 deletions packages/core/src/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -106,29 +106,7 @@ export function toChatCompletionUserMessage(
content: expanded,
}
}
/*
function encodeMessagesForLlama(req: CreateChatCompletionRequest) {
return (
req.messages
.map((msg) => {
switch (msg.role) {
case "user":
return `[INST]\n${msg.content}\n[/INST]`
case "system":
return `[INST] <<SYS>>\n${msg.content}\n<</SYS>>\n[/INST]`
case "assistant":
return msg.content
case "function":
return "???function"
default:
return "???role " + msg.role
}
})
.join("\n")
.replace(/\[\/INST\]\n\[INST\]/g, "\n") + "\n"
)
}
*/

export type ChatCompletionHandler = (
req: CreateChatCompletionRequest,
connection: LanguageModelConfiguration,
Expand Down
11 changes: 11 additions & 0 deletions packages/core/src/connection.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import {
OLLAMA_DEFAUT_PORT,
MODEL_PROVIDER_GOOGLE,
GOOGLE_API_BASE,
MODEL_PROVIDER_TRANSFORMERS,
} from "./constants"
import { fileExists, readText, writeText } from "./fs"
import {
Expand Down Expand Up @@ -407,6 +408,16 @@ export async function parseTokenFromEnv(
}
}

if (provider === MODEL_PROVIDER_TRANSFORMERS) {
return {
provider,
model,
base: undefined,
token: "transformers",
source: "default",
}
}

if (provider === MODEL_PROVIDER_CLIENT && host.clientLanguageModel) {
return {
provider,
Expand Down
8 changes: 8 additions & 0 deletions packages/core/src/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ export const MODEL_PROVIDER_AICI = "aici"
export const MODEL_PROVIDER_CLIENT = "client"
export const MODEL_PROVIDER_ANTHROPIC = "anthropic"
export const MODEL_PROVIDER_HUGGINGFACE = "huggingface"
export const MODEL_PROVIDER_TRANSFORMERS = "transformers"

export const TRACE_FILE_PREVIEW_MAX_LENGTH = 240

Expand Down Expand Up @@ -210,6 +211,8 @@ export const DOCS_CONFIGURATION_GOOGLE_URL =
"https://microsoft.github.io/genaiscript/getting-started/configuration/#google"
export const DOCS_CONFIGURATION_HUGGINGFACE_URL =
"https://microsoft.github.io/genaiscript/getting-started/configuration/#huggingface"
export const DOCS_CONFIGURATION_HUGGINGFACE_TRANSFORMERS_URL =
"https://microsoft.github.io/genaiscript/getting-started/configuration/#transformers"
export const DOCS_CONFIGURATION_CONTENT_SAFETY_URL =
"https://microsoft.github.io/genaiscript/reference/scripts/content-safety"
export const DOCS_DEF_FILES_IS_EMPTY_URL =
Expand Down Expand Up @@ -262,6 +265,11 @@ export const MODEL_PROVIDERS = Object.freeze([
detail: "Hugging Face models",
url: DOCS_CONFIGURATION_HUGGINGFACE_URL,
},
{
id: MODEL_PROVIDER_TRANSFORMERS,
detail: "Hugging Face Transformers",
url: DOCS_CONFIGURATION_HUGGINGFACE_TRANSFORMERS_URL,
},
{
id: MODEL_PROVIDER_OLLAMA,
detail: "Ollama local model",
Expand Down
4 changes: 4 additions & 0 deletions packages/core/src/lm.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import { Transform } from "stream"
import { AICIModel } from "./aici"
import { AnthropicModel } from "./anthropic"
import { LanguageModel } from "./chat"
Expand All @@ -6,10 +7,12 @@ import {
MODEL_PROVIDER_ANTHROPIC,
MODEL_PROVIDER_CLIENT,
MODEL_PROVIDER_OLLAMA,
MODEL_PROVIDER_TRANSFORMERS,
} from "./constants"
import { host } from "./host"
import { OllamaModel } from "./ollama"
import { OpenAIModel } from "./openai"
import { TransformersModel } from "./transformers"

export function resolveLanguageModel(provider: string): LanguageModel {
if (provider === MODEL_PROVIDER_CLIENT) {
Expand All @@ -20,5 +23,6 @@ export function resolveLanguageModel(provider: string): LanguageModel {
if (provider === MODEL_PROVIDER_OLLAMA) return OllamaModel
if (provider === MODEL_PROVIDER_AICI) return AICIModel
if (provider === MODEL_PROVIDER_ANTHROPIC) return AnthropicModel
if (provider === MODEL_PROVIDER_TRANSFORMERS) return TransformersModel
return OpenAIModel
}
2 changes: 1 addition & 1 deletion packages/core/src/ollama.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import { URL } from "url"
* @returns The result of the chat completion.
* @throws Will throw an error if the model cannot be pulled or any other request error occurs.
*/
export const OllamaCompletion: ChatCompletionHandler = async (
const OllamaCompletion: ChatCompletionHandler = async (
req,
cfg,
options,
Expand Down
173 changes: 173 additions & 0 deletions packages/core/src/transformers.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
import { ChatCompletionHandler, LanguageModel } from "./chat"
import { renderMessageContent } from "./chatrender"
import { MODEL_PROVIDER_TRANSFORMERS } from "./constants"
import type {
Chat,
Message,
TextGenerationOutput,
TextGenerationPipeline,
} from "@huggingface/transformers"
import { NotSupportedError } from "./error"
import { ChatCompletionMessageParam, ChatCompletionResponse } from "./chattypes"
import { deleteUndefinedValues, dotGenaiscriptPath, logVerbose } from "./util"
import { parseModelIdentifier } from "./models"
import prettyBytes from "pretty-bytes"
import { hash } from "./crypto"
import {
ChatCompletionRequestCacheKey,
getChatCompletionCache,
} from "./chatcache"
import { PLimitPromiseQueue } from "./concurrency"

type GeneratorProgress =
| {
status: "initiate"
}
| {
status: "progress"
file: string
name: string
progress: number
loaded: number
total: number
}
| { status: "ready"; task: string; model: string }

function progressBar() {
const progress: Record<string, number> = {}
return (cb: GeneratorProgress) => {
switch (cb.status) {
case "progress":
const p = progress[cb.file] || 0
const cp = Math.floor(cb.progress)
if (cp > p + 5) {
progress[cb.file] = cp
logVerbose(`${cb.file}: ${cp}% (${prettyBytes(cb.loaded)})`)
}
break
case "ready": {
logVerbose(`model ${cb.model} ready`)
logVerbose(``)
break
}
}
}
}

const generators: Record<string, Promise<TextGenerationPipeline>> = {}
const generationQueue = new PLimitPromiseQueue(1)

async function loadGenerator(family: string, options: object) {
const h = await hash({ family, options })
let p = generators[h]
if (!p) {
const { pipeline } = await import("@huggingface/transformers")
p = generators[h] = pipeline("text-generation", family, {
...options,
cache_dir: dotGenaiscriptPath("cache", "transformers"),
progress_callback: progressBar(),
})
}
return p
}

export const TransformersCompletion: ChatCompletionHandler = async (
req,
cfg,
options,
trace
) => {
const { messages, temperature, top_p, max_tokens } = req
const { partialCb, inner, cache: cacheOrName, cacheName } = options
const { model, tag, family } = parseModelIdentifier(req.model)

trace.itemValue("model", model)

const cache = !!cacheOrName || !!cacheName
const cacheStore = getChatCompletionCache(
typeof cacheOrName === "string" ? cacheOrName : cacheName
)
const cachedKey = cache
? <ChatCompletionRequestCacheKey>{
...req,
model: req.model,
temperature: req.temperature,
top_p: req.top_p,
max_tokens: req.max_tokens,
logit_bias: req.logit_bias,
}
: undefined
trace.itemValue(`caching`, cache)
trace.itemValue(`cache`, cacheStore?.name)
const { text: cached, finishReason: cachedFinishReason } =
(await cacheStore.get(cachedKey)) || {}
if (cached !== undefined) {
partialCb?.({
tokensSoFar: 0, // TODO
responseSoFar: cached,
responseChunk: cached,
inner,
})
trace.itemValue(`cache hit`, await cacheStore.getKeySHA(cachedKey))
return { text: cached, finishReason: cachedFinishReason, cached: true }
}

const generator = await generationQueue.add(() => loadGenerator(family, {
dtype: tag,
device: "cpu",
}))
const msgs: Chat = chatMessagesToTranformerMessages(messages)
trace.detailsFenced("messages", msgs, "yaml")
const output = (await generator(
msgs,
deleteUndefinedValues({
max_new_tokens: max_tokens || 4000,
temperature,
top_p,
early_stopping: true,
})
)) as TextGenerationOutput
const text = output
.map((msg) => (msg.generated_text.at(-1) as Message).content)
.join("")
trace.fence(text, "markdown")
partialCb?.({
responseSoFar: text,
responseChunk: text,
tokensSoFar: 0,
inner,
})

const finishReason = "stop"
if (finishReason === "stop")
await cacheStore.set(cachedKey, { text, finishReason })

return {
text,
finishReason: "stop",
} satisfies ChatCompletionResponse
}

// Define the Ollama model with its completion handler and model listing function
export const TransformersModel = Object.freeze<LanguageModel>({
completer: TransformersCompletion,
id: MODEL_PROVIDER_TRANSFORMERS,
})

function chatMessagesToTranformerMessages(
messages: ChatCompletionMessageParam[]
): Chat {
return messages.map((msg) => {
switch (msg.role) {
case "function":
case "aici":
case "tool":
throw new NotSupportedError(`role ${msg.role} not supported`)
default:
return {
role: msg.role,
content: renderMessageContent(msg),
} satisfies Message
}
})
}
1 change: 1 addition & 0 deletions packages/core/src/types/prompt_template.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ type ModelType = OptionsOrString<
| "google:gemini-1.5-pro"
| "google:gemini-1.5-pro-002"
| "google:gemini-1-pro"
| "transformers:onnx-community/Qwen2.5-0.5B-Instruct:q4"
>

type ModelSmallType = OptionsOrString<
Expand Down
2 changes: 2 additions & 0 deletions packages/core/src/usage.ts
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,8 @@ export class GenerationStats {
* @param indent - The indentation used for logging.
*/
private logTokens(indent: string) {
if (!this.resolvedModel) return

const unknowns = new Set<string>()
const c = this.cost()
if (this.model && isNaN(c) && isCosteable(this.model))
Expand Down
Loading
Loading