Skip to content

Commit

Permalink
remember language chat model choices
Browse files Browse the repository at this point in the history
  • Loading branch information
pelikhan committed Jul 30, 2024
1 parent f4db019 commit 30b3961
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 54 deletions.
97 changes: 54 additions & 43 deletions packages/vscode/src/lmaccess.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ async function generateLanguageModelConfiguration(
return { provider }
}

if (state.useLanguageModels)
if (Object.keys(state.languageChatModels).length)
return { provider: MODEL_PROVIDER_CLIENT, model: "*" }

const items: (vscode.QuickPickItem & {
Expand Down Expand Up @@ -93,28 +93,36 @@ async function generateLanguageModelConfiguration(
apiType?: APIType
}
>(items, {
title: `Pick a Language Model for ${modelId}`,
title: `Configure a Language Model for ${modelId}`,
})

if (res.provider === MODEL_PROVIDER_CLIENT) state.useLanguageModels = true

return res
}

async function pickChatModel(model: string): Promise<vscode.LanguageModelChat> {
async function pickChatModel(
state: ExtensionState,
model: string
): Promise<vscode.LanguageModelChat> {
const chatModels = await vscode.lm.selectChatModels()
const items: (vscode.QuickPickItem & {
chatModel?: vscode.LanguageModelChat
})[] = chatModels.map((chatModel) => ({
label: chatModel.name,
description: `${chatModel.vendor} ${chatModel.family}`,
detail: `${chatModel.version}, ${chatModel.maxInputTokens}t.`,
chatModel,
}))
const res = await vscode.window.showQuickPick(items, {
title: `Pick a Chat Model for ${model}`,
})
return res?.chatModel

const chatModelId = state.languageChatModels[model]
let chatModel = chatModelId && chatModels.find((m) => m.id === chatModelId)
if (!chatModel) {
const items: (vscode.QuickPickItem & {
chatModel?: vscode.LanguageModelChat
})[] = chatModels.map((chatModel) => ({
label: chatModel.name,
description: `${chatModel.vendor} ${chatModel.family}`,
detail: `${chatModel.version}, ${chatModel.maxInputTokens}t.`,
chatModel,
}))
const res = await vscode.window.showQuickPick(items, {
title: `Pick a Chat Model for ${model}`,
})
chatModel = res?.chatModel
if (chatModel) state.languageChatModels[model] = chatModel.id
}
return chatModel
}

export async function pickLanguageModel(
Expand Down Expand Up @@ -178,34 +186,37 @@ function messagesToChatMessages(messages: ChatCompletionMessageParam[]) {
return res
}

export const runChatModel: LanguageModelChatRequest = async (
req: ChatStart,
onChunk
) => {
const token = new vscode.CancellationTokenSource().token
const { model, messages, modelOptions } = req
const chatModel = await pickChatModel(model)
if (!chatModel) throw new Error("No chat model selected.")
const chatMessages = messagesToChatMessages(messages)
const request = await chatModel.sendRequest(
chatMessages,
{
justification: `Run GenAIScript`,
modelOptions,
},
token
)
export function createChatModelRunner(
state: ExtensionState
): LanguageModelChatRequest {
if (!isLanguageModelsAvailable()) return undefined

return async (req: ChatStart, onChunk) => {
const token = new vscode.CancellationTokenSource().token
const { model, messages, modelOptions } = req
const chatModel = await pickChatModel(state, model)
if (!chatModel) throw new Error("No chat model selected.")
const chatMessages = messagesToChatMessages(messages)
const request = await chatModel.sendRequest(
chatMessages,
{
justification: `Run GenAIScript`,
modelOptions,
},
token
)

let text = ""
for await (const fragment of request.text) {
text += fragment
let text = ""
for await (const fragment of request.text) {
text += fragment
onChunk({
chunk: fragment,
tokens: await chatModel.countTokens(text),
finishReason: undefined,
})
}
onChunk({
chunk: fragment,
tokens: await chatModel.countTokens(text),
finishReason: undefined,
finishReason: "stop",
})
}
onChunk({
finishReason: "stop",
})
}
4 changes: 2 additions & 2 deletions packages/vscode/src/servermanager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import { ServerManager, host } from "../../core/src/host"
import { logError, logVerbose } from "../../core/src/util"
import { WebSocketClient } from "../../core/src/server/client"
import { CORE_VERSION } from "../../core/src/version"
import { isLanguageModelsAvailable, runChatModel } from "./lmaccess"
import { createChatModelRunner } from "./lmaccess"

export class TerminalServerManager implements ServerManager {
private _terminal: vscode.Terminal
Expand Down Expand Up @@ -44,7 +44,7 @@ export class TerminalServerManager implements ServerManager {
)

this.client = new WebSocketClient(`http://localhost:${SERVER_PORT}`)
if (isLanguageModelsAvailable()) this.client.chatRequest = runChatModel
this.client.chatRequest = createChatModelRunner(this.state)
this.client.addEventListener(OPEN, () => {
// client connected to a rogue server
if (!this._terminal) {
Expand Down
26 changes: 19 additions & 7 deletions packages/vscode/src/state.ts
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,8 @@ export class ExtensionState extends EventTarget {
AIRequestSnapshot
> = undefined
readonly output: vscode.LogOutputChannel
useLanguageModels = false
// modelid -> vscode language mode id
languageChatModels: Record<string, string> = {}

constructor(public readonly context: ExtensionContext) {
super()
Expand All @@ -128,13 +129,24 @@ export class ExtensionState extends EventTarget {
>(AI_REQUESTS_CACHE)

// clear errors when file edited (remove me?)
vscode.workspace.onDidChangeTextDocument(
(ev) => {
this._diagColl.set(ev.document.uri, [])
},
undefined,
subscriptions
subscriptions.push(
vscode.workspace.onDidChangeTextDocument(
(ev) => {
this._diagColl.set(ev.document.uri, [])
},
undefined,
subscriptions
)
)
if (
typeof vscode.lm !== "undefined" &&
typeof vscode.lm.onDidChangeChatModels === "function"
)
subscriptions.push(
vscode.lm.onDidChangeChatModels(
() => (this.languageChatModels = {})
)
)
}

private async saveScripts() {
Expand Down
7 changes: 5 additions & 2 deletions packages/vscode/src/statusbar.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ export function activateStatusBar(state: ExtensionState) {
)
statusBarItem.command = "genaiscript.request.status"
const updateStatusBar = async () => {
const { parsing, aiRequest } = state
const { parsing, aiRequest, languageChatModels } = state
const { computing, progress, options } = aiRequest || {}
const { template, fragment } = options || {}
const { tokensSoFar } = progress || {}
Expand All @@ -30,7 +30,10 @@ export function activateStatusBar(state: ExtensionState) {
fragment?.files?.[0],
template
? `- tool: ${template.title} (${template.id})`
: undefined
: undefined,
...Object.entries(languageChatModels).map(
([m, c]) => `- language chat model: ${m} -> ${c}`
)
),
true
)
Expand Down

0 comments on commit 30b3961

Please sign in to comment.