From 9572d047abeb3208598d96d2f3897a26c51f5a63 Mon Sep 17 00:00:00 2001 From: pelikhan Date: Sat, 22 Jun 2024 18:41:25 -0700 Subject: [PATCH] more tweaking of format --- packages/cli/src/server.ts | 17 ++++---- packages/core/src/chat.ts | 2 +- packages/core/src/promptcontext.ts | 6 +-- packages/core/src/promptrunner.ts | 5 --- packages/core/src/server/client.ts | 59 ++++++++++++++++++++++++++-- packages/core/src/server/messages.ts | 20 ++++++++-- 6 files changed, 83 insertions(+), 26 deletions(-) diff --git a/packages/cli/src/server.ts b/packages/cli/src/server.ts index 8183538b18..2e2b5c2122 100644 --- a/packages/cli/src/server.ts +++ b/packages/cli/src/server.ts @@ -16,11 +16,11 @@ import { MarkdownTrace, TRACE_CHUNK, TraceChunkEvent, - TraceChunkResponseEvent, - PromptScriptEndResponseEvent, UNHANDLED_ERROR_CODE, isCancelError, USER_CANCELLED_ERROR_CODE, + PromptScriptProgressResponseEvent, + PromptScriptEndResponseEvent, } from "genaiscript-core" import { runPromptScriptTests } from "./test" import { PROMPTFOO_VERSION } from "./version" @@ -134,22 +134,23 @@ export async function startServer(options: { port: string }) { case "script.start": { cancelAll() - const { script, files, options, id } = data - const runId = id + const { script, files, options, runId } = data const canceller = new AbortSignalCancellationController() const trace = new MarkdownTrace() trace.addEventListener(TRACE_CHUNK, (ev) => { const tev = ev as TraceChunkEvent ws?.send( - JSON.stringify({ - type: "trace.chunk", - chunk: tev.chunk, + JSON.stringify(< + PromptScriptProgressResponseEvent + >{ + type: "script.progress", runId, + chunk: tev.chunk, }) ) }) - console.log(`run ${id} starting`) + console.log(`run ${runId} starting`) const runner = runScript(script, files, { ...options, trace, diff --git a/packages/core/src/chat.ts b/packages/core/src/chat.ts index a6b5a8bf30..c94ce7d509 100644 --- a/packages/core/src/chat.ts +++ b/packages/core/src/chat.ts @@ -107,7 +107,7 @@ export interface ChatCompletionsProgressReport { } export interface ChatCompletionsOptions { - partialCb?: (progres: ChatCompletionsProgressReport) => void + partialCb?: (progress: ChatCompletionsProgressReport) => void requestOptions?: Partial maxCachedTemperature?: number maxCachedTopP?: number diff --git a/packages/core/src/promptcontext.ts b/packages/core/src/promptcontext.ts index f1773fc901..4616bc64e8 100644 --- a/packages/core/src/promptcontext.ts +++ b/packages/core/src/promptcontext.ts @@ -358,11 +358,7 @@ export interface GenerationOptions ModelOptions, ScriptRuntimeOptions { cancellationToken?: CancellationToken - infoCb?: (partialResponse: { - text: string - label?: string - vars?: Partial - }) => void + infoCb?: (partialResponse: { text: string }) => void trace: MarkdownTrace maxCachedTemperature?: number maxCachedTopP?: number diff --git a/packages/core/src/promptrunner.ts b/packages/core/src/promptrunner.ts index ed651ed590..9ec2398acb 100644 --- a/packages/core/src/promptrunner.ts +++ b/packages/core/src/promptrunner.ts @@ -441,11 +441,6 @@ export async function runTemplate( schemas, json, } - options?.infoCb?.({ - label: res.label, - vars: res.vars, - text: undefined, - }) return res } finally { await host.removeContainers() diff --git a/packages/core/src/server/client.ts b/packages/core/src/server/client.ts index d2ab7c0276..17f924518e 100644 --- a/packages/core/src/server/client.ts +++ b/packages/core/src/server/client.ts @@ -1,4 +1,6 @@ +import { ChatCompletionsProgressReport } from "../chat" import { CLIENT_RECONNECT_DELAY } from "../constants" +import { randomHex } from "../crypto" import { ModelService, ParsePdfResponse, @@ -10,7 +12,7 @@ import { RetrievalUpsertOptions, host, } from "../host" -import { TraceOptions } from "../trace" +import { MarkdownTrace, TraceOptions } from "../trace" import { assert, logError } from "../util" import { ParsePdfMessage, @@ -32,6 +34,8 @@ import { PromptScriptRunOptions, PromptScriptStart, PromptScriptAbort, + PromptScriptProgressResponseEvent, + ResponseEvents, } from "./messages" export class WebSocketClient @@ -46,6 +50,15 @@ export class WebSocketClient private _pendingMessages: string[] = [] private _reconnectTimeout: ReturnType | undefined + private runs: Record< + string, + { + trace: MarkdownTrace + infoCb: (partialResponse: { text: string }) => void + partialCb: (progress: ChatCompletionsProgressReport) => void + } + > = {} + constructor(readonly url: string) {} private installPolyfill() { @@ -97,12 +110,39 @@ export class WebSocketClient this._ws.addEventListener("message", < (event: MessageEvent) => void >(async (event) => { - const data: RequestMessages = JSON.parse(event.data) - const { id } = data + const data = JSON.parse(event.data) + // handle responses + const req: RequestMessages = data + const { id } = req const awaiter = this.awaiters[id] if (awaiter) { delete this.awaiters[id] - await awaiter.resolve(data) + await awaiter.resolve(req) + } + + // handle run progress + const ev: ResponseEvents = data + const { runId, type } = ev + const run = this.runs[runId] + if (run) { + switch (type) { + case "script.progress": { + if (ev.trace) run.trace.appendContent(ev.trace) + if (ev.progress) run.infoCb({ text: ev.progress }) + if (ev.response || ev.tokens !== undefined) + run.partialCb({ + responseChunk: ev.response, + responseSoFar: ev.response, + tokensSoFar: ev.tokens, + }) + break + } + case "script.end": { + // todo: final result message? + delete this.runs[runId] + break + } + } } })) } @@ -203,16 +243,27 @@ export class WebSocketClient files: string[], options: PromptScriptRunOptions ): Promise { + const runId = randomHex(6) + this.runs[runId] = { + trace: new MarkdownTrace(), + infoCb: options.infoCb, + partialCb: options.partialCb, + } const res = await this.queue({ type: "script.start", + runId, script, files, options, }) + if (!res.response?.ok) { + delete this.runs[runId] // failed to start + } return res.response } async abortScript(runId: string, reason?: string): Promise { + delete this.runs[runId] const res = await this.queue({ type: "script.abort", runId, diff --git a/packages/core/src/server/messages.ts b/packages/core/src/server/messages.ts index d2a860d698..ff0538895a 100644 --- a/packages/core/src/server/messages.ts +++ b/packages/core/src/server/messages.ts @@ -1,3 +1,4 @@ +import { GenerationResult } from "../expander" import { ParsePdfResponse, ResponseStatus, @@ -106,6 +107,7 @@ export interface PromptScriptRunOptions { export interface PromptScriptStart extends RequestMessage { type: "script.start" + runId: string script: string files: string[] options: PromptScriptRunOptions @@ -119,6 +121,7 @@ export interface PromptScriptEndResponseEvent { type: "script.end" runId: string exitCode: number + result: GenerationResult } export interface PromptScriptAbort extends RequestMessage { @@ -127,10 +130,17 @@ export interface PromptScriptAbort extends RequestMessage { runId: string } -export interface TraceChunkResponseEvent { - type: "trace.chunk" +export interface PromptScriptProgressResponseEvent { + type: "script.progress" runId: string - chunk: string + + trace?: string + + progress?: string + + tokens?: number + response?: string + responseChunk?: string } export interface ShellExecResponse extends ResponseStatus { @@ -177,3 +187,7 @@ export type RequestMessages = | ContainerRemove | PromptScriptStart | PromptScriptAbort + +export type ResponseEvents = + | PromptScriptProgressResponseEvent + | PromptScriptEndResponseEvent