From 12030185a0592c652601752c864c1ecac55acf7f Mon Sep 17 00:00:00 2001 From: pelikhan Date: Fri, 2 Aug 2024 11:05:16 -0700 Subject: [PATCH 1/5] moving token resolution outsdie of vscode --- packages/cli/src/info.ts | 24 ++++++++++++++++++------ packages/cli/src/server.ts | 14 +++++++++----- packages/core/src/connection.ts | 8 ++++---- packages/core/src/github.ts | 8 ++++---- packages/core/src/host.ts | 18 +++++++++--------- packages/core/src/server/client.ts | 9 ++++----- packages/core/src/server/messages.ts | 9 ++++++++- packages/core/src/websearch.ts | 4 ++-- packages/vscode/src/vshost.ts | 19 ++----------------- 9 files changed, 60 insertions(+), 53 deletions(-) diff --git a/packages/cli/src/info.ts b/packages/cli/src/info.ts index 2fad37c7c5..d17391f806 100644 --- a/packages/cli/src/info.ts +++ b/packages/cli/src/info.ts @@ -6,6 +6,7 @@ import { ModelConnectionInfo, resolveModelConnectionInfo, } from "../../core/src/models" +import { ServerEnvResponse } from "../../core/src/server/messages" import { CORE_VERSION } from "../../core/src/version" import { YAMLStringify } from "../../core/src/yaml" import { buildProject } from "./build" @@ -19,10 +20,20 @@ export async function systemInfo() { } export async function envInfo(provider: string, options?: { token?: boolean }) { + const res = await resolveEnv(provider, options) + console.log(YAMLStringify(res)) +} + +export async function resolveEnv( + provider: string, + options?: { token?: boolean } +): Promise { const { token } = options || {} - const res: any = {} - res[".env"] = host.dotEnvPath ?? "" - res.providers = [] + const res: ServerEnvResponse = { + ok: true, + env: host.dotEnvPath ?? "", + providers: [], + } const env = process.env for (const modelProvider of MODEL_PROVIDERS.filter( (mp) => !provider || mp.id === provider @@ -30,18 +41,19 @@ export async function envInfo(provider: string, options?: { token?: boolean }) { try { const conn = await parseTokenFromEnv(env, `${modelProvider.id}:*`) if (conn) { - if (!token && conn.token) - conn.token = "***" + if (!token && conn.token) conn.token = "***" res.providers.push(conn) } } catch (e) { res.providers.push({ provider: modelProvider.id, + model: undefined, + base: undefined, error: errorMessage(e), }) } } - console.log(YAMLStringify(res)) + return res } async function resolveScriptsConnectionInfo( diff --git a/packages/cli/src/server.ts b/packages/cli/src/server.ts index 0c92c3f170..d6e1d0a676 100644 --- a/packages/cli/src/server.ts +++ b/packages/cli/src/server.ts @@ -18,7 +18,7 @@ import { import { LanguageModelConfiguration, ResponseStatus, - ServerResponse, + ServerVersionResponse, host, runtimeHost, } from "../../core/src/host" @@ -33,8 +33,9 @@ import { ChatStart, ChatChunk, ChatCancel, + ServerEnvResponse, } from "../../core/src/server/messages" -import { envInfo } from "./info" +import { envInfo, resolveEnv } from "./info" import { LanguageModel } from "../../core/src/chat" import { ChatCompletionResponse, @@ -166,7 +167,7 @@ export async function startServer(options: { port: string }) { switch (type) { case "server.version": { console.log(`server: version ${CORE_VERSION}`) - response = { + response = { ok: true, version: CORE_VERSION, node: process.version, @@ -178,9 +179,12 @@ export async function startServer(options: { port: string }) { } case "server.env": { console.log(`server: env`) - envInfo(undefined) - response = { + const info = await resolveEnv(undefined, { + token: false, + }) + response = { ok: true, + ...info, } break } diff --git a/packages/core/src/connection.ts b/packages/core/src/connection.ts index 7f1f6ecfe7..878e7bc091 100644 --- a/packages/core/src/connection.ts +++ b/packages/core/src/connection.ts @@ -28,18 +28,18 @@ import { PLACEHOLDER_API_KEY, } from "./constants" import { fileExists, readText, tryReadText, writeText } from "./fs" -import { APIType, host, LanguageModelConfiguration } from "./host" +import { APIType, host, LanguageModelConfiguration, runtimeHost } from "./host" import { dedent } from "./indent" import { parseModelIdentifier } from "./models" import { normalizeFloat, trimTrailingSlash } from "./util" export async function parseDefaultsFromEnv(env: Record) { if (env.GENAISCRIPT_DEFAULT_MODEL) - host.defaultModelOptions.model = env.GENAISCRIPT_DEFAULT_MODEL + runtimeHost.defaultModelOptions.model = env.GENAISCRIPT_DEFAULT_MODEL const t = normalizeFloat(env.GENAISCRIPT_DEFAULT_TEMPERATURE) - if (!isNaN(t)) host.defaultModelOptions.temperature = t + if (!isNaN(t)) runtimeHost.defaultModelOptions.temperature = t if (env.GENAISCRIPT_DEFAULT_EMBEDDINGS_MODEL) - host.defaultEmbeddingsModelOptions.embeddingsModel = + runtimeHost.defaultEmbeddingsModelOptions.embeddingsModel = env.GENAISCRIPT_DEFAULT_EMBEDDINGS_MODEL } diff --git a/packages/core/src/github.ts b/packages/core/src/github.ts index 0056b1934e..99a46603a5 100644 --- a/packages/core/src/github.ts +++ b/packages/core/src/github.ts @@ -5,7 +5,7 @@ import { GITHUB_TOKEN, } from "./constants" import { createFetch } from "./fetch" -import { host } from "./host" +import { host, runtimeHost } from "./host" import { link, prettifyMarkdown } from "./markdown" import { logError, logVerbose, normalizeInt } from "./util" @@ -70,7 +70,7 @@ export async function githubUpdatePullRequestDescription( assert(commentTag) if (!issue) return { updated: false, statusText: "missing issue number" } - const token = await host.readSecret(GITHUB_TOKEN) + const token = await runtimeHost.readSecret(GITHUB_TOKEN) if (!token) return { updated: false, statusText: "missing github token" } text = prettifyMarkdown(text) @@ -169,7 +169,7 @@ export async function githubCreateIssueComment( const { apiUrl, repository, issue } = info if (!issue) return { created: false, statusText: "missing issue number" } - const token = await host.readSecret(GITHUB_TOKEN) + const token = await runtimeHost.readSecret(GITHUB_TOKEN) if (!token) return { created: false, statusText: "missing github token" } const fetch = await createFetch({ retryOn: [] }) @@ -313,7 +313,7 @@ export async function githubCreatePullRequestReviews( logError("missing commit sha") return false } - const token = await host.readSecret(GITHUB_TOKEN) + const token = await runtimeHost.readSecret(GITHUB_TOKEN) if (!token) { logError("missing github token") return false diff --git a/packages/core/src/host.ts b/packages/core/src/host.ts index 2784b15834..e731931395 100644 --- a/packages/core/src/host.ts +++ b/packages/core/src/host.ts @@ -51,8 +51,7 @@ export interface ResponseStatus { status?: number } -export interface RetrievalSearchOptions extends VectorSearchOptions { -} +export interface RetrievalSearchOptions extends VectorSearchOptions {} export interface RetrievalSearchResponse extends ResponseStatus { results: WorkspaceFileWithScore[] @@ -70,7 +69,7 @@ export interface RetrievalService { ): Promise } -export interface ServerResponse extends ResponseStatus { +export interface ServerVersionResponse extends ResponseStatus { version: string node: string platform: string @@ -96,12 +95,6 @@ export interface Host { installFolder(): string resolvePath(...segments: string[]): string - // read a secret from the environment or a .env file - readSecret(name: string): Promise - defaultModelOptions: Required> - defaultEmbeddingsModelOptions: Required< - Pick - > getLanguageModelConfiguration( modelId: string, options?: { token?: boolean } & AbortSignalOptions & TraceOptions @@ -130,6 +123,13 @@ export interface RuntimeHost extends Host { models: ModelService workspace: Omit + // read a secret from the environment or a .env file + readSecret(name: string): Promise + defaultModelOptions: Required> + defaultEmbeddingsModelOptions: Required< + Pick + > + // executes a process exec( containerId: string, diff --git a/packages/core/src/server/client.ts b/packages/core/src/server/client.ts index 04b7ef061e..2640b435c0 100644 --- a/packages/core/src/server/client.ts +++ b/packages/core/src/server/client.ts @@ -1,6 +1,4 @@ -import { - ChatCompletionsProgressReport, -} from "../chattypes" +import { ChatCompletionsProgressReport } from "../chattypes" import { CLIENT_RECONNECT_DELAY, OPEN, RECONNECT } from "../constants" import { randomHex } from "../crypto" import { errorMessage } from "../error" @@ -25,6 +23,7 @@ import { ChatEvents, ChatChunk, ChatStart, + ServerEnvResponse, } from "./messages" export type LanguageModelChatRequest = ( @@ -238,9 +237,9 @@ export class WebSocketClient extends EventTarget { return res.version } - async infoEnv(): Promise { + async infoEnv(): Promise { const res = await this.queue({ type: "server.env" }) - return res.response + return res.response as ServerEnvResponse } async startScript( diff --git a/packages/core/src/server/messages.ts b/packages/core/src/server/messages.ts index 0a4ffa1954..59e5ac514a 100644 --- a/packages/core/src/server/messages.ts +++ b/packages/core/src/server/messages.ts @@ -1,6 +1,6 @@ import { ChatCompletionAssistantMessageParam } from "../chattypes" import { GenerationResult } from "../generation" -import { ResponseStatus } from "../host" +import { LanguageModelConfiguration, ResponseStatus } from "../host" export interface RequestMessage { type: string @@ -21,6 +21,13 @@ export interface ServerEnv extends RequestMessage { type: "server.env" } +export interface ServerEnvResponse extends ResponseStatus { + env: string + providers: (Omit & { + error?: string + })[] +} + export interface PromptScriptTestRunOptions { testProvider?: string models?: string[] diff --git a/packages/core/src/websearch.ts b/packages/core/src/websearch.ts index c9db0f698c..22cccca376 100644 --- a/packages/core/src/websearch.ts +++ b/packages/core/src/websearch.ts @@ -1,6 +1,6 @@ import { BING_SEARCH_ENDPOINT } from "./constants" import { createFetch } from "./fetch" -import { host } from "./host" +import { runtimeHost } from "./host" import { MarkdownTrace } from "./trace" function toURLSearchParams(o: any) { @@ -47,7 +47,7 @@ export async function bingSearch( } = options || {} if (!q) return {} - const apiKey = await host.readSecret("BING_SEARCH_API_KEY") + const apiKey = await runtimeHost.readSecret("BING_SEARCH_API_KEY") if (!apiKey) throw new Error( "BING_SEARCH_API_KEY secret is required to use bing search. See https://microsoft.github.io/genaiscript/reference/scripts/web-search/#bing-web-search-configuration." diff --git a/packages/vscode/src/vshost.ts b/packages/vscode/src/vshost.ts index fae9523e51..15da6e2a1e 100644 --- a/packages/vscode/src/vshost.ts +++ b/packages/vscode/src/vshost.ts @@ -8,7 +8,6 @@ import { Utils } from "vscode-uri" import { checkFileExists, readFileText } from "./fs" import { filterGitIgnore } from "../../core/src/gitignore" import { - parseDefaultsFromEnv, parseTokenFromEnv, } from "../../core/src/connection" import { @@ -51,9 +50,6 @@ export class VSCodeHost extends EventTarget implements Host { } async activate() { - const dotenv = await readFileText(this.projectUri, DOT_ENV_FILENAME) - const env = dotEnvTryParse(dotenv) ?? {} - await parseDefaultsFromEnv(env) } get azure() { @@ -186,25 +182,14 @@ export class VSCodeHost extends EventTarget implements Host { await vscode.workspace.fs.delete(uri, { recursive: true }) } - async readSecret(name: string): Promise { - try { - const dotenv = await readFileText(this.projectUri, DOT_ENV_FILENAME) - const env = dotEnvTryParse(dotenv) - return env?.[name] - } catch (e) { - return undefined - } - } - clientLanguageModel?: LanguageModel async getLanguageModelConfiguration( modelId: string, options?: { token?: boolean } & AbortSignalOptions & TraceOptions ): Promise { const { signal, token: askToken } = options || {} - const dotenv = await readFileText(this.projectUri, DOT_ENV_FILENAME) - const env = dotEnvTryParse(dotenv) ?? {} - await parseDefaultsFromEnv(env) + const res = await this.server.client.infoEnv() + // TODO const tok = await parseTokenFromEnv(env, modelId) if ( askToken && From 2347fa47ff63ffed2e18dd1f54154507bc87bcc6 Mon Sep 17 00:00:00 2001 From: pelikhan Date: Fri, 2 Aug 2024 11:37:26 -0700 Subject: [PATCH 2/5] call server env to get configuration --- packages/core/src/server/messages.ts | 2 +- packages/vscode/src/vshost.ts | 23 ++++++++++++----------- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/packages/core/src/server/messages.ts b/packages/core/src/server/messages.ts index 59e5ac514a..c59b723815 100644 --- a/packages/core/src/server/messages.ts +++ b/packages/core/src/server/messages.ts @@ -23,7 +23,7 @@ export interface ServerEnv extends RequestMessage { export interface ServerEnvResponse extends ResponseStatus { env: string - providers: (Omit & { + providers: (LanguageModelConfiguration & { error?: string })[] } diff --git a/packages/vscode/src/vshost.ts b/packages/vscode/src/vshost.ts index 15da6e2a1e..9ece1b8a35 100644 --- a/packages/vscode/src/vshost.ts +++ b/packages/vscode/src/vshost.ts @@ -7,17 +7,12 @@ import { ExtensionState } from "./state" import { Utils } from "vscode-uri" import { checkFileExists, readFileText } from "./fs" import { filterGitIgnore } from "../../core/src/gitignore" -import { - parseTokenFromEnv, -} from "../../core/src/connection" import { DEFAULT_EMBEDDINGS_MODEL, DEFAULT_MODEL, DEFAULT_TEMPERATURE, DOT_ENV_FILENAME, - MODEL_PROVIDER_AZURE, } from "../../core/src/constants" -import { dotEnvTryParse } from "../../core/src/dotenv" import { setHost, LanguageModelConfiguration, @@ -27,6 +22,7 @@ import { import { TraceOptions, AbortSignalOptions } from "../../core/src/trace" import { arrayify, unique } from "../../core/src/util" import { LanguageModel } from "../../core/src/chat" +import { parseModelIdentifier } from "../../core/src/models" export class VSCodeHost extends EventTarget implements Host { dotEnvPath: string = DOT_ENV_FILENAME @@ -49,8 +45,7 @@ export class VSCodeHost extends EventTarget implements Host { this.state.context.subscriptions.push(this) } - async activate() { - } + async activate() {} get azure() { if (!this._azure) this._azure = new AzureManager(this.state) @@ -167,7 +162,10 @@ export class VSCodeHost extends EventTarget implements Host { } let files = Array.from(uris.values()) - if (applyGitIgnore && (await checkFileExists(this.projectUri, ".gitignore"))) { + if ( + applyGitIgnore && + (await checkFileExists(this.projectUri, ".gitignore")) + ) { const gitignore = await readFileText(this.projectUri, ".gitignore") files = await filterGitIgnore(gitignore, files) } @@ -187,9 +185,11 @@ export class VSCodeHost extends EventTarget implements Host { modelId: string, options?: { token?: boolean } & AbortSignalOptions & TraceOptions ): Promise { - const { signal, token: askToken } = options || {} - const res = await this.server.client.infoEnv() - // TODO + const { provider } = parseModelIdentifier(modelId) + const { error, ...tok } = ( + await this.server.client.infoEnv() + ).providers.find((m) => m.provider === provider) + /* const tok = await parseTokenFromEnv(env, modelId) if ( askToken && @@ -204,6 +204,7 @@ export class VSCodeHost extends EventTarget implements Host { Authorization: "Bearer ***", } } + */ return tok } From 51c1ebca02d7a319ce060cfd5205ca730a6060a4 Mon Sep 17 00:00:00 2001 From: pelikhan Date: Fri, 2 Aug 2024 13:42:11 -0700 Subject: [PATCH 3/5] more extension of protocol --- packages/core/src/server/client.ts | 15 +++++++++++++-- packages/core/src/server/messages.ts | 7 +++++++ 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/packages/core/src/server/client.ts b/packages/core/src/server/client.ts index 2640b435c0..595547395e 100644 --- a/packages/core/src/server/client.ts +++ b/packages/core/src/server/client.ts @@ -24,6 +24,8 @@ import { ChatChunk, ChatStart, ServerEnvResponse, + ClientRequeMessages, + ClientRequestMessages, } from "./messages" export type LanguageModelChatRequest = ( @@ -164,10 +166,19 @@ export class WebSocketClient extends EventTarget { } } } else { - const cev: ChatEvents = data - const { chatId, type } = cev + const cev: ClientRequestMessages = data + const { type } = cev switch (type) { + case "authentication.session": { + const resp = await this.authenticationSession(cev.model) + this.queue({ + ...resp, + type: "authentication.session", + }) + break + } case "chat.start": { + const { chatId } = cev if (!this.chatRequest) throw new Error( "client language model not supported" diff --git a/packages/core/src/server/messages.ts b/packages/core/src/server/messages.ts index c59b723815..afedf8a426 100644 --- a/packages/core/src/server/messages.ts +++ b/packages/core/src/server/messages.ts @@ -159,6 +159,11 @@ export interface ChatChunk extends RequestMessage { error?: SerializedError } +export interface AuthenticationSessionRequest { + type: "authentication.session" + model: string +} + export type RequestMessages = | ServerKill | ServerVersion @@ -175,3 +180,5 @@ export type PromptScriptResponseEvents = | PromptScriptEndResponseEvent export type ChatEvents = ChatStart | ChatCancel + +export type ClientRequestMessages = ChatEvents | AuthenticationSessionRequest From 19666de1a8674fddbe6a9b9aa52f3e9386634e68 Mon Sep 17 00:00:00 2001 From: pelikhan Date: Mon, 5 Aug 2024 07:48:56 -0700 Subject: [PATCH 4/5] Refactor token resolution in vscode extension --- packages/cli/src/info.ts | 4 +- packages/cli/src/server.ts | 2 +- packages/core/src/chat.ts | 13 +++--- packages/core/src/connection.ts | 4 +- packages/core/src/constants.ts | 1 + packages/core/src/expander.ts | 4 +- packages/core/src/models.ts | 4 +- packages/core/src/promptcontext.ts | 6 ++- packages/core/src/server/client.ts | 62 +++++++------------------- packages/core/src/server/messages.ts | 6 +-- packages/core/src/server/rpc.ts | 66 ++++++++++++++++++++++++++++ packages/core/src/test.ts | 6 +-- 12 files changed, 109 insertions(+), 69 deletions(-) create mode 100644 packages/core/src/server/rpc.ts diff --git a/packages/cli/src/info.ts b/packages/cli/src/info.ts index d17391f806..c5d98833c5 100644 --- a/packages/cli/src/info.ts +++ b/packages/cli/src/info.ts @@ -1,7 +1,7 @@ import { parseTokenFromEnv } from "../../core/src/connection" import { MODEL_PROVIDERS } from "../../core/src/constants" import { errorMessage } from "../../core/src/error" -import { host } from "../../core/src/host" +import { host, runtimeHost } from "../../core/src/host" import { ModelConnectionInfo, resolveModelConnectionInfo, @@ -63,7 +63,7 @@ async function resolveScriptsConnectionInfo( const models: Record = {} for (const template of templates) { const conn: ModelConnectionOptions = { - model: template.model ?? host.defaultModelOptions.model, + model: template.model ?? runtimeHost.defaultModelOptions.model, } const key = JSON.stringify(conn) if (!models[key]) models[key] = conn diff --git a/packages/cli/src/server.ts b/packages/cli/src/server.ts index d6e1d0a676..0c67877cf5 100644 --- a/packages/cli/src/server.ts +++ b/packages/cli/src/server.ts @@ -35,7 +35,7 @@ import { ChatCancel, ServerEnvResponse, } from "../../core/src/server/messages" -import { envInfo, resolveEnv } from "./info" +import { resolveEnv } from "./info" import { LanguageModel } from "../../core/src/chat" import { ChatCompletionResponse, diff --git a/packages/core/src/chat.ts b/packages/core/src/chat.ts index ddc163775d..8db44aed2c 100644 --- a/packages/core/src/chat.ts +++ b/packages/core/src/chat.ts @@ -1,6 +1,6 @@ import { MarkdownTrace } from "./trace" import { PromptImage, renderPromptNode } from "./promptdom" -import { LanguageModelConfiguration, host } from "./host" +import { LanguageModelConfiguration, host, runtimeHost } from "./host" import { GenerationOptions } from "./generation" import { JSON5TryParse, JSON5parse, isJSONObjectOrArray } from "./json5" import { CancellationToken, checkCancelled } from "./cancellation" @@ -424,13 +424,14 @@ export function mergeGenerationOptions( model: runOptions?.model ?? options?.model ?? - host.defaultModelOptions.model, + runtimeHost.defaultModelOptions.model, temperature: - runOptions?.temperature ?? host.defaultModelOptions.temperature, + runOptions?.temperature ?? + runtimeHost.defaultModelOptions.temperature, embeddingsModel: runOptions?.embeddingsModel ?? options?.embeddingsModel ?? - host.defaultEmbeddingsModelOptions.embeddingsModel, + runtimeHost.defaultEmbeddingsModelOptions.embeddingsModel, } } @@ -447,8 +448,8 @@ export async function executeChatSession( ): Promise { const { trace, - model = host.defaultModelOptions.model, - temperature = host.defaultModelOptions.temperature, + model = runtimeHost.defaultModelOptions.model, + temperature = runtimeHost.defaultModelOptions.temperature, topP, maxTokens, seed, diff --git a/packages/core/src/connection.ts b/packages/core/src/connection.ts index 878e7bc091..b7d79720d7 100644 --- a/packages/core/src/connection.ts +++ b/packages/core/src/connection.ts @@ -28,7 +28,7 @@ import { PLACEHOLDER_API_KEY, } from "./constants" import { fileExists, readText, tryReadText, writeText } from "./fs" -import { APIType, host, LanguageModelConfiguration, runtimeHost } from "./host" +import { APIType, LanguageModelConfiguration, runtimeHost } from "./host" import { dedent } from "./indent" import { parseModelIdentifier } from "./models" import { normalizeFloat, trimTrailingSlash } from "./util" @@ -48,7 +48,7 @@ export async function parseTokenFromEnv( modelId: string ): Promise { const { provider, model, tag } = parseModelIdentifier( - modelId ?? host.defaultModelOptions.model + modelId ?? runtimeHost.defaultModelOptions.model ) if (provider === MODEL_PROVIDER_OPENAI) { if (env.OPENAI_API_KEY || env.OPENAI_API_BASE || env.OPENAI_API_TYPE) { diff --git a/packages/core/src/constants.ts b/packages/core/src/constants.ts index 6b40314780..bdc678bfd2 100644 --- a/packages/core/src/constants.ts +++ b/packages/core/src/constants.ts @@ -2,6 +2,7 @@ export const CHANGE = "change" export const TRACE_CHUNK = "traceChunk" export const RECONNECT = "reconnect" export const OPEN = "open" +export const MESSAGE = "message" export const MAX_CACHED_TEMPERATURE = 0.5 export const MAX_CACHED_TOP_P = 0.5 export const MAX_TOOL_CALLS = 100 diff --git a/packages/core/src/expander.ts b/packages/core/src/expander.ts index 4f9cbe96f6..832d07913d 100644 --- a/packages/core/src/expander.ts +++ b/packages/core/src/expander.ts @@ -16,7 +16,7 @@ import { toChatCompletionUserMessage } from "./chat" import { importPrompt } from "./importprompt" import { parseModelIdentifier } from "./models" import { JSONSchemaStringifyToTypeScript } from "./schema" -import { host } from "./host" +import { host, runtimeHost } from "./host" import { resolveSystems } from "./systems" import { GenerationOptions, GenerationStatus } from "./generation" import { @@ -175,7 +175,7 @@ export async function expandTemplate( options.temperature ?? normalizeFloat(env.vars["temperature"]) ?? template.temperature ?? - host.defaultModelOptions.temperature + runtimeHost.defaultModelOptions.temperature const topP = options.topP ?? normalizeFloat(env.vars["top_p"]) ?? template.topP const max_tokens = diff --git a/packages/core/src/models.ts b/packages/core/src/models.ts index 39dcfb96af..6c9611b8ab 100644 --- a/packages/core/src/models.ts +++ b/packages/core/src/models.ts @@ -1,6 +1,6 @@ import { MODEL_PROVIDER_LLAMAFILE, MODEL_PROVIDER_OPENAI } from "./constants" import { errorMessage } from "./error" -import { LanguageModelConfiguration, host } from "./host" +import { LanguageModelConfiguration, host, runtimeHost } from "./host" import { AbortSignalOptions, MarkdownTrace, TraceOptions } from "./trace" import { assert } from "./util" @@ -86,7 +86,7 @@ export async function resolveModelConnectionInfo( }> { const { trace, token: askToken, signal } = options || {} const hasModel = options?.model ?? conn.model - const model = options?.model ?? conn.model ?? host.defaultModelOptions.model + const model = options?.model ?? conn.model ?? runtimeHost.defaultModelOptions.model try { const configuration = await host.getLanguageModelConfiguration(model, { token: askToken, diff --git a/packages/core/src/promptcontext.ts b/packages/core/src/promptcontext.ts index fffde02656..90f9de1015 100644 --- a/packages/core/src/promptcontext.ts +++ b/packages/core/src/promptcontext.ts @@ -175,7 +175,7 @@ export async function createPromptContext( searchOptions.embeddingsModel = searchOptions?.embeddingsModel ?? options?.embeddingsModel ?? - host.defaultEmbeddingsModelOptions.embeddingsModel + runtimeHost.defaultEmbeddingsModelOptions.embeddingsModel const key = await sha256string( JSON.stringify({ files, searchOptions }) ) @@ -294,7 +294,9 @@ export async function createPromptContext( ) if (!connection.configuration) throw new Error("model connection error " + connection.info) - const { completer } = await resolveLanguageModel(connection.configuration.provider) + const { completer } = await resolveLanguageModel( + connection.configuration.provider + ) if (!completer) throw new Error( "model driver not found for " + connection.info diff --git a/packages/core/src/server/client.ts b/packages/core/src/server/client.ts index 595547395e..719c2a1686 100644 --- a/packages/core/src/server/client.ts +++ b/packages/core/src/server/client.ts @@ -20,13 +20,12 @@ import { PromptScriptAbort, PromptScriptResponseEvents, ServerEnv, - ChatEvents, ChatChunk, ChatStart, ServerEnvResponse, - ClientRequeMessages, ClientRequestMessages, } from "./messages" +import { MessageQueue } from "./rpc" export type LanguageModelChatRequest = ( request: ChatStart, @@ -34,13 +33,8 @@ export type LanguageModelChatRequest = ( ) => Promise export class WebSocketClient extends EventTarget { - private awaiters: Record< - string, - { resolve: (data: any) => void; reject: (error: unknown) => void } - > = {} - private _nextId = 1 + private messages: MessageQueue private _ws: WebSocket - private _pendingMessages: string[] = [] private _reconnectTimeout: ReturnType | undefined connectedOnce = false reconnectAttempts = 0 @@ -65,6 +59,10 @@ export class WebSocketClient extends EventTarget { constructor(readonly url: string) { super() + this.messages = new MessageQueue({ + readyState: () => this._ws?.readyState, + send: (msg) => this._ws?.send(msg), + }) } private installPolyfill() { @@ -104,12 +102,7 @@ export class WebSocketClient extends EventTarget { this.connectedOnce = true this.reconnectAttempts = 0 // flush cached messages - let m: string - while ( - this._ws?.readyState === WebSocket.OPEN && - (m = this._pendingMessages.pop()) - ) - this._ws.send(m) + this.messages.flush() this.dispatchEvent(new Event(OPEN)) }) this._ws.addEventListener("error", (ev) => { @@ -127,15 +120,7 @@ export class WebSocketClient extends EventTarget { (event: MessageEvent) => void >(async (event) => { const data = JSON.parse(event.data) - // handle responses - const req: RequestMessages = data - const { id } = req - const awaiter = this.awaiters[id] - if (awaiter) { - delete this.awaiters[id] - await awaiter.resolve(req) - return - } + if (this.messages.receive(data)) return // handle run progress const ev: PromptScriptResponseEvents = data @@ -174,6 +159,7 @@ export class WebSocketClient extends EventTarget { this.queue({ ...resp, type: "authentication.session", + id: cev.id, }) break } @@ -197,24 +183,11 @@ export class WebSocketClient extends EventTarget { })) } - private queue(msg: Omit): Promise { - const id = this._nextId++ + "" - const mo: any = { ...msg, id } - // avoid pollution - delete mo.trace - if (mo.options) delete mo.options.trace - const m = JSON.stringify({ ...msg, id }) - - this.init() - return new Promise((resolve, reject) => { - this.awaiters[id] = { - resolve: (data) => resolve(data), - reject, - } - if (this._ws?.readyState === WebSocket.OPEN) { - this._ws.send(m) - } else this._pendingMessages.push(m) - }) + private async queue( + msg: Omit + ): Promise { + await this.init() + return this.messages.queue(msg) } stop() { @@ -237,10 +210,7 @@ export class WebSocketClient extends EventTarget { cancel(reason?: string) { this.reconnectAttempts = 0 - this._pendingMessages = [] - const cancellers = Object.values(this.awaiters) - this.awaiters = {} - cancellers.forEach((a) => a.reject(reason || "cancelled")) + this.messages.cancel(reason) } async version(): Promise { @@ -353,7 +323,7 @@ export class WebSocketClient extends EventTarget { this._ws?.readyState === WebSocket.OPEN ) this._ws.send( - JSON.stringify({ type: "server.kill", id: this._nextId++ + "" }) + JSON.stringify({ type: "server.kill", id: randomHex(6) }) ) this.stop() } diff --git a/packages/core/src/server/messages.ts b/packages/core/src/server/messages.ts index afedf8a426..5b44872e97 100644 --- a/packages/core/src/server/messages.ts +++ b/packages/core/src/server/messages.ts @@ -134,7 +134,7 @@ export interface ShellExec extends RequestMessage { response?: ShellExecResponse } -export interface ChatStart { +export interface ChatStart extends RequestMessage { type: "chat.start" chatId: string messages: ChatCompletionAssistantMessageParam[] @@ -144,7 +144,7 @@ export interface ChatStart { } } -export interface ChatCancel { +export interface ChatCancel extends RequestMessage { type: "chat.cancel" chatId: string } @@ -159,7 +159,7 @@ export interface ChatChunk extends RequestMessage { error?: SerializedError } -export interface AuthenticationSessionRequest { +export interface AuthenticationSessionRequest extends RequestMessage { type: "authentication.session" model: string } diff --git a/packages/core/src/server/rpc.ts b/packages/core/src/server/rpc.ts new file mode 100644 index 0000000000..d6f81c009d --- /dev/null +++ b/packages/core/src/server/rpc.ts @@ -0,0 +1,66 @@ +export class MessageQueue extends EventTarget { + private awaiters: Record< + string, + { resolve: (data: any) => void; reject: (error: unknown) => void } + > = {} + private _nextId = 1 + private _pendingMessages: string[] = [] + + constructor( + readonly options: { + readonly readyState: () => number + readonly send: (msg: string) => void + } + ) { + super() + } + + flush() { + let m: string + while ( + this.options.readyState() === WebSocket.OPEN && + (m = this._pendingMessages.pop()) + ) + this.options.send(m) + } + + async receive(data: any) { + const req: { id: string } = data + const { id } = req + const awaiter = this.awaiters[id] + if (awaiter) { + delete this.awaiters[id] + await awaiter.resolve(req) + return true + } else { + this.dispatchEvent(new CustomEvent("message", { detail: data })) + return false + } + } + + queue(msg: Omit): Promise { + const id = this._nextId++ + "" + const mo: any = { ...msg, id } + // avoid pollution + delete mo.trace + if (mo.options) delete mo.options.trace + const m = JSON.stringify({ ...msg, id }) + + return new Promise((resolve, reject) => { + this.awaiters[id] = { + resolve: (data) => resolve(data), + reject, + } + if (this.options.readyState() === WebSocket.OPEN) { + this.options.send(m) + } else this._pendingMessages.push(m) + }) + } + + cancel(reason?: string) { + this._pendingMessages = [] + const cancellers = Object.values(this.awaiters) + this.awaiters = {} + cancellers.forEach((a) => a.reject(reason || "cancelled")) + } +} diff --git a/packages/core/src/test.ts b/packages/core/src/test.ts index 39cdef123e..c36ae21132 100644 --- a/packages/core/src/test.ts +++ b/packages/core/src/test.ts @@ -1,6 +1,6 @@ import { HTTPS_REGEX } from "./constants" import { arrayify } from "./util" -import { host } from "./host" +import { runtimeHost } from "./host" function cleanUndefined(obj: Record) { return obj @@ -37,10 +37,10 @@ export function generatePromptFooConfiguration( prompts: [id], providers: models .map(({ model, temperature, topP }) => ({ - model: model ?? host.defaultModelOptions.model, + model: model ?? runtimeHost.defaultModelOptions.model, temperature: !isNaN(temperature) ? temperature - : host.defaultModelOptions.temperature, + : runtimeHost.defaultModelOptions.temperature, top_p: topP, })) .map(({ model, temperature, top_p }) => ({ From a6666cad81d2ec3d227c0b98f1af5ce258caf47d Mon Sep 17 00:00:00 2001 From: pelikhan Date: Mon, 5 Aug 2024 07:53:38 -0700 Subject: [PATCH 5/5] Refactor token resolution in vscode extension --- packages/core/src/server/client.ts | 17 ++++++++++++++--- packages/core/src/server/rpc.ts | 6 +++--- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/packages/core/src/server/client.ts b/packages/core/src/server/client.ts index 719c2a1686..50a04624aa 100644 --- a/packages/core/src/server/client.ts +++ b/packages/core/src/server/client.ts @@ -32,6 +32,10 @@ export type LanguageModelChatRequest = ( onChunk: (param: Omit) => void ) => Promise +export type AuthenticationSessionRequest = (model: string) => Promise<{ + token: string +}> + export class WebSocketClient extends EventTarget { private messages: MessageQueue private _ws: WebSocket @@ -40,6 +44,7 @@ export class WebSocketClient extends EventTarget { reconnectAttempts = 0 chatRequest: LanguageModelChatRequest + authenticationSessionRequest: AuthenticationSessionRequest private runs: Record< string, @@ -155,11 +160,17 @@ export class WebSocketClient extends EventTarget { const { type } = cev switch (type) { case "authentication.session": { - const resp = await this.authenticationSession(cev.model) + if (!this.chatRequest) + throw new Error( + "authentication session not supported" + ) + const resp = await this.authenticationSessionRequest( + cev.model + ) this.queue({ + ...cev, ...resp, - type: "authentication.session", - id: cev.id, + type: "authentication.session", }) break } diff --git a/packages/core/src/server/rpc.ts b/packages/core/src/server/rpc.ts index d6f81c009d..43d56a7574 100644 --- a/packages/core/src/server/rpc.ts +++ b/packages/core/src/server/rpc.ts @@ -38,13 +38,13 @@ export class MessageQueue extends EventTarget { } } - queue(msg: Omit): Promise { - const id = this._nextId++ + "" + queue(msg: T): Promise { + const id = msg.id ?? this._nextId++ + "" const mo: any = { ...msg, id } // avoid pollution delete mo.trace if (mo.options) delete mo.options.trace - const m = JSON.stringify({ ...msg, id }) + const m = JSON.stringify(mo) return new Promise((resolve, reject) => { this.awaiters[id] = {