From 30b396112a3d73841cc51e841ffa35316ce9c231 Mon Sep 17 00:00:00 2001 From: pelikhan Date: Tue, 30 Jul 2024 16:49:32 +0200 Subject: [PATCH] remember language chat model choices --- packages/vscode/src/lmaccess.ts | 97 ++++++++++++++++------------ packages/vscode/src/servermanager.ts | 4 +- packages/vscode/src/state.ts | 26 ++++++-- packages/vscode/src/statusbar.ts | 7 +- 4 files changed, 80 insertions(+), 54 deletions(-) diff --git a/packages/vscode/src/lmaccess.ts b/packages/vscode/src/lmaccess.ts index 5371f81ff4..0c930d88de 100644 --- a/packages/vscode/src/lmaccess.ts +++ b/packages/vscode/src/lmaccess.ts @@ -34,7 +34,7 @@ async function generateLanguageModelConfiguration( return { provider } } - if (state.useLanguageModels) + if (Object.keys(state.languageChatModels).length) return { provider: MODEL_PROVIDER_CLIENT, model: "*" } const items: (vscode.QuickPickItem & { @@ -93,28 +93,36 @@ async function generateLanguageModelConfiguration( apiType?: APIType } >(items, { - title: `Pick a Language Model for ${modelId}`, + title: `Configure a Language Model for ${modelId}`, }) - if (res.provider === MODEL_PROVIDER_CLIENT) state.useLanguageModels = true - return res } -async function pickChatModel(model: string): Promise { +async function pickChatModel( + state: ExtensionState, + model: string +): Promise { const chatModels = await vscode.lm.selectChatModels() - const items: (vscode.QuickPickItem & { - chatModel?: vscode.LanguageModelChat - })[] = chatModels.map((chatModel) => ({ - label: chatModel.name, - description: `${chatModel.vendor} ${chatModel.family}`, - detail: `${chatModel.version}, ${chatModel.maxInputTokens}t.`, - chatModel, - })) - const res = await vscode.window.showQuickPick(items, { - title: `Pick a Chat Model for ${model}`, - }) - return res?.chatModel + + const chatModelId = state.languageChatModels[model] + let chatModel = chatModelId && chatModels.find((m) => m.id === chatModelId) + if (!chatModel) { + const items: (vscode.QuickPickItem & { + chatModel?: vscode.LanguageModelChat + })[] = chatModels.map((chatModel) => ({ + label: chatModel.name, + description: `${chatModel.vendor} ${chatModel.family}`, + detail: `${chatModel.version}, ${chatModel.maxInputTokens}t.`, + chatModel, + })) + const res = await vscode.window.showQuickPick(items, { + title: `Pick a Chat Model for ${model}`, + }) + chatModel = res?.chatModel + if (chatModel) state.languageChatModels[model] = chatModel.id + } + return chatModel } export async function pickLanguageModel( @@ -178,34 +186,37 @@ function messagesToChatMessages(messages: ChatCompletionMessageParam[]) { return res } -export const runChatModel: LanguageModelChatRequest = async ( - req: ChatStart, - onChunk -) => { - const token = new vscode.CancellationTokenSource().token - const { model, messages, modelOptions } = req - const chatModel = await pickChatModel(model) - if (!chatModel) throw new Error("No chat model selected.") - const chatMessages = messagesToChatMessages(messages) - const request = await chatModel.sendRequest( - chatMessages, - { - justification: `Run GenAIScript`, - modelOptions, - }, - token - ) +export function createChatModelRunner( + state: ExtensionState +): LanguageModelChatRequest { + if (!isLanguageModelsAvailable()) return undefined + + return async (req: ChatStart, onChunk) => { + const token = new vscode.CancellationTokenSource().token + const { model, messages, modelOptions } = req + const chatModel = await pickChatModel(state, model) + if (!chatModel) throw new Error("No chat model selected.") + const chatMessages = messagesToChatMessages(messages) + const request = await chatModel.sendRequest( + chatMessages, + { + justification: `Run GenAIScript`, + modelOptions, + }, + token + ) - let text = "" - for await (const fragment of request.text) { - text += fragment + let text = "" + for await (const fragment of request.text) { + text += fragment + onChunk({ + chunk: fragment, + tokens: await chatModel.countTokens(text), + finishReason: undefined, + }) + } onChunk({ - chunk: fragment, - tokens: await chatModel.countTokens(text), - finishReason: undefined, + finishReason: "stop", }) } - onChunk({ - finishReason: "stop", - }) } diff --git a/packages/vscode/src/servermanager.ts b/packages/vscode/src/servermanager.ts index 5e5f9408e5..c717a94243 100644 --- a/packages/vscode/src/servermanager.ts +++ b/packages/vscode/src/servermanager.ts @@ -15,7 +15,7 @@ import { ServerManager, host } from "../../core/src/host" import { logError, logVerbose } from "../../core/src/util" import { WebSocketClient } from "../../core/src/server/client" import { CORE_VERSION } from "../../core/src/version" -import { isLanguageModelsAvailable, runChatModel } from "./lmaccess" +import { createChatModelRunner } from "./lmaccess" export class TerminalServerManager implements ServerManager { private _terminal: vscode.Terminal @@ -44,7 +44,7 @@ export class TerminalServerManager implements ServerManager { ) this.client = new WebSocketClient(`http://localhost:${SERVER_PORT}`) - if (isLanguageModelsAvailable()) this.client.chatRequest = runChatModel + this.client.chatRequest = createChatModelRunner(this.state) this.client.addEventListener(OPEN, () => { // client connected to a rogue server if (!this._terminal) { diff --git a/packages/vscode/src/state.ts b/packages/vscode/src/state.ts index aefb87ec06..cd9d58a9bb 100644 --- a/packages/vscode/src/state.ts +++ b/packages/vscode/src/state.ts @@ -107,7 +107,8 @@ export class ExtensionState extends EventTarget { AIRequestSnapshot > = undefined readonly output: vscode.LogOutputChannel - useLanguageModels = false + // modelid -> vscode language mode id + languageChatModels: Record = {} constructor(public readonly context: ExtensionContext) { super() @@ -128,13 +129,24 @@ export class ExtensionState extends EventTarget { >(AI_REQUESTS_CACHE) // clear errors when file edited (remove me?) - vscode.workspace.onDidChangeTextDocument( - (ev) => { - this._diagColl.set(ev.document.uri, []) - }, - undefined, - subscriptions + subscriptions.push( + vscode.workspace.onDidChangeTextDocument( + (ev) => { + this._diagColl.set(ev.document.uri, []) + }, + undefined, + subscriptions + ) + ) + if ( + typeof vscode.lm !== "undefined" && + typeof vscode.lm.onDidChangeChatModels === "function" ) + subscriptions.push( + vscode.lm.onDidChangeChatModels( + () => (this.languageChatModels = {}) + ) + ) } private async saveScripts() { diff --git a/packages/vscode/src/statusbar.ts b/packages/vscode/src/statusbar.ts index 4c0a406f30..f9df1fa6e8 100644 --- a/packages/vscode/src/statusbar.ts +++ b/packages/vscode/src/statusbar.ts @@ -13,7 +13,7 @@ export function activateStatusBar(state: ExtensionState) { ) statusBarItem.command = "genaiscript.request.status" const updateStatusBar = async () => { - const { parsing, aiRequest } = state + const { parsing, aiRequest, languageChatModels } = state const { computing, progress, options } = aiRequest || {} const { template, fragment } = options || {} const { tokensSoFar } = progress || {} @@ -30,7 +30,10 @@ export function activateStatusBar(state: ExtensionState) { fragment?.files?.[0], template ? `- tool: ${template.title} (${template.id})` - : undefined + : undefined, + ...Object.entries(languageChatModels).map( + ([m, c]) => `- language chat model: ${m} -> ${c}` + ) ), true )