Skip to content

Commit

Permalink
support for vscode language models (#595)
Browse files Browse the repository at this point in the history
* add client provider

* don't check for languageModels proposal

* plumbings for chunks

* finished server side

* clean phase

* added cancellation

* renaming

* client messages

* more plumbing

* updated display

* more routing

* bail out the default client configuration

* disable lm access

* upgrade deps

* working proto

* don't ask twice
  • Loading branch information
pelikhan authored Jul 30, 2024
1 parent 9244608 commit 890c8ce
Show file tree
Hide file tree
Showing 26 changed files with 364 additions and 234 deletions.
2 changes: 1 addition & 1 deletion docs/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
},
"dependencies": {
"@astrojs/check": "^0.8.3",
"@astrojs/starlight": "^0.25.2",
"@astrojs/starlight": "^0.25.3",
"astro": "^4.12.2",
"typescript": "5.5.4"
},
Expand Down
2 changes: 1 addition & 1 deletion docs/src/content/docs/reference/cli/commands.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ Options:
-td, --test-delay <string> delay between tests in seconds
--no-cache disable LLM result cache
-v, --verbose verbose output
-pv, --promptfoo-version [version] promptfoo version, default is ^0.73.6
-pv, --promptfoo-version [version] promptfoo version, default is ^0.73.8
-os, --out-summary <file> append output summary in file
-h, --help display help for command
```
Expand Down
4 changes: 2 additions & 2 deletions packages/cli/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
"mammoth": "^1.8.0",
"mathjs": "^13.0.3",
"pdfjs-dist": "4.4.168",
"promptfoo": "^0.73.6",
"promptfoo": "^0.73.8",
"tree-sitter-wasms": "^0.1.11",
"tsx": "^4.16.2",
"typescript": "5.5.4",
Expand Down Expand Up @@ -63,7 +63,7 @@
"glob": "^11.0.0",
"memorystream": "^0.3.1",
"node-sarif-builder": "^3.1.0",
"openai": "^4.53.1",
"openai": "^4.53.2",
"ora": "^8.0.1",
"pretty-bytes": "^6.1.1",
"prompts": "^2.4.2",
Expand Down
12 changes: 11 additions & 1 deletion packages/cli/src/nodehost.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import {
DEFAULT_MODEL,
DEFAULT_TEMPERATURE,
MODEL_PROVIDER_AZURE,
AZURE_OPENAI_TOKEN_SCOPES,
SHELL_EXEC_TIMEOUT,
DOT_ENV_FILENAME,
MODEL_PROVIDER_OLLAMA,
Expand All @@ -43,6 +42,7 @@ import { AbortSignalOptions, TraceOptions } from "../../core/src/trace"
import { logVerbose, unique } from "../../core/src/util"
import { parseModelIdentifier } from "../../core/src/models"
import { createAzureToken } from "./azuretoken"
import { LanguageModel } from "../../core/src/chat"

class NodeServerManager implements ServerManager {
async start(): Promise<void> {
Expand Down Expand Up @@ -135,6 +135,7 @@ export class NodeHost implements RuntimeHost {
private async parseDefaults() {
await parseDefaultsFromEnv(process.env)
}
clientLanguageModel: LanguageModel

private _azureToken: string
async getLanguageModelConfiguration(
Expand All @@ -155,6 +156,15 @@ export class NodeHost implements RuntimeHost {
if (!this._azureToken) throw new Error("Azure token not available")
tok.token = "Bearer " + this._azureToken
}
if (!tok && this.clientLanguageModel) {
logVerbose(`model: using client language model`)
return <LanguageModelConfiguration>{
model: modelId,
provider: this.clientLanguageModel.id,
source: "client",
}
}

return tok
}

Expand Down
9 changes: 2 additions & 7 deletions packages/cli/src/run.ts
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,7 @@ export async function runScript(
partialCb?: (progress: ChatCompletionsProgressReport) => void
}
): Promise<{ exitCode: number; result?: GenerationResult }> {
const {
trace = new MarkdownTrace(),
infoCb,
partialCb,
} = options || {}
const { trace = new MarkdownTrace(), infoCb, partialCb } = options || {}
let result: GenerationResult
const excludedFiles = options.excludedFiles
const excludeGitIgnore = !!options.excludeGitIgnore
Expand Down Expand Up @@ -173,8 +169,7 @@ export async function runScript(
if (options.label) trace.heading(2, options.label)
const { info } = await resolveModelConnectionInfo(script, {
trace,
model:
options.model ?? script.model ?? host.defaultModelOptions.model,
model: options.model,
})
if (info.error) {
trace.error(undefined, info.error)
Expand Down
94 changes: 91 additions & 3 deletions packages/cli/src/server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,30 +8,40 @@ import {
TRACE_CHUNK,
USER_CANCELLED_ERROR_CODE,
UNHANDLED_ERROR_CODE,
DOCKER_DEFAULT_IMAGE,
MODEL_PROVIDER_CLIENT,
} from "../../core/src/constants"
import {
isCancelError,
errorMessage,
serializeError,
} from "../../core/src/error"
import {
LanguageModelConfiguration,
ResponseStatus,
ServerResponse,
host,
runtimeHost,
} from "../../core/src/host"
import { MarkdownTrace, TraceChunkEvent } from "../../core/src/trace"
import { logVerbose, logError, assert } from "../../core/src/util"
import { CORE_VERSION } from "../../core/src/version"
import { YAMLStringify } from "../../core/src/yaml"
import {
RequestMessages,
PromptScriptProgressResponseEvent,
PromptScriptEndResponseEvent,
ShellExecResponse,
ChatStart,
ChatChunk,
ChatCancel,
} from "../../core/src/server/messages"
import { envInfo } from "./info"
import { estimateTokens } from "../../core/src/tokens"
import { LanguageModel } from "../../core/src/chat"
import {
ChatCompletionResponse,
ChatCompletionsOptions,
CreateChatCompletionRequest,
} from "../../core/src/chattypes"
import { randomHex } from "../../core/src/crypto"

export async function startServer(options: { port: string }) {
const port = parseInt(options.port) || SERVER_PORT
Expand All @@ -45,15 +55,88 @@ export async function startServer(options: { port: string }) {
runner: Promise<void>
}
> = {}
const chats: Record<string, (chunk: ChatChunk) => Promise<void>> = {}

const cancelAll = () => {
for (const [runId, run] of Object.entries(runs)) {
console.log(`abort run ${runId}`)
run.canceller.abort("closing")
delete runs[runId]
}
for (const [chatId, chat] of Object.entries(chats)) {
console.log(`abort chat ${chat}`)
for (const ws of wss.clients) {
ws.send(
JSON.stringify(<ChatCancel>{
type: "chat.cancel",
chatId,
})
)
break
}

delete chats[chatId]
}
}

const handleChunk = async (chunk: ChatChunk) => {
const handler = chats[chunk.chatId]
if (handler) {
if (chunk.finishReason) delete chats[chunk.chatId]
await handler(chunk)
}
}

host.clientLanguageModel = Object.freeze<LanguageModel>({
id: MODEL_PROVIDER_CLIENT,
completer: async (
req: CreateChatCompletionRequest,
connection: LanguageModelConfiguration,
options: ChatCompletionsOptions,
trace: MarkdownTrace
): Promise<ChatCompletionResponse> => {
const { messages, model } = req
const { partialCb } = options
if (!wss.clients.size) throw new Error("no llm clients connected")

return new Promise<ChatCompletionResponse>((resolve, reject) => {
let responseSoFar: string = ""
let tokensSoFar: number = 0
let finishReason: ChatCompletionResponse["finishReason"]

// add handler
const chatId = randomHex(6)
chats[chatId] = async (chunk) => {
responseSoFar += chunk.chunk ?? ""
tokensSoFar += chunk.tokens ?? 0
partialCb?.({
tokensSoFar,
responseSoFar,
responseChunk: chunk.chunk,
})
finishReason = chunk.finishReason as any
if (finishReason) {
delete chats[chatId]
resolve({ text: responseSoFar, finishReason })
}
}

// ask for LLM
const msg = JSON.stringify(<ChatStart>{
type: "chat.start",
chatId,
model,
messages,
})
for (const ws of wss.clients) {
trace.log(`chat: sending request to client`)
ws.send(msg)
break
}
})
},
})

// cleanup runs
wss.on("close", () => {
cancelAll()
Expand Down Expand Up @@ -231,6 +314,11 @@ export async function startServer(options: { port: string }) {
}
break
}
case "chat.chunk": {
await handleChunk(data)
response = <ResponseStatus>{ ok: true }
break
}
default:
throw new Error(`unknown message type ${type}`)
}
Expand Down
4 changes: 2 additions & 2 deletions packages/core/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
"csv-parse": "^5.5.6",
"dotenv": "^16.4.5",
"esbuild": "^0.23.0",
"fast-xml-parser": "^4.4.0",
"fast-xml-parser": "^4.4.1",
"fetch-retry": "^6.0.0",
"fflate": "^0.8.2",
"file-type": "^19.3.0",
Expand All @@ -49,7 +49,7 @@
"mime-types": "^2.1.35",
"minimatch": "^10.0.1",
"minisearch": "^7.1.0",
"openai": "^4.53.1",
"openai": "^4.53.2",
"parse-diff": "^0.11.1",
"prettier": "^3.3.3",
"pretty-bytes": "^6.1.1",
Expand Down
10 changes: 10 additions & 0 deletions packages/core/src/connection.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import {
LOCALAI_API_BASE,
MODEL_PROVIDER_AICI,
MODEL_PROVIDER_AZURE,
MODEL_PROVIDER_CLIENT,
MODEL_PROVIDER_LITELLM,
MODEL_PROVIDER_LLAMAFILE,
MODEL_PROVIDER_OLLAMA,
Expand Down Expand Up @@ -217,6 +218,15 @@ export async function parseTokenFromEnv(
}
}

if (provider === MODEL_PROVIDER_CLIENT) {
return {
provider,
model,
base: undefined,
token: "client",
}
}

return undefined
}

Expand Down
1 change: 1 addition & 0 deletions packages/core/src/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ export const MODEL_PROVIDER_OLLAMA = "ollama"
export const MODEL_PROVIDER_LLAMAFILE = "llamafile"
export const MODEL_PROVIDER_LITELLM = "litellm"
export const MODEL_PROVIDER_AICI = "aici"
export const MODEL_PROVIDER_CLIENT = "client"

export const TRACE_FILE_PREVIEW_MAX_LENGTH = 240

Expand Down
1 change: 0 additions & 1 deletion packages/core/src/generation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ export interface GenerationOptions
cliInfo?: {
files: string[]
}
languageModel?: LanguageModel
vars?: PromptParameters
stats: GenerationStats
}
5 changes: 1 addition & 4 deletions packages/core/src/host.ts
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,6 @@ export interface RetrievalService {
): Promise<RetrievalSearchResponse>
}

export interface ParsePdfResponse extends ResponseStatus {
pages?: string[]
}

export interface ServerResponse extends ResponseStatus {
version: string
node: string
Expand Down Expand Up @@ -111,6 +107,7 @@ export interface Host {
options?: { token?: boolean } & AbortSignalOptions & TraceOptions
): Promise<LanguageModelConfiguration | undefined>
log(level: LogLevel, msg: string): void
clientLanguageModel?: LanguageModel

// fs
readFile(name: string): Promise<Uint8Array>
Expand Down
26 changes: 12 additions & 14 deletions packages/core/src/lm.ts
Original file line number Diff line number Diff line change
@@ -1,22 +1,20 @@
import { AICIModel } from "./aici"
import { LanguageModel } from "./chat"
import { MODEL_PROVIDER_AICI, MODEL_PROVIDER_OLLAMA } from "./constants"
import { LanguageModelConfiguration } from "./host"
import {
MODEL_PROVIDER_AICI,
MODEL_PROVIDER_CLIENT,
MODEL_PROVIDER_OLLAMA,
} from "./constants"
import { host } from "./host"
import { OllamaModel } from "./ollama"
import { OpenAIModel } from "./openai"
import { parseModelIdentifier } from "./models"

export function resolveLanguageModel(
options: {
model?: string
languageModel?: LanguageModel
},
configuration: LanguageModelConfiguration
): LanguageModel {
const { model, languageModel } = options || {}
if (languageModel) return languageModel

const { provider } = parseModelIdentifier(model)
export function resolveLanguageModel(provider: string): LanguageModel {
if (provider === MODEL_PROVIDER_CLIENT) {
const m = host.clientLanguageModel
if (!m) throw new Error("Client language model not available")
return m
}
if (provider === MODEL_PROVIDER_OLLAMA) return OllamaModel
if (provider === MODEL_PROVIDER_AICI) return AICIModel
return OpenAIModel
Expand Down
9 changes: 3 additions & 6 deletions packages/core/src/models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,17 +53,14 @@ export function traceLanguageModelConnection(
trace.startDetails(`⚙️ configuration`)
try {
trace.itemValue(`model`, model)
trace.itemValue(`version`, version)
trace.itemValue(`source`, source)
trace.itemValue(`provider`, provider)
trace.itemValue(`temperature`, temperature)
trace.itemValue(`topP`, topP)
trace.itemValue(`maxTokens`, maxTokens)
trace.itemValue(`base`, base)
trace.itemValue(`type`, type)
trace.itemValue(`version`, version)
trace.itemValue(`source`, source)
trace.itemValue(`provider`, provider)
trace.itemValue(`model`, model)
trace.itemValue(`temperature`, temperature)
trace.itemValue(`top_p`, topP)
trace.itemValue(`seed`, seed)
trace.itemValue(`cache name`, cacheName)
trace.itemValue(`response type`, responseType)
Expand Down
4 changes: 2 additions & 2 deletions packages/core/src/pdf.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import type { TextItem } from "pdfjs-dist/types/src/display/api"
import { ParsePdfResponse, host } from "./host"
import { host } from "./host"
import { TraceOptions } from "./trace"
import { installImport } from "./import"
import { PDFJS_DIST_VERSION } from "./version"
Expand Down Expand Up @@ -63,7 +63,7 @@ async function PDFTryParse(
fileOrUrl: string,
content?: Uint8Array,
options?: { disableCleanup?: boolean } & TraceOptions
): Promise<ParsePdfResponse> {
) {
const { disableCleanup, trace } = options || {}
try {
const pdfjs = await tryImportPdfjs(options)
Expand Down
Loading

0 comments on commit 890c8ce

Please sign in to comment.