From 890c8ceb31ce799561f64112493ddbd66e4d1087 Mon Sep 17 00:00:00 2001 From: Peli de Halleux Date: Tue, 30 Jul 2024 15:21:06 +0200 Subject: [PATCH] support for vscode language models (#595) * add client provider * don't check for languageModels proposal * plumbings for chunks * finished server side * clean phase * added cancellation * renaming * client messages * more plumbing * updated display * more routing * bail out the default client configuration * disable lm access * upgrade deps * working proto * don't ask twice --- docs/package.json | 2 +- .../content/docs/reference/cli/commands.md | 2 +- packages/cli/package.json | 4 +- packages/cli/src/nodehost.ts | 12 +- packages/cli/src/run.ts | 9 +- packages/cli/src/server.ts | 94 +++++++- packages/core/package.json | 4 +- packages/core/src/connection.ts | 10 + packages/core/src/constants.ts | 1 + packages/core/src/generation.ts | 1 - packages/core/src/host.ts | 5 +- packages/core/src/lm.ts | 26 ++- packages/core/src/models.ts | 9 +- packages/core/src/pdf.ts | 4 +- packages/core/src/promptcontext.ts | 5 +- packages/core/src/promptrunner.ts | 3 +- packages/core/src/server/client.ts | 37 +++- packages/core/src/server/messages.ts | 31 ++- packages/core/src/testhost.ts | 2 + packages/sample/genaisrc/poem.genai.mts | 1 + packages/vscode/package.json | 2 +- packages/vscode/src/lmaccess.ts | 201 +++++++++--------- packages/vscode/src/servermanager.ts | 2 + packages/vscode/src/state.ts | 53 +---- packages/vscode/src/vshost.ts | 2 + yarn.lock | 76 ++++--- 26 files changed, 364 insertions(+), 234 deletions(-) create mode 100644 packages/sample/genaisrc/poem.genai.mts diff --git a/docs/package.json b/docs/package.json index 82608bec99..b29f743b92 100644 --- a/docs/package.json +++ b/docs/package.json @@ -19,7 +19,7 @@ }, "dependencies": { "@astrojs/check": "^0.8.3", - "@astrojs/starlight": "^0.25.2", + "@astrojs/starlight": "^0.25.3", "astro": "^4.12.2", "typescript": "5.5.4" }, diff --git a/docs/src/content/docs/reference/cli/commands.md b/docs/src/content/docs/reference/cli/commands.md index c905f393b5..9a6a719095 100644 --- a/docs/src/content/docs/reference/cli/commands.md +++ b/docs/src/content/docs/reference/cli/commands.md @@ -87,7 +87,7 @@ Options: -td, --test-delay delay between tests in seconds --no-cache disable LLM result cache -v, --verbose verbose output - -pv, --promptfoo-version [version] promptfoo version, default is ^0.73.6 + -pv, --promptfoo-version [version] promptfoo version, default is ^0.73.8 -os, --out-summary append output summary in file -h, --help display help for command ``` diff --git a/packages/cli/package.json b/packages/cli/package.json index 3e34cae107..245162fa58 100644 --- a/packages/cli/package.json +++ b/packages/cli/package.json @@ -32,7 +32,7 @@ "mammoth": "^1.8.0", "mathjs": "^13.0.3", "pdfjs-dist": "4.4.168", - "promptfoo": "^0.73.6", + "promptfoo": "^0.73.8", "tree-sitter-wasms": "^0.1.11", "tsx": "^4.16.2", "typescript": "5.5.4", @@ -63,7 +63,7 @@ "glob": "^11.0.0", "memorystream": "^0.3.1", "node-sarif-builder": "^3.1.0", - "openai": "^4.53.1", + "openai": "^4.53.2", "ora": "^8.0.1", "pretty-bytes": "^6.1.1", "prompts": "^2.4.2", diff --git a/packages/cli/src/nodehost.ts b/packages/cli/src/nodehost.ts index f7b3dd240e..f1533957da 100644 --- a/packages/cli/src/nodehost.ts +++ b/packages/cli/src/nodehost.ts @@ -20,7 +20,6 @@ import { DEFAULT_MODEL, DEFAULT_TEMPERATURE, MODEL_PROVIDER_AZURE, - AZURE_OPENAI_TOKEN_SCOPES, SHELL_EXEC_TIMEOUT, DOT_ENV_FILENAME, MODEL_PROVIDER_OLLAMA, @@ -43,6 +42,7 @@ import { AbortSignalOptions, TraceOptions } from "../../core/src/trace" import { logVerbose, unique } from "../../core/src/util" import { parseModelIdentifier } from "../../core/src/models" import { createAzureToken } from "./azuretoken" +import { LanguageModel } from "../../core/src/chat" class NodeServerManager implements ServerManager { async start(): Promise { @@ -135,6 +135,7 @@ export class NodeHost implements RuntimeHost { private async parseDefaults() { await parseDefaultsFromEnv(process.env) } + clientLanguageModel: LanguageModel private _azureToken: string async getLanguageModelConfiguration( @@ -155,6 +156,15 @@ export class NodeHost implements RuntimeHost { if (!this._azureToken) throw new Error("Azure token not available") tok.token = "Bearer " + this._azureToken } + if (!tok && this.clientLanguageModel) { + logVerbose(`model: using client language model`) + return { + model: modelId, + provider: this.clientLanguageModel.id, + source: "client", + } + } + return tok } diff --git a/packages/cli/src/run.ts b/packages/cli/src/run.ts index ede84aeb3e..23310b53cf 100644 --- a/packages/cli/src/run.ts +++ b/packages/cli/src/run.ts @@ -76,11 +76,7 @@ export async function runScript( partialCb?: (progress: ChatCompletionsProgressReport) => void } ): Promise<{ exitCode: number; result?: GenerationResult }> { - const { - trace = new MarkdownTrace(), - infoCb, - partialCb, - } = options || {} + const { trace = new MarkdownTrace(), infoCb, partialCb } = options || {} let result: GenerationResult const excludedFiles = options.excludedFiles const excludeGitIgnore = !!options.excludeGitIgnore @@ -173,8 +169,7 @@ export async function runScript( if (options.label) trace.heading(2, options.label) const { info } = await resolveModelConnectionInfo(script, { trace, - model: - options.model ?? script.model ?? host.defaultModelOptions.model, + model: options.model, }) if (info.error) { trace.error(undefined, info.error) diff --git a/packages/cli/src/server.ts b/packages/cli/src/server.ts index 8ad69a6df8..1373996fc0 100644 --- a/packages/cli/src/server.ts +++ b/packages/cli/src/server.ts @@ -8,7 +8,7 @@ import { TRACE_CHUNK, USER_CANCELLED_ERROR_CODE, UNHANDLED_ERROR_CODE, - DOCKER_DEFAULT_IMAGE, + MODEL_PROVIDER_CLIENT, } from "../../core/src/constants" import { isCancelError, @@ -16,22 +16,32 @@ import { serializeError, } from "../../core/src/error" import { + LanguageModelConfiguration, ResponseStatus, ServerResponse, + host, runtimeHost, } from "../../core/src/host" import { MarkdownTrace, TraceChunkEvent } from "../../core/src/trace" import { logVerbose, logError, assert } from "../../core/src/util" import { CORE_VERSION } from "../../core/src/version" -import { YAMLStringify } from "../../core/src/yaml" import { RequestMessages, PromptScriptProgressResponseEvent, PromptScriptEndResponseEvent, ShellExecResponse, + ChatStart, + ChatChunk, + ChatCancel, } from "../../core/src/server/messages" import { envInfo } from "./info" -import { estimateTokens } from "../../core/src/tokens" +import { LanguageModel } from "../../core/src/chat" +import { + ChatCompletionResponse, + ChatCompletionsOptions, + CreateChatCompletionRequest, +} from "../../core/src/chattypes" +import { randomHex } from "../../core/src/crypto" export async function startServer(options: { port: string }) { const port = parseInt(options.port) || SERVER_PORT @@ -45,6 +55,7 @@ export async function startServer(options: { port: string }) { runner: Promise } > = {} + const chats: Record Promise> = {} const cancelAll = () => { for (const [runId, run] of Object.entries(runs)) { @@ -52,8 +63,80 @@ export async function startServer(options: { port: string }) { run.canceller.abort("closing") delete runs[runId] } + for (const [chatId, chat] of Object.entries(chats)) { + console.log(`abort chat ${chat}`) + for (const ws of wss.clients) { + ws.send( + JSON.stringify({ + type: "chat.cancel", + chatId, + }) + ) + break + } + + delete chats[chatId] + } } + const handleChunk = async (chunk: ChatChunk) => { + const handler = chats[chunk.chatId] + if (handler) { + if (chunk.finishReason) delete chats[chunk.chatId] + await handler(chunk) + } + } + + host.clientLanguageModel = Object.freeze({ + id: MODEL_PROVIDER_CLIENT, + completer: async ( + req: CreateChatCompletionRequest, + connection: LanguageModelConfiguration, + options: ChatCompletionsOptions, + trace: MarkdownTrace + ): Promise => { + const { messages, model } = req + const { partialCb } = options + if (!wss.clients.size) throw new Error("no llm clients connected") + + return new Promise((resolve, reject) => { + let responseSoFar: string = "" + let tokensSoFar: number = 0 + let finishReason: ChatCompletionResponse["finishReason"] + + // add handler + const chatId = randomHex(6) + chats[chatId] = async (chunk) => { + responseSoFar += chunk.chunk ?? "" + tokensSoFar += chunk.tokens ?? 0 + partialCb?.({ + tokensSoFar, + responseSoFar, + responseChunk: chunk.chunk, + }) + finishReason = chunk.finishReason as any + if (finishReason) { + delete chats[chatId] + resolve({ text: responseSoFar, finishReason }) + } + } + + // ask for LLM + const msg = JSON.stringify({ + type: "chat.start", + chatId, + model, + messages, + }) + for (const ws of wss.clients) { + trace.log(`chat: sending request to client`) + ws.send(msg) + break + } + }) + }, + }) + // cleanup runs wss.on("close", () => { cancelAll() @@ -231,6 +314,11 @@ export async function startServer(options: { port: string }) { } break } + case "chat.chunk": { + await handleChunk(data) + response = { ok: true } + break + } default: throw new Error(`unknown message type ${type}`) } diff --git a/packages/core/package.json b/packages/core/package.json index 9d39970f0b..474f09dca0 100644 --- a/packages/core/package.json +++ b/packages/core/package.json @@ -30,7 +30,7 @@ "csv-parse": "^5.5.6", "dotenv": "^16.4.5", "esbuild": "^0.23.0", - "fast-xml-parser": "^4.4.0", + "fast-xml-parser": "^4.4.1", "fetch-retry": "^6.0.0", "fflate": "^0.8.2", "file-type": "^19.3.0", @@ -49,7 +49,7 @@ "mime-types": "^2.1.35", "minimatch": "^10.0.1", "minisearch": "^7.1.0", - "openai": "^4.53.1", + "openai": "^4.53.2", "parse-diff": "^0.11.1", "prettier": "^3.3.3", "pretty-bytes": "^6.1.1", diff --git a/packages/core/src/connection.ts b/packages/core/src/connection.ts index 216bb1139e..0c035e0aaa 100644 --- a/packages/core/src/connection.ts +++ b/packages/core/src/connection.ts @@ -14,6 +14,7 @@ import { LOCALAI_API_BASE, MODEL_PROVIDER_AICI, MODEL_PROVIDER_AZURE, + MODEL_PROVIDER_CLIENT, MODEL_PROVIDER_LITELLM, MODEL_PROVIDER_LLAMAFILE, MODEL_PROVIDER_OLLAMA, @@ -217,6 +218,15 @@ export async function parseTokenFromEnv( } } + if (provider === MODEL_PROVIDER_CLIENT) { + return { + provider, + model, + base: undefined, + token: "client", + } + } + return undefined } diff --git a/packages/core/src/constants.ts b/packages/core/src/constants.ts index 387bb0f5bd..dad1cb2d88 100644 --- a/packages/core/src/constants.ts +++ b/packages/core/src/constants.ts @@ -94,6 +94,7 @@ export const MODEL_PROVIDER_OLLAMA = "ollama" export const MODEL_PROVIDER_LLAMAFILE = "llamafile" export const MODEL_PROVIDER_LITELLM = "litellm" export const MODEL_PROVIDER_AICI = "aici" +export const MODEL_PROVIDER_CLIENT = "client" export const TRACE_FILE_PREVIEW_MAX_LENGTH = 240 diff --git a/packages/core/src/generation.ts b/packages/core/src/generation.ts index e21b21dd81..8561c5236e 100644 --- a/packages/core/src/generation.ts +++ b/packages/core/src/generation.ts @@ -82,7 +82,6 @@ export interface GenerationOptions cliInfo?: { files: string[] } - languageModel?: LanguageModel vars?: PromptParameters stats: GenerationStats } diff --git a/packages/core/src/host.ts b/packages/core/src/host.ts index a33392ecb7..2784b15834 100644 --- a/packages/core/src/host.ts +++ b/packages/core/src/host.ts @@ -70,10 +70,6 @@ export interface RetrievalService { ): Promise } -export interface ParsePdfResponse extends ResponseStatus { - pages?: string[] -} - export interface ServerResponse extends ResponseStatus { version: string node: string @@ -111,6 +107,7 @@ export interface Host { options?: { token?: boolean } & AbortSignalOptions & TraceOptions ): Promise log(level: LogLevel, msg: string): void + clientLanguageModel?: LanguageModel // fs readFile(name: string): Promise diff --git a/packages/core/src/lm.ts b/packages/core/src/lm.ts index 4aa613a11a..54e773a1a4 100644 --- a/packages/core/src/lm.ts +++ b/packages/core/src/lm.ts @@ -1,22 +1,20 @@ import { AICIModel } from "./aici" import { LanguageModel } from "./chat" -import { MODEL_PROVIDER_AICI, MODEL_PROVIDER_OLLAMA } from "./constants" -import { LanguageModelConfiguration } from "./host" +import { + MODEL_PROVIDER_AICI, + MODEL_PROVIDER_CLIENT, + MODEL_PROVIDER_OLLAMA, +} from "./constants" +import { host } from "./host" import { OllamaModel } from "./ollama" import { OpenAIModel } from "./openai" -import { parseModelIdentifier } from "./models" -export function resolveLanguageModel( - options: { - model?: string - languageModel?: LanguageModel - }, - configuration: LanguageModelConfiguration -): LanguageModel { - const { model, languageModel } = options || {} - if (languageModel) return languageModel - - const { provider } = parseModelIdentifier(model) +export function resolveLanguageModel(provider: string): LanguageModel { + if (provider === MODEL_PROVIDER_CLIENT) { + const m = host.clientLanguageModel + if (!m) throw new Error("Client language model not available") + return m + } if (provider === MODEL_PROVIDER_OLLAMA) return OllamaModel if (provider === MODEL_PROVIDER_AICI) return AICIModel return OpenAIModel diff --git a/packages/core/src/models.ts b/packages/core/src/models.ts index bb1720654b..5c7ce7e823 100644 --- a/packages/core/src/models.ts +++ b/packages/core/src/models.ts @@ -53,17 +53,14 @@ export function traceLanguageModelConnection( trace.startDetails(`⚙️ configuration`) try { trace.itemValue(`model`, model) + trace.itemValue(`version`, version) + trace.itemValue(`source`, source) + trace.itemValue(`provider`, provider) trace.itemValue(`temperature`, temperature) trace.itemValue(`topP`, topP) trace.itemValue(`maxTokens`, maxTokens) trace.itemValue(`base`, base) trace.itemValue(`type`, type) - trace.itemValue(`version`, version) - trace.itemValue(`source`, source) - trace.itemValue(`provider`, provider) - trace.itemValue(`model`, model) - trace.itemValue(`temperature`, temperature) - trace.itemValue(`top_p`, topP) trace.itemValue(`seed`, seed) trace.itemValue(`cache name`, cacheName) trace.itemValue(`response type`, responseType) diff --git a/packages/core/src/pdf.ts b/packages/core/src/pdf.ts index 31efabb39d..55540c9573 100644 --- a/packages/core/src/pdf.ts +++ b/packages/core/src/pdf.ts @@ -1,5 +1,5 @@ import type { TextItem } from "pdfjs-dist/types/src/display/api" -import { ParsePdfResponse, host } from "./host" +import { host } from "./host" import { TraceOptions } from "./trace" import { installImport } from "./import" import { PDFJS_DIST_VERSION } from "./version" @@ -63,7 +63,7 @@ async function PDFTryParse( fileOrUrl: string, content?: Uint8Array, options?: { disableCleanup?: boolean } & TraceOptions -): Promise { +) { const { disableCleanup, trace } = options || {} try { const pdfjs = await tryImportPdfjs(options) diff --git a/packages/core/src/promptcontext.ts b/packages/core/src/promptcontext.ts index 7d49cdb6a9..fffde02656 100644 --- a/packages/core/src/promptcontext.ts +++ b/packages/core/src/promptcontext.ts @@ -294,10 +294,7 @@ export async function createPromptContext( ) if (!connection.configuration) throw new Error("model connection error " + connection.info) - const { completer } = await resolveLanguageModel( - genOptions, - connection.configuration - ) + const { completer } = await resolveLanguageModel(connection.configuration.provider) if (!completer) throw new Error( "model driver not found for " + connection.info diff --git a/packages/core/src/promptrunner.ts b/packages/core/src/promptrunner.ts index d5a73f6c37..e41a301407 100644 --- a/packages/core/src/promptrunner.ts +++ b/packages/core/src/promptrunner.ts @@ -196,8 +196,7 @@ export async function runTemplate( connection.info ) const { completer } = await resolveLanguageModel( - genOptions, - connection.configuration + connection.configuration.provider ) const output = await executeChatSession( connection.configuration, diff --git a/packages/core/src/server/client.ts b/packages/core/src/server/client.ts index 9e7623c829..04b7ef061e 100644 --- a/packages/core/src/server/client.ts +++ b/packages/core/src/server/client.ts @@ -1,4 +1,6 @@ -import { ChatCompletionsProgressReport } from "../chattypes" +import { + ChatCompletionsProgressReport, +} from "../chattypes" import { CLIENT_RECONNECT_DELAY, OPEN, RECONNECT } from "../constants" import { randomHex } from "../crypto" import { errorMessage } from "../error" @@ -18,10 +20,18 @@ import { PromptScriptRunOptions, PromptScriptStart, PromptScriptAbort, - ResponseEvents, + PromptScriptResponseEvents, ServerEnv, + ChatEvents, + ChatChunk, + ChatStart, } from "./messages" +export type LanguageModelChatRequest = ( + request: ChatStart, + onChunk: (param: Omit) => void +) => Promise + export class WebSocketClient extends EventTarget { private awaiters: Record< string, @@ -34,6 +44,8 @@ export class WebSocketClient extends EventTarget { connectedOnce = false reconnectAttempts = 0 + chatRequest: LanguageModelChatRequest + private runs: Record< string, { @@ -125,7 +137,7 @@ export class WebSocketClient extends EventTarget { } // handle run progress - const ev: ResponseEvents = data + const ev: PromptScriptResponseEvents = data const { runId, type } = ev const run = this.runs[runId] if (run) { @@ -152,6 +164,25 @@ export class WebSocketClient extends EventTarget { break } } + } else { + const cev: ChatEvents = data + const { chatId, type } = cev + switch (type) { + case "chat.start": { + if (!this.chatRequest) + throw new Error( + "client language model not supported" + ) + await this.chatRequest(cev, (chunk) => { + this.queue({ + ...chunk, + chatId, + type: "chat.chunk", + }) + }) + // done + } + } } })) } diff --git a/packages/core/src/server/messages.ts b/packages/core/src/server/messages.ts index a0be2856e4..1c9e3a2a72 100644 --- a/packages/core/src/server/messages.ts +++ b/packages/core/src/server/messages.ts @@ -1,5 +1,6 @@ +import { ChatCompletionAssistantMessageParam } from "../chattypes" import { GenerationResult } from "../generation" -import { ParsePdfResponse, ResponseStatus } from "../host" +import { ResponseStatus } from "../host" export interface RequestMessage { type: string @@ -126,6 +127,29 @@ export interface ShellExec extends RequestMessage { response?: ShellExecResponse } +export interface ChatStart { + type: "chat.start" + chatId: string + messages: ChatCompletionAssistantMessageParam[] + model: string + modelOptions?: { + temperature?: number + } +} + +export interface ChatCancel { + type: "chat.cancel" + chatId: string +} + +export interface ChatChunk extends RequestMessage { + type: "chat.chunk" + chatId: string + finishReason?: string + chunk?: string + tokens?: number +} + export type RequestMessages = | ServerKill | ServerVersion @@ -135,7 +159,10 @@ export type RequestMessages = | ShellExec | PromptScriptStart | PromptScriptAbort + | ChatChunk -export type ResponseEvents = +export type PromptScriptResponseEvents = | PromptScriptProgressResponseEvent | PromptScriptEndResponseEvent + +export type ChatEvents = ChatStart | ChatCancel \ No newline at end of file diff --git a/packages/core/src/testhost.ts b/packages/core/src/testhost.ts index f0c453c98f..c14c746a31 100644 --- a/packages/core/src/testhost.ts +++ b/packages/core/src/testhost.ts @@ -25,6 +25,7 @@ import { resolve, isAbsolute, } from "node:path" +import { LanguageModel } from "./chat" export function createNodePath(): Path { return Object.freeze({ @@ -81,6 +82,7 @@ export class TestHost implements RuntimeHost { ): Promise { throw new Error("Method not implemented.") } + clientLanguageModel?: LanguageModel log(level: LogLevel, msg: string): void { throw new Error("Method not implemented.") } diff --git a/packages/sample/genaisrc/poem.genai.mts b/packages/sample/genaisrc/poem.genai.mts new file mode 100644 index 0000000000..301a86ffdd --- /dev/null +++ b/packages/sample/genaisrc/poem.genai.mts @@ -0,0 +1 @@ +$`Write a very short poem using emojis.` \ No newline at end of file diff --git a/packages/vscode/package.json b/packages/vscode/package.json index 42de28bec4..633e6ff222 100644 --- a/packages/vscode/package.json +++ b/packages/vscode/package.json @@ -365,7 +365,7 @@ "@types/vscode": "^1.90.0", "@vscode/vsce": "^2.31.1", "assert": "^2.1.0", - "eslint": "^9.7.0", + "eslint": "^9.8.0", "markdown-it-github-alerts": "^0.3.0", "process": "^0.11.10", "typescript": "5.5.4", diff --git a/packages/vscode/src/lmaccess.ts b/packages/vscode/src/lmaccess.ts index 1ff1de4b4e..5371f81ff4 100644 --- a/packages/vscode/src/lmaccess.ts +++ b/packages/vscode/src/lmaccess.ts @@ -10,12 +10,14 @@ import { MODEL_PROVIDER_AZURE, MODEL_PROVIDER_LITELLM, MODEL_PROVIDER_OPENAI, - DOT_ENV_FILENAME, + MODEL_PROVIDER_CLIENT, } from "../../core/src/constants" import { APIType } from "../../core/src/host" import { parseModelIdentifier } from "../../core/src/models" -import { GenerationOptions } from "../../core/src/generation" import { updateConnectionConfiguration } from "../../core/src/connection" +import { ChatCompletionMessageParam } from "../../core/src/chattypes" +import { LanguageModelChatRequest } from "../../core/src/server/client" +import { ChatStart } from "../../core/src/server/messages" async function generateLanguageModelConfiguration( state: ExtensionState, @@ -32,28 +34,24 @@ async function generateLanguageModelConfiguration( return { provider } } - let models: vscode.LanguageModelChat[] = [] - if (isLanguageModelsAvailable(state.context)) - models = await vscode.lm.selectChatModels() + if (state.useLanguageModels) + return { provider: MODEL_PROVIDER_CLIENT, model: "*" } + const items: (vscode.QuickPickItem & { model?: string provider?: string apiType?: APIType - })[] = models.map((model) => ({ - label: model.name, - description: model.vendor, - detail: `Use the language model ${model} through your GitHub Copilot subscription.`, - model: model.id, - })) - if (items.length) - items.unshift({ - kind: vscode.QuickPickItemKind.Separator, - label: "Visual Studio Code Language Model", - }) - items.push({ - kind: vscode.QuickPickItemKind.Separator, - label: DOT_ENV_FILENAME, - }) + })[] = [] + if (isLanguageModelsAvailable()) { + const models = await vscode.lm.selectChatModels() + if (models.length) + items.push({ + label: "Visual Studio Language Models", + detail: `Use a registered Language Model (e.g. GitHub Copilot).`, + model: "*", + provider: MODEL_PROVIDER_CLIENT, + }) + } items.push( { label: "OpenAI", @@ -97,9 +95,28 @@ async function generateLanguageModelConfiguration( >(items, { title: `Pick a Language Model for ${modelId}`, }) + + if (res.provider === MODEL_PROVIDER_CLIENT) state.useLanguageModels = true + return res } +async function pickChatModel(model: string): Promise { + const chatModels = await vscode.lm.selectChatModels() + const items: (vscode.QuickPickItem & { + chatModel?: vscode.LanguageModelChat + })[] = chatModels.map((chatModel) => ({ + label: chatModel.name, + description: `${chatModel.vendor} ${chatModel.family}`, + detail: `${chatModel.version}, ${chatModel.maxInputTokens}t.`, + chatModel, + })) + const res = await vscode.window.showQuickPick(items, { + title: `Pick a Chat Model for ${model}`, + }) + return res?.chatModel +} + export async function pickLanguageModel( state: ExtensionState, modelId: string @@ -118,93 +135,77 @@ export async function pickLanguageModel( } } -export function isLanguageModelsAvailable(context: vscode.ExtensionContext) { +export function isLanguageModelsAvailable() { return ( - isApiProposalEnabled( - context, - "languageModels", - "github.copilot-chat" - ) && typeof vscode.lm !== "undefined" && typeof vscode.lm.selectChatModels !== "undefined" ) } -export async function configureLanguageModelAccess( - context: vscode.ExtensionContext, - options: AIRequestOptions, - genOptions: GenerationOptions, - chatModelId: string -): Promise { - const { template } = options - const { partialCb } = genOptions - - const chatModel = (await vscode.lm.selectChatModels({ id: chatModelId }))[0] - - genOptions.cache = false - genOptions.languageModel = Object.freeze({ - id: "vscode", - completer: async (req, connection, chatOptions, trace) => { - const token = new vscode.CancellationTokenSource().token - const { model, temperature, top_p, seed, ...rest } = req - - trace.itemValue(`script model`, model) - trace.itemValue(`language model`, chatModel) - const messages: vscode.LanguageModelChatMessage[] = - req.messages.map((m) => { - switch (m.role) { - case "system": - return { - role: vscode.LanguageModelChatMessageRole.User, - content: m.content, - } - case "user": - if ( - Array.isArray(m.content) && - m.content.some((c) => c.type === "image_url") - ) - throw new Error("Vision model not supported") - return { - role: vscode.LanguageModelChatMessageRole.User, - content: - typeof m.content === "string" - ? m.content - : m.content.map((c) => c).join("\n"), - } - case "assistant": - return { - role: vscode.LanguageModelChatMessageRole - .Assistant, - content: m.content, - } - case "function": - case "tool": - throw new Error( - "tools not supported with copilot models" - ) - default: - throw new Error("uknown role") - } - }) - const request = await chatModel.sendRequest( - messages, - { - justification: `Run GenAIScript ${template.title || template.id}`, - modelOptions: { temperature, top_p, seed }, - }, - token - ) +function messagesToChatMessages(messages: ChatCompletionMessageParam[]) { + const res: vscode.LanguageModelChatMessage[] = messages.map((m) => { + switch (m.role) { + case "system": + return { + role: vscode.LanguageModelChatMessageRole.User, + content: m.content, + } + case "user": + if ( + Array.isArray(m.content) && + m.content.some((c) => c.type === "image_url") + ) + throw new Error("Vision model not supported") + return { + role: vscode.LanguageModelChatMessageRole.User, + content: + typeof m.content === "string" + ? m.content + : m.content.map((c) => c).join("\n"), + } + case "assistant": + return { + role: vscode.LanguageModelChatMessageRole.Assistant, + content: m.content, + } + case "function": + case "tool": + throw new Error("tools not supported with copilot models") + default: + throw new Error("uknown role") + } + }) + return res +} - let text = "" - for await (const fragment of request.text) { - text += fragment - partialCb?.({ - responseSoFar: text, - responseChunk: fragment, - tokensSoFar: await chatModel.countTokens(text), - }) - } - return { text } +export const runChatModel: LanguageModelChatRequest = async ( + req: ChatStart, + onChunk +) => { + const token = new vscode.CancellationTokenSource().token + const { model, messages, modelOptions } = req + const chatModel = await pickChatModel(model) + if (!chatModel) throw new Error("No chat model selected.") + const chatMessages = messagesToChatMessages(messages) + const request = await chatModel.sendRequest( + chatMessages, + { + justification: `Run GenAIScript`, + modelOptions, }, + token + ) + + let text = "" + for await (const fragment of request.text) { + text += fragment + onChunk({ + chunk: fragment, + tokens: await chatModel.countTokens(text), + finishReason: undefined, + }) + } + onChunk({ + finishReason: "stop", }) } diff --git a/packages/vscode/src/servermanager.ts b/packages/vscode/src/servermanager.ts index 04f3a1ecfe..5e5f9408e5 100644 --- a/packages/vscode/src/servermanager.ts +++ b/packages/vscode/src/servermanager.ts @@ -15,6 +15,7 @@ import { ServerManager, host } from "../../core/src/host" import { logError, logVerbose } from "../../core/src/util" import { WebSocketClient } from "../../core/src/server/client" import { CORE_VERSION } from "../../core/src/version" +import { isLanguageModelsAvailable, runChatModel } from "./lmaccess" export class TerminalServerManager implements ServerManager { private _terminal: vscode.Terminal @@ -43,6 +44,7 @@ export class TerminalServerManager implements ServerManager { ) this.client = new WebSocketClient(`http://localhost:${SERVER_PORT}`) + if (isLanguageModelsAvailable()) this.client.chatRequest = runChatModel this.client.addEventListener(OPEN, () => { // client connected to a rogue server if (!this._terminal) { diff --git a/packages/vscode/src/state.ts b/packages/vscode/src/state.ts index 4ac776c54a..aefb87ec06 100644 --- a/packages/vscode/src/state.ts +++ b/packages/vscode/src/state.ts @@ -107,6 +107,7 @@ export class ExtensionState extends EventTarget { AIRequestSnapshot > = undefined readonly output: vscode.LogOutputChannel + useLanguageModels = false constructor(public readonly context: ExtensionContext) { super() @@ -290,58 +291,10 @@ temp/ r.response = partialResponse reqChange() } - /* - const genOptions: GenerationOptions = { - requestOptions: { signal }, - cancellationToken, - partialCb, - trace, - infoCb: (data) => { - r.response = data - reqChange() - }, - maxCachedTemperature, - maxCachedTopP, - vars: options.parameters, - cache: cache && template.cache, - stats: { toolCalls: 0, repairs: 0, turns: 0 }, - cliInfo: - fragment && !options.notebook - ? { - spec: - this.host.isVirtualFile(fragment.file.filename) && - this.host.path.basename( - fragment.file.filename - ) === "dir.gpspec.md" - ? fragment.file.filename.replace( - /dir\.gpspec\.md$/i, - "**" - ) - : this.host.isVirtualFile( - fragment.file.filename - ) - ? fragment.file.filename.replace( - /\.gpspec\.md$/i, - "" - ) - : fragment.file.filename, - } - : undefined, - model: info.model, - } - */ if (!connectionToken) { // we don't have a token so ask user if they want to use copilot - const lmmodel = await pickLanguageModel(this, info.model) - if (!lmmodel) return undefined - /* - await configureLanguageModelAccess( - this.context, - options, - genOptions, - lmmodel - ) - */ + const lm = await pickLanguageModel(this, info.model) + if (!lm) return undefined } if (connectionToken?.type === "localai") await startLocalAI() diff --git a/packages/vscode/src/vshost.ts b/packages/vscode/src/vshost.ts index c4444b89f8..fae9523e51 100644 --- a/packages/vscode/src/vshost.ts +++ b/packages/vscode/src/vshost.ts @@ -27,6 +27,7 @@ import { } from "../../core/src/host" import { TraceOptions, AbortSignalOptions } from "../../core/src/trace" import { arrayify, unique } from "../../core/src/util" +import { LanguageModel } from "../../core/src/chat" export class VSCodeHost extends EventTarget implements Host { dotEnvPath: string = DOT_ENV_FILENAME @@ -195,6 +196,7 @@ export class VSCodeHost extends EventTarget implements Host { } } + clientLanguageModel?: LanguageModel async getLanguageModelConfiguration( modelId: string, options?: { token?: boolean } & AbortSignalOptions & TraceOptions diff --git a/yarn.lock b/yarn.lock index 7577c2acee..22773830af 100644 --- a/yarn.lock +++ b/yarn.lock @@ -163,10 +163,10 @@ stream-replace-string "^2.0.0" zod "^3.23.8" -"@astrojs/starlight@^0.25.2": - version "0.25.2" - resolved "https://registry.yarnpkg.com/@astrojs/starlight/-/starlight-0.25.2.tgz#1fb6591b28ce6026d6a9da7eb5705ac86f85e1a1" - integrity sha512-VcFe9yXB0oUIoEU0lLdDA3jnbtjVzSXXpX/oI8OW4ofEHCt8L9V27f+NzRjB0A1+2D6FxAEoAw3NKoP06PLRQA== +"@astrojs/starlight@^0.25.3": + version "0.25.3" + resolved "https://registry.yarnpkg.com/@astrojs/starlight/-/starlight-0.25.3.tgz#f1faa850e649e3c4029d16cdbf3cf2124cbb9364" + integrity sha512-XNpGbZ54ungtzen4wQkPXn50D1ZquB51paWrZftA0jWxUkj4b/pP8PijAFrGFypydhvu7Dhl0DjD07lrnSSUhQ== dependencies: "@astrojs/mdx" "^3.1.0" "@astrojs/sitemap" "^3.1.5" @@ -969,7 +969,7 @@ resolved "https://registry.yarnpkg.com/@eslint-community/regexpp/-/regexpp-4.11.0.tgz#b0ffd0312b4a3fd2d6f77237e7248a5ad3a680ae" integrity sha512-G/M/tIiMrTAxEWRfLfQJMmGNX28IxBg4PBz8XqQhqUHLFI6TL2htpIB1iQCj144V5ee/JaKyT9/WZ0MGZWfA7A== -"@eslint/config-array@^0.17.0": +"@eslint/config-array@^0.17.1": version "0.17.1" resolved "https://registry.yarnpkg.com/@eslint/config-array/-/config-array-0.17.1.tgz#d9b8b8b6b946f47388f32bedfd3adf29ca8f8910" integrity sha512-BlYOpej8AQ8Ev9xVqroV7a02JK3SkBAaN9GfMMH9W6Ch8FlQlkjGw4Ir7+FgYwfirivAf4t+GtzuAxqfukmISA== @@ -993,10 +993,10 @@ minimatch "^3.1.2" strip-json-comments "^3.1.1" -"@eslint/js@9.7.0": - version "9.7.0" - resolved "https://registry.yarnpkg.com/@eslint/js/-/js-9.7.0.tgz#b712d802582f02b11cfdf83a85040a296afec3f0" - integrity sha512-ChuWDQenef8OSFnvuxv0TCVxEwmu3+hPNKvM9B34qpM0rDRbjL8t5QkQeHHeAfsKQjuH9wS82WeCi1J/owatng== +"@eslint/js@9.8.0": + version "9.8.0" + resolved "https://registry.yarnpkg.com/@eslint/js/-/js-9.8.0.tgz#ae9bc14bb839713c5056f5018bcefa955556d3a4" + integrity sha512-MfluB7EUfxXtv3i/++oh89uzAr4PDI4nn201hsp+qaXqsjAWzinlZEHEfPgAX4doIlKvPG/i0A9dpKxOLII8yA== "@eslint/object-schema@^2.1.4": version "2.1.4" @@ -5018,7 +5018,7 @@ drauu@^0.4.0: dependencies: "@drauu/core" "0.4.0" -drizzle-orm@^0.32.0: +drizzle-orm@^0.32.1: version "0.32.1" resolved "https://registry.yarnpkg.com/drizzle-orm/-/drizzle-orm-0.32.1.tgz#4e28c22d7f2a60aef3f0837c0a06aa7b3378b082" integrity sha512-Wq1J+lL8PzwR5K3a1FfoWsbs8powjr3pGA4+5+2ueN1VTLDNFYEolUyUWFtqy8DVRvYbL2n7sXZkgVmK9dQkng== @@ -5312,16 +5312,16 @@ eslint-visitor-keys@^4.0.0: resolved "https://registry.yarnpkg.com/eslint-visitor-keys/-/eslint-visitor-keys-4.0.0.tgz#e3adc021aa038a2a8e0b2f8b0ce8f66b9483b1fb" integrity sha512-OtIRv/2GyiF6o/d8K7MYKKbXrOUBIK6SfkIRM4Z0dY3w+LiQ0vy3F57m0Z71bjbyeiWFiHJ8brqnmE6H6/jEuw== -eslint@^9.7.0: - version "9.7.0" - resolved "https://registry.yarnpkg.com/eslint/-/eslint-9.7.0.tgz#bedb48e1cdc2362a0caaa106a4c6ed943e8b09e4" - integrity sha512-FzJ9D/0nGiCGBf8UXO/IGLTgLVzIxze1zpfA8Ton2mjLovXdAPlYDv+MQDcqj3TmrhAGYfOpz9RfR+ent0AgAw== +eslint@^9.8.0: + version "9.8.0" + resolved "https://registry.yarnpkg.com/eslint/-/eslint-9.8.0.tgz#a4f4a090c8ea2d10864d89a6603e02ce9f649f0f" + integrity sha512-K8qnZ/QJzT2dLKdZJVX6W4XOwBzutMYmt0lqUS+JdXgd+HTYFlonFgkJ8s44d/zMPPCnOOk0kMWCApCPhiOy9A== dependencies: "@eslint-community/eslint-utils" "^4.2.0" "@eslint-community/regexpp" "^4.11.0" - "@eslint/config-array" "^0.17.0" + "@eslint/config-array" "^0.17.1" "@eslint/eslintrc" "^3.1.0" - "@eslint/js" "9.7.0" + "@eslint/js" "9.8.0" "@humanwhocodes/module-importer" "^1.0.1" "@humanwhocodes/retry" "^0.3.0" "@nodelib/fs.walk" "^1.2.8" @@ -5661,6 +5661,13 @@ fast-xml-parser@^4.4.0: dependencies: strnum "^1.0.5" +fast-xml-parser@^4.4.1: + version "4.4.1" + resolved "https://registry.yarnpkg.com/fast-xml-parser/-/fast-xml-parser-4.4.1.tgz#86dbf3f18edf8739326447bcaac31b4ae7f6514f" + integrity sha512-xkjOecfnKGkSsOwtZ5Pz7Us/T6mrbPQrq0nh+aCO5V9nk5NLWmasAHumTKjiPJPWANe+kAZ84Jc8ooJkzZ88Sw== + dependencies: + strnum "^1.0.5" + fastest-levenshtein@^1.0.16: version "1.0.16" resolved "https://registry.yarnpkg.com/fastest-levenshtein/-/fastest-levenshtein-1.0.16.tgz#210e61b6ff181de91ea9b3d1b84fdedd47e034e5" @@ -9173,7 +9180,7 @@ openai@^3.2.1: axios "^0.26.0" form-data "^4.0.0" -openai@^4.53.0, openai@^4.53.1: +openai@^4.53.1: version "4.53.1" resolved "https://registry.yarnpkg.com/openai/-/openai-4.53.1.tgz#5ea38c175a70b685d329fafd6669dc4e789592b2" integrity sha512-BFj9e0jfzqd2GAGRY9hj6PU7VrGyl3LPhUdji7QvZCVxlqusoLR5qBzH5wjrJZ4d1BBDic/t5yvTdk023fM7+w== @@ -9186,6 +9193,19 @@ openai@^4.53.0, openai@^4.53.1: formdata-node "^4.3.2" node-fetch "^2.6.7" +openai@^4.53.2: + version "4.53.2" + resolved "https://registry.yarnpkg.com/openai/-/openai-4.53.2.tgz#86f54a38091a87db36f651cf28c9e5ee7c98d56a" + integrity sha512-ohYEv6OV3jsFGqNrgolDDWN6Ssx1nFg6JDJQuaBFo4SL2i+MBoOQ16n2Pq1iBF5lH1PKnfCIOfqAGkmzPvdB9g== + dependencies: + "@types/node" "^18.11.18" + "@types/node-fetch" "^2.6.4" + abort-controller "^3.0.0" + agentkeepalive "^4.2.1" + form-data-encoder "1.7.2" + formdata-node "^4.3.2" + node-fetch "^2.6.7" + opener@^1.5.2: version "1.5.2" resolved "https://registry.yarnpkg.com/opener/-/opener-1.5.2.tgz#5d37e1f35077b9dcac4301372271afdeb2a13598" @@ -9735,10 +9755,10 @@ process@^0.11.10: resolved "https://registry.yarnpkg.com/process/-/process-0.11.10.tgz#7332300e840161bda3e69a1d1d91a7d4bc16f182" integrity sha512-cdGef/drWFoydD1JsMzuFf8100nZl+GT+yacc2bEced5f9Rjk4z+WtFUTBu9PhOi9j/jfmBPu0mMEY4wIdAF8A== -promptfoo@^0.73.6: - version "0.73.6" - resolved "https://registry.yarnpkg.com/promptfoo/-/promptfoo-0.73.6.tgz#54acd01416497edacf4422d637ebeecce8dd085a" - integrity sha512-cnUCf8MeHZ90lHgVV/3PQz1q5sV8+xZFpZnnmGM2cG2ldCr4nAJrZCBGeyiWJd6SJNuDDfD3fpGVSN2+HmtTaw== +promptfoo@^0.73.8: + version "0.73.8" + resolved "https://registry.yarnpkg.com/promptfoo/-/promptfoo-0.73.8.tgz#46ac9a30fa9fd03c0689187d15a32e34c4384377" + integrity sha512-yKUbeywdoP471nFR1uCp7ZjnpAA8QXznaFQLB5UeR5C4fsm/Ksbl58Q/SNJ+XcU8NGGfBXI1yPhRJbA/HsBV4A== dependencies: "@anthropic-ai/sdk" "^0.24.3" "@apidevtools/json-schema-ref-parser" "^11.6.4" @@ -9761,7 +9781,7 @@ promptfoo@^0.73.6: debounce "^1.2.1" dedent "^1.5.3" dotenv "^16.4.5" - drizzle-orm "^0.32.0" + drizzle-orm "^0.32.1" express "^4.19.2" fast-deep-equal "^3.1.3" fast-xml-parser "^4.4.0" @@ -9772,11 +9792,11 @@ promptfoo@^0.73.6: mathjs "^13.0.3" node-fetch "^2.6.7" nunjucks "^3.2.4" - openai "^4.53.0" + openai "^4.53.1" opener "^1.5.2" proxy-agent "^6.4.0" python-shell "^5.0.0" - replicate "^0.31.1" + replicate "^0.32.0" rfdc "^1.4.1" rouge "git+https://github.com/kenlimmj/rouge.git#f35111b599aca55f1d4dc1d4a3d15e28e7f7c55f" semver "^7.6.3" @@ -10190,10 +10210,10 @@ replace-ext@^2.0.0: resolved "https://registry.yarnpkg.com/replace-ext/-/replace-ext-2.0.0.tgz#9471c213d22e1bcc26717cd6e50881d88f812b06" integrity sha512-UszKE5KVK6JvyD92nzMn9cDapSk6w/CaFZ96CnmDMUqH9oowfxF/ZjRITD25H4DnOQClLA4/j7jLGXXLVKxAug== -replicate@^0.31.1: - version "0.31.1" - resolved "https://registry.yarnpkg.com/replicate/-/replicate-0.31.1.tgz#83e2b809dd093a72a629a4063bebd670d068e7cd" - integrity sha512-klO76pTPzzS9Xri6bWtAp5mNjgcvyvqpVHibhTyrx4pAK7rvXal8rNGspURGCwp8ToxDQNYGEV7l+3d+xiFiwQ== +replicate@^0.32.0: + version "0.32.0" + resolved "https://registry.yarnpkg.com/replicate/-/replicate-0.32.0.tgz#d35821571fb465d10554233d226fdeb2ddc9d8f4" + integrity sha512-XOJBnV/FpRsz/r7DEj8KL4pdDk9BpptkljGOhKmjlZGdNcBvt532GxxmjT4ZaqdExg7STxrh1JHhI91zg+CZTw== optionalDependencies: readable-stream ">=4.0.0"