Skip to content

Commit

Permalink
feat: 🤖 integrate transformers model support
Browse files Browse the repository at this point in the history
  • Loading branch information
pelikhan committed Nov 22, 2024
1 parent 5fe5744 commit 36d054f
Show file tree
Hide file tree
Showing 11 changed files with 154 additions and 27 deletions.
5 changes: 3 additions & 2 deletions packages/cli/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
"xlsx": "https://cdn.sheetjs.com/xlsx-0.20.2/xlsx-0.20.2.tgz"
},
"optionalDependencies": {
"@huggingface/transformers": "^3.0.2",
"@lvce-editor/ripgrep": "^1.4.0",
"pdfjs-dist": "4.8.69",
"playwright": "^1.49.0",
Expand Down Expand Up @@ -94,8 +95,8 @@
"zx": "^8.2.2"
},
"scripts": {
"compile": "esbuild src/main.ts --metafile=./esbuild.meta.json --bundle --platform=node --target=node20 --outfile=built/genaiscript.cjs --external:tsx --external:esbuild --external:get-tsconfig --external:resolve-pkg-maps --external:dockerode --external:pdfjs-dist --external:web-tree-sitter --external:tree-sitter-wasms --external:promptfoo --external:typescript --external:@lvce-editor/ripgrep --external:gpt-3-encoder --external:mammoth --external:xlsx --external:mathjs --external:@azure/identity --external:gpt-tokenizer --external:playwright --external:@inquirer/prompts --external:jimp --external:turndown --external:vectra --external:tabletojson --external:html-to-text --external:@octokit/rest --external:@octokit/plugin-throttling --external:@octokit/plugin-retry --external:@octokit/plugin-paginate-rest --external:skia-canvas && node ../../scripts/patch-cli.mjs",
"compile-debug": "esbuild src/main.ts --sourcemap --metafile=./esbuild.meta.json --bundle --platform=node --target=node20 --outfile=built/genaiscript.cjs --external:tsx --external:esbuild --external:get-tsconfig --external:resolve-pkg-maps --external:dockerode --external:pdfjs-dist --external:web-tree-sitter --external:tree-sitter-wasms --external:promptfoo --external:typescript --external:@lvce-editor/ripgrep --external:gpt-3-encoder --external:mammoth --external:xlsx --external:mathjs --external:@azure/identity --external:gpt-tokenizer --external:playwright --external:@inquirer/prompts --external:jimp --external:turndown --external:vectra --external:tabletojson --external:html-to-text --external:@octokit/rest --external:@octokit/plugin-throttling --external:@octokit/plugin-retry --external:@octokit/plugin-paginate-rest --external:skia-canvas",
"compile": "esbuild src/main.ts --metafile=./esbuild.meta.json --bundle --platform=node --target=node20 --outfile=built/genaiscript.cjs --external:tsx --external:esbuild --external:get-tsconfig --external:resolve-pkg-maps --external:dockerode --external:pdfjs-dist --external:web-tree-sitter --external:tree-sitter-wasms --external:promptfoo --external:typescript --external:@lvce-editor/ripgrep --external:gpt-3-encoder --external:mammoth --external:xlsx --external:mathjs --external:@azure/identity --external:gpt-tokenizer --external:playwright --external:@inquirer/prompts --external:jimp --external:turndown --external:vectra --external:tabletojson --external:html-to-text --external:@octokit/rest --external:@octokit/plugin-throttling --external:@octokit/plugin-retry --external:@octokit/plugin-paginate-rest --external:skia-canvas --external:@huggingface/transformers && node ../../scripts/patch-cli.mjs",
"compile-debug": "esbuild src/main.ts --sourcemap --metafile=./esbuild.meta.json --bundle --platform=node --target=node20 --outfile=built/genaiscript.cjs --external:tsx --external:esbuild --external:get-tsconfig --external:resolve-pkg-maps --external:dockerode --external:pdfjs-dist --external:web-tree-sitter --external:tree-sitter-wasms --external:promptfoo --external:typescript --external:@lvce-editor/ripgrep --external:gpt-3-encoder --external:mammoth --external:xlsx --external:mathjs --external:@azure/identity --external:gpt-tokenizer --external:playwright --external:@inquirer/prompts --external:jimp --external:turndown --external:vectra --external:tabletojson --external:html-to-text --external:@octokit/rest --external:@octokit/plugin-throttling --external:@octokit/plugin-retry --external:@octokit/plugin-paginate-rest --external:skia-canvas --external:@huggingface/transformers",
"postcompile": "node built/genaiscript.cjs info help > ../../docs/src/content/docs/reference/cli/commands.md",
"vis:treemap": "npx --yes esbuild-visualizer --metadata esbuild.meta.json --filename esbuild.treemap.html",
"vis:network": "npx --yes esbuild-visualizer --metadata esbuild.meta.json --filename esbuild.network.html --template network",
Expand Down
1 change: 1 addition & 0 deletions packages/core/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"@anthropic-ai/sdk": "^0.32.1",
"@azure/identity": "^4.5.0",
"@huggingface/jinja": "^0.3.2",
"@huggingface/transformers": "^3.0.2",
"@octokit/plugin-paginate-rest": "^11.3.5",
"@octokit/plugin-retry": "^7.1.2",
"@octokit/plugin-throttling": "^9.3.2",
Expand Down
24 changes: 1 addition & 23 deletions packages/core/src/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -106,29 +106,7 @@ export function toChatCompletionUserMessage(
content: expanded,
}
}
/*
function encodeMessagesForLlama(req: CreateChatCompletionRequest) {
return (
req.messages
.map((msg) => {
switch (msg.role) {
case "user":
return `[INST]\n${msg.content}\n[/INST]`
case "system":
return `[INST] <<SYS>>\n${msg.content}\n<</SYS>>\n[/INST]`
case "assistant":
return msg.content
case "function":
return "???function"
default:
return "???role " + msg.role
}
})
.join("\n")
.replace(/\[\/INST\]\n\[INST\]/g, "\n") + "\n"
)
}
*/

export type ChatCompletionHandler = (
req: CreateChatCompletionRequest,
connection: LanguageModelConfiguration,
Expand Down
11 changes: 11 additions & 0 deletions packages/core/src/connection.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import {
OLLAMA_DEFAUT_PORT,
MODEL_PROVIDER_GOOGLE,
GOOGLE_API_BASE,
MODEL_PROVIDER_TRANSFORMERS,
} from "./constants"
import { fileExists, readText, writeText } from "./fs"
import {
Expand Down Expand Up @@ -407,6 +408,16 @@ export async function parseTokenFromEnv(
}
}

if (provider === MODEL_PROVIDER_TRANSFORMERS) {
return {
provider,
model,
base: undefined,
token: "transformers",
source: "default",
}
}

if (provider === MODEL_PROVIDER_CLIENT && host.clientLanguageModel) {
return {
provider,
Expand Down
8 changes: 8 additions & 0 deletions packages/core/src/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ export const MODEL_PROVIDER_AICI = "aici"
export const MODEL_PROVIDER_CLIENT = "client"
export const MODEL_PROVIDER_ANTHROPIC = "anthropic"
export const MODEL_PROVIDER_HUGGINGFACE = "huggingface"
export const MODEL_PROVIDER_TRANSFORMERS = "transformers"

export const TRACE_FILE_PREVIEW_MAX_LENGTH = 240

Expand Down Expand Up @@ -210,6 +211,8 @@ export const DOCS_CONFIGURATION_GOOGLE_URL =
"https://microsoft.github.io/genaiscript/getting-started/configuration/#google"
export const DOCS_CONFIGURATION_HUGGINGFACE_URL =
"https://microsoft.github.io/genaiscript/getting-started/configuration/#huggingface"
export const DOCS_CONFIGURATION_HUGGINGFACE_TRANSFORMERS_URL =
"https://microsoft.github.io/genaiscript/getting-started/configuration/#transformers"
export const DOCS_CONFIGURATION_CONTENT_SAFETY_URL =
"https://microsoft.github.io/genaiscript/reference/scripts/content-safety"
export const DOCS_DEF_FILES_IS_EMPTY_URL =
Expand Down Expand Up @@ -262,6 +265,11 @@ export const MODEL_PROVIDERS = Object.freeze([
detail: "Hugging Face models",
url: DOCS_CONFIGURATION_HUGGINGFACE_URL,
},
{
id: MODEL_PROVIDER_TRANSFORMERS,
detail: "Hugging Face Transformers",
url: DOCS_CONFIGURATION_HUGGINGFACE_TRANSFORMERS_URL,
},
{
id: MODEL_PROVIDER_OLLAMA,
detail: "Ollama local model",
Expand Down
4 changes: 4 additions & 0 deletions packages/core/src/lm.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import { Transform } from "stream"
import { AICIModel } from "./aici"
import { AnthropicModel } from "./anthropic"
import { LanguageModel } from "./chat"
Expand All @@ -6,10 +7,12 @@ import {
MODEL_PROVIDER_ANTHROPIC,
MODEL_PROVIDER_CLIENT,
MODEL_PROVIDER_OLLAMA,
MODEL_PROVIDER_TRANSFORMERS,
} from "./constants"
import { host } from "./host"
import { OllamaModel } from "./ollama"
import { OpenAIModel } from "./openai"
import { TransformersModel } from "./transformers"

export function resolveLanguageModel(provider: string): LanguageModel {
if (provider === MODEL_PROVIDER_CLIENT) {
Expand All @@ -20,5 +23,6 @@ export function resolveLanguageModel(provider: string): LanguageModel {
if (provider === MODEL_PROVIDER_OLLAMA) return OllamaModel
if (provider === MODEL_PROVIDER_AICI) return AICIModel
if (provider === MODEL_PROVIDER_ANTHROPIC) return AnthropicModel
if (provider === MODEL_PROVIDER_TRANSFORMERS) return TransformersModel
return OpenAIModel
}
2 changes: 1 addition & 1 deletion packages/core/src/ollama.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import { URL } from "url"
* @returns The result of the chat completion.
* @throws Will throw an error if the model cannot be pulled or any other request error occurs.
*/
export const OllamaCompletion: ChatCompletionHandler = async (
const OllamaCompletion: ChatCompletionHandler = async (
req,
cfg,
options,
Expand Down
119 changes: 119 additions & 0 deletions packages/core/src/transformers.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import { ChatCompletionHandler, LanguageModel } from "./chat"
import { renderMessageContent } from "./chatrender"
import { MODEL_PROVIDER_TRANSFORMERS } from "./constants"
import type {
Chat,
Message,
TextGenerationOutput,
} from "@huggingface/transformers"
import { NotSupportedError } from "./error"
import { ChatCompletionMessageParam, ChatCompletionResponse } from "./chattypes"
import { deleteUndefinedValues, dotGenaiscriptPath, logVerbose } from "./util"
import { parseModelIdentifier } from "./models"
import prettyBytes from "pretty-bytes"

type GeneratorProgress =
| {
status: "initiate"
}
| {
status: "progress"
file: string
name: string
progress: number
loaded: number
total: number
}

function progressBar() {
const progress: Record<string, number> = {}
return (cb: GeneratorProgress) => {
if (cb.status !== "progress") return
const p = progress[cb.file] || 0
const cp = Math.floor(cb.progress)
if (cp > p) {
progress[cb.file] = cp
logVerbose(`${cb.file}: ${cp}% (${prettyBytes(cb.loaded)})`)
}
}
}

export const TransformersCompletion: ChatCompletionHandler = async (
req,
cfg,
options,
trace
) => {
const { messages, temperature, top_p, max_tokens } = req
const { partialCb, inner } = options
const { model } = parseModelIdentifier(req.model)
try {
trace.startDetails(`transformer`)
trace.itemValue("model", model)

const { pipeline } = await import("@huggingface/transformers")
// Create a text generation pipeline
const generator = await pipeline("text-generation", model, {
dtype: "q4",
cache_dir: dotGenaiscriptPath("transformers"),
progress_callback: progressBar(),
device: "cpu"
})
logVerbose(`transformers model ${model} loaded`)
logVerbose(``)
const msgs: Chat = chatMessagesToTranformerMessages(messages)
trace.detailsFenced("messages", msgs, "yaml")
const output = (await generator(
msgs,
deleteUndefinedValues({
max_new_tokens: max_tokens || 4000,
temperature,
top_p,
early_stopping: true,
})
)) as TextGenerationOutput
const text = output
.map((msg) => (msg.generated_text.at(-1) as Message).content)
.join("\n")
trace.fence(text, "markdown")
partialCb?.({
responseSoFar: text,
responseChunk: text,
tokensSoFar: 0,
inner,
})
return {
text,
finishReason: "stop",
} satisfies ChatCompletionResponse
} catch (e) {
logVerbose(e)
throw e
} finally {
trace.endDetails()
}
}

// Define the Ollama model with its completion handler and model listing function
export const TransformersModel = Object.freeze<LanguageModel>({
completer: TransformersCompletion,
id: MODEL_PROVIDER_TRANSFORMERS,
})

function chatMessagesToTranformerMessages(
messages: ChatCompletionMessageParam[]
): Chat {
return messages.map((msg) => {
switch (msg.role) {
case "function":
case "aici":
case "tool":
throw new NotSupportedError(`role ${msg.role} not supported`)
default:
return {
role: msg.role,
content: renderMessageContent(msg),
} satisfies Message
}
})
}
1 change: 1 addition & 0 deletions packages/core/src/types/prompt_template.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ type ModelType = OptionsOrString<
| "google:gemini-1.5-pro"
| "google:gemini-1.5-pro-002"
| "google:gemini-1-pro"
| "transformers:onnx-community/Qwen2.5-0.5B-Instruct"
>

type ModelSmallType = OptionsOrString<
Expand Down
4 changes: 4 additions & 0 deletions packages/sample/genaisrc/transformers.genai.mjs
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
script({
model: "transformers:onnx-community/Qwen2.5-0.5B-Instruct",
})
$`Write a poem with 2 paragraphs.`
2 changes: 1 addition & 1 deletion packages/vscode/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@
"vscode:update-dts": "npx @vscode/dts dev && mv vscode.*.d.ts src/",
"vscode:prepublish": "yarn run compile",
"compile:icons": "node updatefonts.mjs",
"compile:extension": "esbuild src/extension.ts --sourcemap --metafile=./esbuild.meta.json --bundle --format=cjs --platform=node --target=node20 --outfile=built/extension.js --external:tsx --external:esbuild --external:get-tsconfig --external:resolve-pkg-maps --external:vscode --external:pdfjs-dist --external:web-tree-sitter --external:tree-sitter-wasms --external:@lvce-editor/ripgrep --external:gpt-3-encoder --external:skia-canvas",
"compile:extension": "esbuild src/extension.ts --sourcemap --metafile=./esbuild.meta.json --bundle --format=cjs --platform=node --target=node20 --outfile=built/extension.js --external:tsx --external:esbuild --external:get-tsconfig --external:resolve-pkg-maps --external:vscode --external:pdfjs-dist --external:web-tree-sitter --external:tree-sitter-wasms --external:@lvce-editor/ripgrep --external:gpt-3-encoder --external:skia-canvas --external:@huggingface/transformers",
"compile": "yarn compile:icons && yarn compile:extension",
"vis:treemap": "npx --yes esbuild-visualizer --metadata esbuild.meta.json --filename esbuild.treemap.html",
"vis:network": "npx --yes esbuild-visualizer --metadata esbuild.meta.json --filename esbuild.network.html --template network",
Expand Down

0 comments on commit 36d054f

Please sign in to comment.