diff --git a/packages/cli/package.json b/packages/cli/package.json index 2a107c8ea..98bba50d1 100644 --- a/packages/cli/package.json +++ b/packages/cli/package.json @@ -53,6 +53,7 @@ "xlsx": "https://cdn.sheetjs.com/xlsx-0.20.2/xlsx-0.20.2.tgz" }, "optionalDependencies": { + "@huggingface/transformers": "^3.0.2", "@lvce-editor/ripgrep": "^1.4.0", "pdfjs-dist": "4.8.69", "playwright": "^1.49.0", @@ -94,8 +95,8 @@ "zx": "^8.2.2" }, "scripts": { - "compile": "esbuild src/main.ts --metafile=./esbuild.meta.json --bundle --platform=node --target=node20 --outfile=built/genaiscript.cjs --external:tsx --external:esbuild --external:get-tsconfig --external:resolve-pkg-maps --external:dockerode --external:pdfjs-dist --external:web-tree-sitter --external:tree-sitter-wasms --external:promptfoo --external:typescript --external:@lvce-editor/ripgrep --external:gpt-3-encoder --external:mammoth --external:xlsx --external:mathjs --external:@azure/identity --external:gpt-tokenizer --external:playwright --external:@inquirer/prompts --external:jimp --external:turndown --external:vectra --external:tabletojson --external:html-to-text --external:@octokit/rest --external:@octokit/plugin-throttling --external:@octokit/plugin-retry --external:@octokit/plugin-paginate-rest --external:skia-canvas && node ../../scripts/patch-cli.mjs", - "compile-debug": "esbuild src/main.ts --sourcemap --metafile=./esbuild.meta.json --bundle --platform=node --target=node20 --outfile=built/genaiscript.cjs --external:tsx --external:esbuild --external:get-tsconfig --external:resolve-pkg-maps --external:dockerode --external:pdfjs-dist --external:web-tree-sitter --external:tree-sitter-wasms --external:promptfoo --external:typescript --external:@lvce-editor/ripgrep --external:gpt-3-encoder --external:mammoth --external:xlsx --external:mathjs --external:@azure/identity --external:gpt-tokenizer --external:playwright --external:@inquirer/prompts --external:jimp --external:turndown --external:vectra --external:tabletojson --external:html-to-text --external:@octokit/rest --external:@octokit/plugin-throttling --external:@octokit/plugin-retry --external:@octokit/plugin-paginate-rest --external:skia-canvas", + "compile": "esbuild src/main.ts --metafile=./esbuild.meta.json --bundle --platform=node --target=node20 --outfile=built/genaiscript.cjs --external:tsx --external:esbuild --external:get-tsconfig --external:resolve-pkg-maps --external:dockerode --external:pdfjs-dist --external:web-tree-sitter --external:tree-sitter-wasms --external:promptfoo --external:typescript --external:@lvce-editor/ripgrep --external:gpt-3-encoder --external:mammoth --external:xlsx --external:mathjs --external:@azure/identity --external:gpt-tokenizer --external:playwright --external:@inquirer/prompts --external:jimp --external:turndown --external:vectra --external:tabletojson --external:html-to-text --external:@octokit/rest --external:@octokit/plugin-throttling --external:@octokit/plugin-retry --external:@octokit/plugin-paginate-rest --external:skia-canvas --external:@huggingface/transformers && node ../../scripts/patch-cli.mjs", + "compile-debug": "esbuild src/main.ts --sourcemap --metafile=./esbuild.meta.json --bundle --platform=node --target=node20 --outfile=built/genaiscript.cjs --external:tsx --external:esbuild --external:get-tsconfig --external:resolve-pkg-maps --external:dockerode --external:pdfjs-dist --external:web-tree-sitter --external:tree-sitter-wasms --external:promptfoo --external:typescript --external:@lvce-editor/ripgrep --external:gpt-3-encoder --external:mammoth --external:xlsx --external:mathjs --external:@azure/identity --external:gpt-tokenizer --external:playwright --external:@inquirer/prompts --external:jimp --external:turndown --external:vectra --external:tabletojson --external:html-to-text --external:@octokit/rest --external:@octokit/plugin-throttling --external:@octokit/plugin-retry --external:@octokit/plugin-paginate-rest --external:skia-canvas --external:@huggingface/transformers", "postcompile": "node built/genaiscript.cjs info help > ../../docs/src/content/docs/reference/cli/commands.md", "vis:treemap": "npx --yes esbuild-visualizer --metadata esbuild.meta.json --filename esbuild.treemap.html", "vis:network": "npx --yes esbuild-visualizer --metadata esbuild.meta.json --filename esbuild.network.html --template network", diff --git a/packages/core/package.json b/packages/core/package.json index 7124224a6..13401244b 100644 --- a/packages/core/package.json +++ b/packages/core/package.json @@ -21,6 +21,7 @@ "@anthropic-ai/sdk": "^0.32.1", "@azure/identity": "^4.5.0", "@huggingface/jinja": "^0.3.2", + "@huggingface/transformers": "^3.0.2", "@octokit/plugin-paginate-rest": "^11.3.5", "@octokit/plugin-retry": "^7.1.2", "@octokit/plugin-throttling": "^9.3.2", diff --git a/packages/core/src/chat.ts b/packages/core/src/chat.ts index 8cd7594c4..cfbfcd2ec 100644 --- a/packages/core/src/chat.ts +++ b/packages/core/src/chat.ts @@ -106,29 +106,7 @@ export function toChatCompletionUserMessage( content: expanded, } } -/* -function encodeMessagesForLlama(req: CreateChatCompletionRequest) { - return ( - req.messages - .map((msg) => { - switch (msg.role) { - case "user": - return `[INST]\n${msg.content}\n[/INST]` - case "system": - return `[INST] <>\n${msg.content}\n<>\n[/INST]` - case "assistant": - return msg.content - case "function": - return "???function" - default: - return "???role " + msg.role - } - }) - .join("\n") - .replace(/\[\/INST\]\n\[INST\]/g, "\n") + "\n" - ) -} -*/ + export type ChatCompletionHandler = ( req: CreateChatCompletionRequest, connection: LanguageModelConfiguration, diff --git a/packages/core/src/connection.ts b/packages/core/src/connection.ts index 7ae61b846..0c9cdea22 100644 --- a/packages/core/src/connection.ts +++ b/packages/core/src/connection.ts @@ -25,6 +25,7 @@ import { OLLAMA_DEFAUT_PORT, MODEL_PROVIDER_GOOGLE, GOOGLE_API_BASE, + MODEL_PROVIDER_TRANSFORMERS, } from "./constants" import { fileExists, readText, writeText } from "./fs" import { @@ -407,6 +408,16 @@ export async function parseTokenFromEnv( } } + if (provider === MODEL_PROVIDER_TRANSFORMERS) { + return { + provider, + model, + base: undefined, + token: "transformers", + source: "default", + } + } + if (provider === MODEL_PROVIDER_CLIENT && host.clientLanguageModel) { return { provider, diff --git a/packages/core/src/constants.ts b/packages/core/src/constants.ts index 3a3812f92..f12d40f00 100644 --- a/packages/core/src/constants.ts +++ b/packages/core/src/constants.ts @@ -172,6 +172,7 @@ export const MODEL_PROVIDER_AICI = "aici" export const MODEL_PROVIDER_CLIENT = "client" export const MODEL_PROVIDER_ANTHROPIC = "anthropic" export const MODEL_PROVIDER_HUGGINGFACE = "huggingface" +export const MODEL_PROVIDER_TRANSFORMERS = "transformers" export const TRACE_FILE_PREVIEW_MAX_LENGTH = 240 @@ -210,6 +211,8 @@ export const DOCS_CONFIGURATION_GOOGLE_URL = "https://microsoft.github.io/genaiscript/getting-started/configuration/#google" export const DOCS_CONFIGURATION_HUGGINGFACE_URL = "https://microsoft.github.io/genaiscript/getting-started/configuration/#huggingface" +export const DOCS_CONFIGURATION_HUGGINGFACE_TRANSFORMERS_URL = + "https://microsoft.github.io/genaiscript/getting-started/configuration/#transformers" export const DOCS_CONFIGURATION_CONTENT_SAFETY_URL = "https://microsoft.github.io/genaiscript/reference/scripts/content-safety" export const DOCS_DEF_FILES_IS_EMPTY_URL = @@ -262,6 +265,11 @@ export const MODEL_PROVIDERS = Object.freeze([ detail: "Hugging Face models", url: DOCS_CONFIGURATION_HUGGINGFACE_URL, }, + { + id: MODEL_PROVIDER_TRANSFORMERS, + detail: "Hugging Face Transformers", + url: DOCS_CONFIGURATION_HUGGINGFACE_TRANSFORMERS_URL, + }, { id: MODEL_PROVIDER_OLLAMA, detail: "Ollama local model", diff --git a/packages/core/src/lm.ts b/packages/core/src/lm.ts index 3826dec4c..63f69d800 100644 --- a/packages/core/src/lm.ts +++ b/packages/core/src/lm.ts @@ -1,3 +1,4 @@ +import { Transform } from "stream" import { AICIModel } from "./aici" import { AnthropicModel } from "./anthropic" import { LanguageModel } from "./chat" @@ -6,10 +7,12 @@ import { MODEL_PROVIDER_ANTHROPIC, MODEL_PROVIDER_CLIENT, MODEL_PROVIDER_OLLAMA, + MODEL_PROVIDER_TRANSFORMERS, } from "./constants" import { host } from "./host" import { OllamaModel } from "./ollama" import { OpenAIModel } from "./openai" +import { TransformersModel } from "./transformers" export function resolveLanguageModel(provider: string): LanguageModel { if (provider === MODEL_PROVIDER_CLIENT) { @@ -20,5 +23,6 @@ export function resolveLanguageModel(provider: string): LanguageModel { if (provider === MODEL_PROVIDER_OLLAMA) return OllamaModel if (provider === MODEL_PROVIDER_AICI) return AICIModel if (provider === MODEL_PROVIDER_ANTHROPIC) return AnthropicModel + if (provider === MODEL_PROVIDER_TRANSFORMERS) return TransformersModel return OpenAIModel } diff --git a/packages/core/src/ollama.ts b/packages/core/src/ollama.ts index 1d0eba3c4..5a3dc3ae1 100644 --- a/packages/core/src/ollama.ts +++ b/packages/core/src/ollama.ts @@ -24,7 +24,7 @@ import { URL } from "url" * @returns The result of the chat completion. * @throws Will throw an error if the model cannot be pulled or any other request error occurs. */ -export const OllamaCompletion: ChatCompletionHandler = async ( +const OllamaCompletion: ChatCompletionHandler = async ( req, cfg, options, diff --git a/packages/core/src/transformers.ts b/packages/core/src/transformers.ts new file mode 100644 index 000000000..733411fbf --- /dev/null +++ b/packages/core/src/transformers.ts @@ -0,0 +1,119 @@ +import { ChatCompletionHandler, LanguageModel } from "./chat" +import { renderMessageContent } from "./chatrender" +import { MODEL_PROVIDER_TRANSFORMERS } from "./constants" +import type { + Chat, + Message, + TextGenerationOutput, +} from "@huggingface/transformers" +import { NotSupportedError } from "./error" +import { ChatCompletionMessageParam, ChatCompletionResponse } from "./chattypes" +import { deleteUndefinedValues, dotGenaiscriptPath, logVerbose } from "./util" +import { parseModelIdentifier } from "./models" +import prettyBytes from "pretty-bytes" + +type GeneratorProgress = + | { + status: "initiate" + } + | { + status: "progress" + file: string + name: string + progress: number + loaded: number + total: number + } + +function progressBar() { + const progress: Record = {} + return (cb: GeneratorProgress) => { + if (cb.status !== "progress") return + const p = progress[cb.file] || 0 + const cp = Math.floor(cb.progress) + if (cp > p) { + progress[cb.file] = cp + logVerbose(`${cb.file}: ${cp}% (${prettyBytes(cb.loaded)})`) + } + } +} + +export const TransformersCompletion: ChatCompletionHandler = async ( + req, + cfg, + options, + trace +) => { + const { messages, temperature, top_p, max_tokens } = req + const { partialCb, inner } = options + const { model } = parseModelIdentifier(req.model) + try { + trace.startDetails(`transformer`) + trace.itemValue("model", model) + + const { pipeline } = await import("@huggingface/transformers") + // Create a text generation pipeline + const generator = await pipeline("text-generation", model, { + dtype: "q4", + cache_dir: dotGenaiscriptPath("transformers"), + progress_callback: progressBar(), + device: "cpu" + }) + logVerbose(`transformers model ${model} loaded`) + logVerbose(``) + const msgs: Chat = chatMessagesToTranformerMessages(messages) + trace.detailsFenced("messages", msgs, "yaml") + const output = (await generator( + msgs, + deleteUndefinedValues({ + max_new_tokens: max_tokens || 4000, + temperature, + top_p, + early_stopping: true, + }) + )) as TextGenerationOutput + const text = output + .map((msg) => (msg.generated_text.at(-1) as Message).content) + .join("\n") + trace.fence(text, "markdown") + partialCb?.({ + responseSoFar: text, + responseChunk: text, + tokensSoFar: 0, + inner, + }) + return { + text, + finishReason: "stop", + } satisfies ChatCompletionResponse + } catch (e) { + logVerbose(e) + throw e + } finally { + trace.endDetails() + } +} + +// Define the Ollama model with its completion handler and model listing function +export const TransformersModel = Object.freeze({ + completer: TransformersCompletion, + id: MODEL_PROVIDER_TRANSFORMERS, +}) + +function chatMessagesToTranformerMessages( + messages: ChatCompletionMessageParam[] +): Chat { + return messages.map((msg) => { + switch (msg.role) { + case "function": + case "aici": + case "tool": + throw new NotSupportedError(`role ${msg.role} not supported`) + default: + return { + role: msg.role, + content: renderMessageContent(msg), + } satisfies Message + } + }) +} diff --git a/packages/core/src/types/prompt_template.d.ts b/packages/core/src/types/prompt_template.d.ts index fa6453e18..25d3c0dde 100644 --- a/packages/core/src/types/prompt_template.d.ts +++ b/packages/core/src/types/prompt_template.d.ts @@ -152,6 +152,7 @@ type ModelType = OptionsOrString< | "google:gemini-1.5-pro" | "google:gemini-1.5-pro-002" | "google:gemini-1-pro" + | "transformers:onnx-community/Qwen2.5-0.5B-Instruct" > type ModelSmallType = OptionsOrString< diff --git a/packages/sample/genaisrc/transformers.genai.mjs b/packages/sample/genaisrc/transformers.genai.mjs new file mode 100644 index 000000000..fa0348761 --- /dev/null +++ b/packages/sample/genaisrc/transformers.genai.mjs @@ -0,0 +1,4 @@ +script({ + model: "transformers:onnx-community/Qwen2.5-0.5B-Instruct", +}) +$`Write a poem with 2 paragraphs.` diff --git a/packages/vscode/package.json b/packages/vscode/package.json index a503fb042..bfdc49bc2 100644 --- a/packages/vscode/package.json +++ b/packages/vscode/package.json @@ -415,7 +415,7 @@ "vscode:update-dts": "npx @vscode/dts dev && mv vscode.*.d.ts src/", "vscode:prepublish": "yarn run compile", "compile:icons": "node updatefonts.mjs", - "compile:extension": "esbuild src/extension.ts --sourcemap --metafile=./esbuild.meta.json --bundle --format=cjs --platform=node --target=node20 --outfile=built/extension.js --external:tsx --external:esbuild --external:get-tsconfig --external:resolve-pkg-maps --external:vscode --external:pdfjs-dist --external:web-tree-sitter --external:tree-sitter-wasms --external:@lvce-editor/ripgrep --external:gpt-3-encoder --external:skia-canvas", + "compile:extension": "esbuild src/extension.ts --sourcemap --metafile=./esbuild.meta.json --bundle --format=cjs --platform=node --target=node20 --outfile=built/extension.js --external:tsx --external:esbuild --external:get-tsconfig --external:resolve-pkg-maps --external:vscode --external:pdfjs-dist --external:web-tree-sitter --external:tree-sitter-wasms --external:@lvce-editor/ripgrep --external:gpt-3-encoder --external:skia-canvas --external:@huggingface/transformers", "compile": "yarn compile:icons && yarn compile:extension", "vis:treemap": "npx --yes esbuild-visualizer --metadata esbuild.meta.json --filename esbuild.treemap.html", "vis:network": "npx --yes esbuild-visualizer --metadata esbuild.meta.json --filename esbuild.network.html --template network",