Skip to content

Commit

Permalink
more tweaking of format
Browse files Browse the repository at this point in the history
  • Loading branch information
pelikhan committed Jun 23, 2024
1 parent 2adc07a commit 9572d04
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 26 deletions.
17 changes: 9 additions & 8 deletions packages/cli/src/server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@ import {
MarkdownTrace,
TRACE_CHUNK,
TraceChunkEvent,
TraceChunkResponseEvent,
PromptScriptEndResponseEvent,
UNHANDLED_ERROR_CODE,
isCancelError,
USER_CANCELLED_ERROR_CODE,
PromptScriptProgressResponseEvent,
PromptScriptEndResponseEvent,
} from "genaiscript-core"
import { runPromptScriptTests } from "./test"
import { PROMPTFOO_VERSION } from "./version"
Expand Down Expand Up @@ -134,22 +134,23 @@ export async function startServer(options: { port: string }) {
case "script.start": {
cancelAll()

const { script, files, options, id } = data
const runId = id
const { script, files, options, runId } = data
const canceller =
new AbortSignalCancellationController()
const trace = new MarkdownTrace()
trace.addEventListener(TRACE_CHUNK, (ev) => {
const tev = ev as TraceChunkEvent
ws?.send(
JSON.stringify(<TraceChunkResponseEvent>{
type: "trace.chunk",
chunk: tev.chunk,
JSON.stringify(<
PromptScriptProgressResponseEvent
>{
type: "script.progress",
runId,
chunk: tev.chunk,
})
)
})
console.log(`run ${id} starting`)
console.log(`run ${runId} starting`)
const runner = runScript(script, files, {
...options,
trace,
Expand Down
2 changes: 1 addition & 1 deletion packages/core/src/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ export interface ChatCompletionsProgressReport {
}

export interface ChatCompletionsOptions {
partialCb?: (progres: ChatCompletionsProgressReport) => void
partialCb?: (progress: ChatCompletionsProgressReport) => void
requestOptions?: Partial<RequestInit>
maxCachedTemperature?: number
maxCachedTopP?: number
Expand Down
6 changes: 1 addition & 5 deletions packages/core/src/promptcontext.ts
Original file line number Diff line number Diff line change
Expand Up @@ -358,11 +358,7 @@ export interface GenerationOptions
ModelOptions,
ScriptRuntimeOptions {
cancellationToken?: CancellationToken
infoCb?: (partialResponse: {
text: string
label?: string
vars?: Partial<ExpansionVariables>
}) => void
infoCb?: (partialResponse: { text: string }) => void
trace: MarkdownTrace
maxCachedTemperature?: number
maxCachedTopP?: number
Expand Down
5 changes: 0 additions & 5 deletions packages/core/src/promptrunner.ts
Original file line number Diff line number Diff line change
Expand Up @@ -441,11 +441,6 @@ export async function runTemplate(
schemas,
json,
}
options?.infoCb?.({
label: res.label,
vars: res.vars,
text: undefined,
})
return res
} finally {
await host.removeContainers()
Expand Down
59 changes: 55 additions & 4 deletions packages/core/src/server/client.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import { ChatCompletionsProgressReport } from "../chat"
import { CLIENT_RECONNECT_DELAY } from "../constants"
import { randomHex } from "../crypto"
import {
ModelService,
ParsePdfResponse,
Expand All @@ -10,7 +12,7 @@ import {
RetrievalUpsertOptions,
host,
} from "../host"
import { TraceOptions } from "../trace"
import { MarkdownTrace, TraceOptions } from "../trace"
import { assert, logError } from "../util"
import {
ParsePdfMessage,
Expand All @@ -32,6 +34,8 @@ import {
PromptScriptRunOptions,
PromptScriptStart,
PromptScriptAbort,
PromptScriptProgressResponseEvent,
ResponseEvents,
} from "./messages"

export class WebSocketClient
Expand All @@ -46,6 +50,15 @@ export class WebSocketClient
private _pendingMessages: string[] = []
private _reconnectTimeout: ReturnType<typeof setTimeout> | undefined

private runs: Record<
string,
{
trace: MarkdownTrace
infoCb: (partialResponse: { text: string }) => void
partialCb: (progress: ChatCompletionsProgressReport) => void
}
> = {}

constructor(readonly url: string) {}

private installPolyfill() {
Expand Down Expand Up @@ -97,12 +110,39 @@ export class WebSocketClient
this._ws.addEventListener("message", <
(event: MessageEvent<any>) => void
>(async (event) => {
const data: RequestMessages = JSON.parse(event.data)
const { id } = data
const data = JSON.parse(event.data)
// handle responses
const req: RequestMessages = data
const { id } = req
const awaiter = this.awaiters[id]
if (awaiter) {
delete this.awaiters[id]
await awaiter.resolve(data)
await awaiter.resolve(req)
}

// handle run progress
const ev: ResponseEvents = data
const { runId, type } = ev
const run = this.runs[runId]
if (run) {
switch (type) {
case "script.progress": {
if (ev.trace) run.trace.appendContent(ev.trace)
if (ev.progress) run.infoCb({ text: ev.progress })
if (ev.response || ev.tokens !== undefined)
run.partialCb({
responseChunk: ev.response,
responseSoFar: ev.response,
tokensSoFar: ev.tokens,
})
break
}
case "script.end": {
// todo: final result message?
delete this.runs[runId]
break
}
}
}
}))
}
Expand Down Expand Up @@ -203,16 +243,27 @@ export class WebSocketClient
files: string[],
options: PromptScriptRunOptions
): Promise<PromptScriptTestRunResponse> {
const runId = randomHex(6)
this.runs[runId] = {
trace: new MarkdownTrace(),
infoCb: options.infoCb,
partialCb: options.partialCb,
}
const res = await this.queue<PromptScriptStart>({
type: "script.start",
runId,
script,
files,
options,
})
if (!res.response?.ok) {
delete this.runs[runId] // failed to start
}
return res.response
}

async abortScript(runId: string, reason?: string): Promise<ResponseStatus> {
delete this.runs[runId]
const res = await this.queue<PromptScriptAbort>({
type: "script.abort",
runId,
Expand Down
20 changes: 17 additions & 3 deletions packages/core/src/server/messages.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import { GenerationResult } from "../expander"
import {
ParsePdfResponse,
ResponseStatus,
Expand Down Expand Up @@ -106,6 +107,7 @@ export interface PromptScriptRunOptions {

export interface PromptScriptStart extends RequestMessage {
type: "script.start"
runId: string
script: string
files: string[]
options: PromptScriptRunOptions
Expand All @@ -119,6 +121,7 @@ export interface PromptScriptEndResponseEvent {
type: "script.end"
runId: string
exitCode: number
result: GenerationResult
}

export interface PromptScriptAbort extends RequestMessage {
Expand All @@ -127,10 +130,17 @@ export interface PromptScriptAbort extends RequestMessage {
runId: string
}

export interface TraceChunkResponseEvent {
type: "trace.chunk"
export interface PromptScriptProgressResponseEvent {
type: "script.progress"
runId: string
chunk: string

trace?: string

progress?: string

tokens?: number
response?: string
responseChunk?: string
}

export interface ShellExecResponse extends ResponseStatus {
Expand Down Expand Up @@ -177,3 +187,7 @@ export type RequestMessages =
| ContainerRemove
| PromptScriptStart
| PromptScriptAbort

export type ResponseEvents =
| PromptScriptProgressResponseEvent
| PromptScriptEndResponseEvent

0 comments on commit 9572d04

Please sign in to comment.