diff --git a/packages/cli/package.json b/packages/cli/package.json
index 2a107c8ea..98bba50d1 100644
--- a/packages/cli/package.json
+++ b/packages/cli/package.json
@@ -53,6 +53,7 @@
"xlsx": "https://cdn.sheetjs.com/xlsx-0.20.2/xlsx-0.20.2.tgz"
},
"optionalDependencies": {
+ "@huggingface/transformers": "^3.0.2",
"@lvce-editor/ripgrep": "^1.4.0",
"pdfjs-dist": "4.8.69",
"playwright": "^1.49.0",
@@ -94,8 +95,8 @@
"zx": "^8.2.2"
},
"scripts": {
- "compile": "esbuild src/main.ts --metafile=./esbuild.meta.json --bundle --platform=node --target=node20 --outfile=built/genaiscript.cjs --external:tsx --external:esbuild --external:get-tsconfig --external:resolve-pkg-maps --external:dockerode --external:pdfjs-dist --external:web-tree-sitter --external:tree-sitter-wasms --external:promptfoo --external:typescript --external:@lvce-editor/ripgrep --external:gpt-3-encoder --external:mammoth --external:xlsx --external:mathjs --external:@azure/identity --external:gpt-tokenizer --external:playwright --external:@inquirer/prompts --external:jimp --external:turndown --external:vectra --external:tabletojson --external:html-to-text --external:@octokit/rest --external:@octokit/plugin-throttling --external:@octokit/plugin-retry --external:@octokit/plugin-paginate-rest --external:skia-canvas && node ../../scripts/patch-cli.mjs",
- "compile-debug": "esbuild src/main.ts --sourcemap --metafile=./esbuild.meta.json --bundle --platform=node --target=node20 --outfile=built/genaiscript.cjs --external:tsx --external:esbuild --external:get-tsconfig --external:resolve-pkg-maps --external:dockerode --external:pdfjs-dist --external:web-tree-sitter --external:tree-sitter-wasms --external:promptfoo --external:typescript --external:@lvce-editor/ripgrep --external:gpt-3-encoder --external:mammoth --external:xlsx --external:mathjs --external:@azure/identity --external:gpt-tokenizer --external:playwright --external:@inquirer/prompts --external:jimp --external:turndown --external:vectra --external:tabletojson --external:html-to-text --external:@octokit/rest --external:@octokit/plugin-throttling --external:@octokit/plugin-retry --external:@octokit/plugin-paginate-rest --external:skia-canvas",
+ "compile": "esbuild src/main.ts --metafile=./esbuild.meta.json --bundle --platform=node --target=node20 --outfile=built/genaiscript.cjs --external:tsx --external:esbuild --external:get-tsconfig --external:resolve-pkg-maps --external:dockerode --external:pdfjs-dist --external:web-tree-sitter --external:tree-sitter-wasms --external:promptfoo --external:typescript --external:@lvce-editor/ripgrep --external:gpt-3-encoder --external:mammoth --external:xlsx --external:mathjs --external:@azure/identity --external:gpt-tokenizer --external:playwright --external:@inquirer/prompts --external:jimp --external:turndown --external:vectra --external:tabletojson --external:html-to-text --external:@octokit/rest --external:@octokit/plugin-throttling --external:@octokit/plugin-retry --external:@octokit/plugin-paginate-rest --external:skia-canvas --external:@huggingface/transformers && node ../../scripts/patch-cli.mjs",
+ "compile-debug": "esbuild src/main.ts --sourcemap --metafile=./esbuild.meta.json --bundle --platform=node --target=node20 --outfile=built/genaiscript.cjs --external:tsx --external:esbuild --external:get-tsconfig --external:resolve-pkg-maps --external:dockerode --external:pdfjs-dist --external:web-tree-sitter --external:tree-sitter-wasms --external:promptfoo --external:typescript --external:@lvce-editor/ripgrep --external:gpt-3-encoder --external:mammoth --external:xlsx --external:mathjs --external:@azure/identity --external:gpt-tokenizer --external:playwright --external:@inquirer/prompts --external:jimp --external:turndown --external:vectra --external:tabletojson --external:html-to-text --external:@octokit/rest --external:@octokit/plugin-throttling --external:@octokit/plugin-retry --external:@octokit/plugin-paginate-rest --external:skia-canvas --external:@huggingface/transformers",
"postcompile": "node built/genaiscript.cjs info help > ../../docs/src/content/docs/reference/cli/commands.md",
"vis:treemap": "npx --yes esbuild-visualizer --metadata esbuild.meta.json --filename esbuild.treemap.html",
"vis:network": "npx --yes esbuild-visualizer --metadata esbuild.meta.json --filename esbuild.network.html --template network",
diff --git a/packages/core/package.json b/packages/core/package.json
index 7124224a6..13401244b 100644
--- a/packages/core/package.json
+++ b/packages/core/package.json
@@ -21,6 +21,7 @@
"@anthropic-ai/sdk": "^0.32.1",
"@azure/identity": "^4.5.0",
"@huggingface/jinja": "^0.3.2",
+ "@huggingface/transformers": "^3.0.2",
"@octokit/plugin-paginate-rest": "^11.3.5",
"@octokit/plugin-retry": "^7.1.2",
"@octokit/plugin-throttling": "^9.3.2",
diff --git a/packages/core/src/chat.ts b/packages/core/src/chat.ts
index 8cd7594c4..cfbfcd2ec 100644
--- a/packages/core/src/chat.ts
+++ b/packages/core/src/chat.ts
@@ -106,29 +106,7 @@ export function toChatCompletionUserMessage(
content: expanded,
}
}
-/*
-function encodeMessagesForLlama(req: CreateChatCompletionRequest) {
- return (
- req.messages
- .map((msg) => {
- switch (msg.role) {
- case "user":
- return `[INST]\n${msg.content}\n[/INST]`
- case "system":
- return `[INST] <>\n${msg.content}\n<>\n[/INST]`
- case "assistant":
- return msg.content
- case "function":
- return "???function"
- default:
- return "???role " + msg.role
- }
- })
- .join("\n")
- .replace(/\[\/INST\]\n\[INST\]/g, "\n") + "\n"
- )
-}
-*/
+
export type ChatCompletionHandler = (
req: CreateChatCompletionRequest,
connection: LanguageModelConfiguration,
diff --git a/packages/core/src/connection.ts b/packages/core/src/connection.ts
index 7ae61b846..0c9cdea22 100644
--- a/packages/core/src/connection.ts
+++ b/packages/core/src/connection.ts
@@ -25,6 +25,7 @@ import {
OLLAMA_DEFAUT_PORT,
MODEL_PROVIDER_GOOGLE,
GOOGLE_API_BASE,
+ MODEL_PROVIDER_TRANSFORMERS,
} from "./constants"
import { fileExists, readText, writeText } from "./fs"
import {
@@ -407,6 +408,16 @@ export async function parseTokenFromEnv(
}
}
+ if (provider === MODEL_PROVIDER_TRANSFORMERS) {
+ return {
+ provider,
+ model,
+ base: undefined,
+ token: "transformers",
+ source: "default",
+ }
+ }
+
if (provider === MODEL_PROVIDER_CLIENT && host.clientLanguageModel) {
return {
provider,
diff --git a/packages/core/src/constants.ts b/packages/core/src/constants.ts
index 3a3812f92..f12d40f00 100644
--- a/packages/core/src/constants.ts
+++ b/packages/core/src/constants.ts
@@ -172,6 +172,7 @@ export const MODEL_PROVIDER_AICI = "aici"
export const MODEL_PROVIDER_CLIENT = "client"
export const MODEL_PROVIDER_ANTHROPIC = "anthropic"
export const MODEL_PROVIDER_HUGGINGFACE = "huggingface"
+export const MODEL_PROVIDER_TRANSFORMERS = "transformers"
export const TRACE_FILE_PREVIEW_MAX_LENGTH = 240
@@ -210,6 +211,8 @@ export const DOCS_CONFIGURATION_GOOGLE_URL =
"https://microsoft.github.io/genaiscript/getting-started/configuration/#google"
export const DOCS_CONFIGURATION_HUGGINGFACE_URL =
"https://microsoft.github.io/genaiscript/getting-started/configuration/#huggingface"
+export const DOCS_CONFIGURATION_HUGGINGFACE_TRANSFORMERS_URL =
+ "https://microsoft.github.io/genaiscript/getting-started/configuration/#transformers"
export const DOCS_CONFIGURATION_CONTENT_SAFETY_URL =
"https://microsoft.github.io/genaiscript/reference/scripts/content-safety"
export const DOCS_DEF_FILES_IS_EMPTY_URL =
@@ -262,6 +265,11 @@ export const MODEL_PROVIDERS = Object.freeze([
detail: "Hugging Face models",
url: DOCS_CONFIGURATION_HUGGINGFACE_URL,
},
+ {
+ id: MODEL_PROVIDER_TRANSFORMERS,
+ detail: "Hugging Face Transformers",
+ url: DOCS_CONFIGURATION_HUGGINGFACE_TRANSFORMERS_URL,
+ },
{
id: MODEL_PROVIDER_OLLAMA,
detail: "Ollama local model",
diff --git a/packages/core/src/lm.ts b/packages/core/src/lm.ts
index 3826dec4c..63f69d800 100644
--- a/packages/core/src/lm.ts
+++ b/packages/core/src/lm.ts
@@ -1,3 +1,4 @@
+import { Transform } from "stream"
import { AICIModel } from "./aici"
import { AnthropicModel } from "./anthropic"
import { LanguageModel } from "./chat"
@@ -6,10 +7,12 @@ import {
MODEL_PROVIDER_ANTHROPIC,
MODEL_PROVIDER_CLIENT,
MODEL_PROVIDER_OLLAMA,
+ MODEL_PROVIDER_TRANSFORMERS,
} from "./constants"
import { host } from "./host"
import { OllamaModel } from "./ollama"
import { OpenAIModel } from "./openai"
+import { TransformersModel } from "./transformers"
export function resolveLanguageModel(provider: string): LanguageModel {
if (provider === MODEL_PROVIDER_CLIENT) {
@@ -20,5 +23,6 @@ export function resolveLanguageModel(provider: string): LanguageModel {
if (provider === MODEL_PROVIDER_OLLAMA) return OllamaModel
if (provider === MODEL_PROVIDER_AICI) return AICIModel
if (provider === MODEL_PROVIDER_ANTHROPIC) return AnthropicModel
+ if (provider === MODEL_PROVIDER_TRANSFORMERS) return TransformersModel
return OpenAIModel
}
diff --git a/packages/core/src/ollama.ts b/packages/core/src/ollama.ts
index 1d0eba3c4..5a3dc3ae1 100644
--- a/packages/core/src/ollama.ts
+++ b/packages/core/src/ollama.ts
@@ -24,7 +24,7 @@ import { URL } from "url"
* @returns The result of the chat completion.
* @throws Will throw an error if the model cannot be pulled or any other request error occurs.
*/
-export const OllamaCompletion: ChatCompletionHandler = async (
+const OllamaCompletion: ChatCompletionHandler = async (
req,
cfg,
options,
diff --git a/packages/core/src/transformers.ts b/packages/core/src/transformers.ts
new file mode 100644
index 000000000..733411fbf
--- /dev/null
+++ b/packages/core/src/transformers.ts
@@ -0,0 +1,119 @@
+import { ChatCompletionHandler, LanguageModel } from "./chat"
+import { renderMessageContent } from "./chatrender"
+import { MODEL_PROVIDER_TRANSFORMERS } from "./constants"
+import type {
+ Chat,
+ Message,
+ TextGenerationOutput,
+} from "@huggingface/transformers"
+import { NotSupportedError } from "./error"
+import { ChatCompletionMessageParam, ChatCompletionResponse } from "./chattypes"
+import { deleteUndefinedValues, dotGenaiscriptPath, logVerbose } from "./util"
+import { parseModelIdentifier } from "./models"
+import prettyBytes from "pretty-bytes"
+
+type GeneratorProgress =
+ | {
+ status: "initiate"
+ }
+ | {
+ status: "progress"
+ file: string
+ name: string
+ progress: number
+ loaded: number
+ total: number
+ }
+
+function progressBar() {
+ const progress: Record = {}
+ return (cb: GeneratorProgress) => {
+ if (cb.status !== "progress") return
+ const p = progress[cb.file] || 0
+ const cp = Math.floor(cb.progress)
+ if (cp > p) {
+ progress[cb.file] = cp
+ logVerbose(`${cb.file}: ${cp}% (${prettyBytes(cb.loaded)})`)
+ }
+ }
+}
+
+export const TransformersCompletion: ChatCompletionHandler = async (
+ req,
+ cfg,
+ options,
+ trace
+) => {
+ const { messages, temperature, top_p, max_tokens } = req
+ const { partialCb, inner } = options
+ const { model } = parseModelIdentifier(req.model)
+ try {
+ trace.startDetails(`transformer`)
+ trace.itemValue("model", model)
+
+ const { pipeline } = await import("@huggingface/transformers")
+ // Create a text generation pipeline
+ const generator = await pipeline("text-generation", model, {
+ dtype: "q4",
+ cache_dir: dotGenaiscriptPath("transformers"),
+ progress_callback: progressBar(),
+ device: "cpu"
+ })
+ logVerbose(`transformers model ${model} loaded`)
+ logVerbose(``)
+ const msgs: Chat = chatMessagesToTranformerMessages(messages)
+ trace.detailsFenced("messages", msgs, "yaml")
+ const output = (await generator(
+ msgs,
+ deleteUndefinedValues({
+ max_new_tokens: max_tokens || 4000,
+ temperature,
+ top_p,
+ early_stopping: true,
+ })
+ )) as TextGenerationOutput
+ const text = output
+ .map((msg) => (msg.generated_text.at(-1) as Message).content)
+ .join("\n")
+ trace.fence(text, "markdown")
+ partialCb?.({
+ responseSoFar: text,
+ responseChunk: text,
+ tokensSoFar: 0,
+ inner,
+ })
+ return {
+ text,
+ finishReason: "stop",
+ } satisfies ChatCompletionResponse
+ } catch (e) {
+ logVerbose(e)
+ throw e
+ } finally {
+ trace.endDetails()
+ }
+}
+
+// Define the Ollama model with its completion handler and model listing function
+export const TransformersModel = Object.freeze({
+ completer: TransformersCompletion,
+ id: MODEL_PROVIDER_TRANSFORMERS,
+})
+
+function chatMessagesToTranformerMessages(
+ messages: ChatCompletionMessageParam[]
+): Chat {
+ return messages.map((msg) => {
+ switch (msg.role) {
+ case "function":
+ case "aici":
+ case "tool":
+ throw new NotSupportedError(`role ${msg.role} not supported`)
+ default:
+ return {
+ role: msg.role,
+ content: renderMessageContent(msg),
+ } satisfies Message
+ }
+ })
+}
diff --git a/packages/core/src/types/prompt_template.d.ts b/packages/core/src/types/prompt_template.d.ts
index fa6453e18..25d3c0dde 100644
--- a/packages/core/src/types/prompt_template.d.ts
+++ b/packages/core/src/types/prompt_template.d.ts
@@ -152,6 +152,7 @@ type ModelType = OptionsOrString<
| "google:gemini-1.5-pro"
| "google:gemini-1.5-pro-002"
| "google:gemini-1-pro"
+ | "transformers:onnx-community/Qwen2.5-0.5B-Instruct"
>
type ModelSmallType = OptionsOrString<
diff --git a/packages/sample/genaisrc/transformers.genai.mjs b/packages/sample/genaisrc/transformers.genai.mjs
new file mode 100644
index 000000000..fa0348761
--- /dev/null
+++ b/packages/sample/genaisrc/transformers.genai.mjs
@@ -0,0 +1,4 @@
+script({
+ model: "transformers:onnx-community/Qwen2.5-0.5B-Instruct",
+})
+$`Write a poem with 2 paragraphs.`
diff --git a/packages/vscode/package.json b/packages/vscode/package.json
index a503fb042..bfdc49bc2 100644
--- a/packages/vscode/package.json
+++ b/packages/vscode/package.json
@@ -415,7 +415,7 @@
"vscode:update-dts": "npx @vscode/dts dev && mv vscode.*.d.ts src/",
"vscode:prepublish": "yarn run compile",
"compile:icons": "node updatefonts.mjs",
- "compile:extension": "esbuild src/extension.ts --sourcemap --metafile=./esbuild.meta.json --bundle --format=cjs --platform=node --target=node20 --outfile=built/extension.js --external:tsx --external:esbuild --external:get-tsconfig --external:resolve-pkg-maps --external:vscode --external:pdfjs-dist --external:web-tree-sitter --external:tree-sitter-wasms --external:@lvce-editor/ripgrep --external:gpt-3-encoder --external:skia-canvas",
+ "compile:extension": "esbuild src/extension.ts --sourcemap --metafile=./esbuild.meta.json --bundle --format=cjs --platform=node --target=node20 --outfile=built/extension.js --external:tsx --external:esbuild --external:get-tsconfig --external:resolve-pkg-maps --external:vscode --external:pdfjs-dist --external:web-tree-sitter --external:tree-sitter-wasms --external:@lvce-editor/ripgrep --external:gpt-3-encoder --external:skia-canvas --external:@huggingface/transformers",
"compile": "yarn compile:icons && yarn compile:extension",
"vis:treemap": "npx --yes esbuild-visualizer --metadata esbuild.meta.json --filename esbuild.treemap.html",
"vis:network": "npx --yes esbuild-visualizer --metadata esbuild.meta.json --filename esbuild.network.html --template network",