diff --git a/docs/src/content/docs/reference/scripts/context.md b/docs/src/content/docs/reference/scripts/context.md index 42adec7bf..32ca82db4 100644 --- a/docs/src/content/docs/reference/scripts/context.md +++ b/docs/src/content/docs/reference/scripts/context.md @@ -207,6 +207,17 @@ You can schedule a check for prompt injection/jai break with your configured [co def("FILE", env.files, { detectPromptInjection: true }) ``` +### Predicted output + +Some models, like OpenAI gpt-4o and gpt-4o-mini, support specifying a [predicted output](https://platform.openai.com/docs/guides/predicted-outputs). This helps reduce latency for model responses where much of the response is known ahead of time. +This can be helpful when asking the LLM to edit specific files. + +Set the `prediction: true` flag to enable it on a `def` call. Note that only a single file can be predicted. + +```js +def("FILE", env.files[0], { prediction: true }) +``` + ## Data definition (`defData`) The `defData` function offers additional formatting options for converting a data object into a textual representation. It supports rendering objects as YAML, JSON, or CSV (formatted as a Markdown table). diff --git a/packages/core/src/chat.ts b/packages/core/src/chat.ts index bdd15830c..8cd7594c4 100644 --- a/packages/core/src/chat.ts +++ b/packages/core/src/chat.ts @@ -1,6 +1,6 @@ // cspell: disable import { MarkdownTrace } from "./trace" -import { PromptImage, renderPromptNode } from "./promptdom" +import { PromptImage, PromptPrediction, renderPromptNode } from "./promptdom" import { LanguageModelConfiguration, host } from "./host" import { GenerationOptions } from "./generation" import { @@ -801,6 +801,7 @@ export async function executeChatSession( fileOutputs: FileOutput[], outputProcessors: PromptOutputProcessorHandler[], fileMerges: FileMergeHandler[], + prediction: PromptPrediction, completer: ChatCompletionHandler, chatParticipants: ChatParticipant[], genOptions: GenerationOptions @@ -878,6 +879,10 @@ export async function executeChatSession( top_logprobs, messages, tools: fallbackTools ? undefined : tools, + // https://platform.openai.com/docs/guides/predicted-outputs + prediction: prediction?.content + ? prediction + : undefined, response_format: responseType === "json_object" ? { type: responseType } diff --git a/packages/core/src/expander.ts b/packages/core/src/expander.ts index 9434e3d90..bd6020f9b 100644 --- a/packages/core/src/expander.ts +++ b/packages/core/src/expander.ts @@ -8,7 +8,12 @@ import { MODEL_PROVIDER_AICI, PROMPTY_REGEX, } from "./constants" -import { finalizeMessages, PromptImage, renderPromptNode } from "./promptdom" +import { + finalizeMessages, + PromptImage, + PromptPrediction, + renderPromptNode, +} from "./promptdom" import { createPromptContext } from "./promptcontext" import { evalPrompt } from "./evalprompt" import { renderAICI } from "./aici" @@ -48,6 +53,7 @@ export async function callExpander( let outputProcessors: PromptOutputProcessorHandler[] = [] let chatParticipants: ChatParticipant[] = [] let fileOutputs: FileOutput[] = [] + let prediction: PromptPrediction let aici: AICIRequest const logCb = (msg: any) => { @@ -79,6 +85,7 @@ export async function callExpander( outputProcessors: ops, chatParticipants: cps, fileOutputs: fos, + prediction: pred, } = await renderPromptNode(model, node, { flexTokens: options.flexTokens, trace, @@ -91,6 +98,7 @@ export async function callExpander( outputProcessors = ops chatParticipants = cps fileOutputs = fos + prediction = pred if (errors?.length) { for (const error of errors) trace.error(``, error) status = "error" @@ -127,6 +135,7 @@ export async function callExpander( outputProcessors, chatParticipants, fileOutputs, + prediction, aici, }) } @@ -232,7 +241,8 @@ export async function expandTemplate( const outputProcessors = prompt.outputProcessors.slice(0) const chatParticipants = prompt.chatParticipants.slice(0) const fileOutputs = prompt.fileOutputs.slice(0) - + const prediction = prompt.prediction + if (prompt.logs?.length) trace.details("📝 console.log", prompt.logs) if (prompt.aici) trace.fence(prompt.aici, "yaml") trace.endDetails() @@ -380,6 +390,7 @@ ${schemaTs} responseType, responseSchema, fileMerges, + prediction, outputProcessors, chatParticipants, fileOutputs, diff --git a/packages/core/src/promptdom.ts b/packages/core/src/promptdom.ts index 70fb47404..65c488c12 100644 --- a/packages/core/src/promptdom.ts +++ b/packages/core/src/promptdom.ts @@ -94,6 +94,11 @@ export interface PromptDefNode extends PromptNode, DefOptions { resolved?: WorkspaceFile // Resolved file content } +export interface PromptPrediction { + type: "content" + content: string +} + // Interface for an assistant node. export interface PromptAssistantNode extends PromptNode { type: "assistant" @@ -552,6 +557,7 @@ export interface PromptNodeRender { chatParticipants: ChatParticipant[] // Chat participants messages: ChatCompletionMessageParam[] // Messages for chat completion fileOutputs: FileOutput[] // File outputs + prediction: PromptPrediction // predicted output for the prompt } /** @@ -1054,6 +1060,7 @@ export async function renderPromptNode( const outputProcessors: PromptOutputProcessorHandler[] = [] const chatParticipants: ChatParticipant[] = [] const fileOutputs: FileOutput[] = [] + let prediction: PromptPrediction await visitNode(node, { error: (n) => { @@ -1065,7 +1072,17 @@ export async function renderPromptNode( }, def: async (n) => { const value = n.resolved - if (value !== undefined) appendUser(renderDefNode(n)) + if (value !== undefined) { + appendUser(renderDefNode(n)) + if (n.prediction) { + if (prediction) n.error = "duplicate prediction" + else + prediction = { + type: "content", + content: extractRange(value.content, n), + } + } + } }, assistant: async (n) => { const value = await n.resolved @@ -1173,6 +1190,7 @@ ${trimNewlines(schemaText)} errors, messages, fileOutputs, + prediction, }) return res } diff --git a/packages/core/src/promptrunner.ts b/packages/core/src/promptrunner.ts index 830d9bef2..ebf297606 100644 --- a/packages/core/src/promptrunner.ts +++ b/packages/core/src/promptrunner.ts @@ -138,6 +138,7 @@ export async function runTemplate( outputProcessors, chatParticipants, fileOutputs, + prediction, status, statusText, temperature, @@ -221,6 +222,7 @@ export async function runTemplate( fileOutputs, outputProcessors, fileMerges, + prediction, completer, chatParticipants, genOptions diff --git a/packages/core/src/runpromptcontext.ts b/packages/core/src/runpromptcontext.ts index 4b5838dcb..b342fed99 100644 --- a/packages/core/src/runpromptcontext.ts +++ b/packages/core/src/runpromptcontext.ts @@ -20,6 +20,7 @@ import { createSystemNode, finalizeMessages, PromptImage, + PromptPrediction, } from "./promptdom" import { MarkdownTrace } from "./trace" import { GenerationOptions } from "./generation" @@ -626,6 +627,7 @@ export function createChatGenerationContext( const fileMerges: FileMergeHandler[] = [] const outputProcessors: PromptOutputProcessorHandler[] = [] const fileOutputs: FileOutput[] = [] + let prediction: PromptPrediction // expand template const { provider } = parseModelIdentifier(genOptions.model) @@ -644,6 +646,7 @@ export function createChatGenerationContext( outputProcessors: ops, fileOutputs: fos, images: imgs, + prediction: pred, } = await renderPromptNode(genOptions.model, node, { flexTokens: genOptions.flexTokens, trace: runTrace, @@ -657,6 +660,7 @@ export function createChatGenerationContext( outputProcessors.push(...ops) fileOutputs.push(...fos) images.push(...imgs) + prediction = pred if (errors?.length) { logError(errors.map((err) => errorMessage(err)).join("\n")) @@ -773,6 +777,7 @@ export function createChatGenerationContext( fileOutputs, outputProcessors, fileMerges, + prediction, completer, chatParticipants, genOptions diff --git a/packages/core/src/types/prompt_template.d.ts b/packages/core/src/types/prompt_template.d.ts index 3c04b0d7a..50b923b84 100644 --- a/packages/core/src/types/prompt_template.d.ts +++ b/packages/core/src/types/prompt_template.d.ts @@ -930,6 +930,12 @@ interface DefOptions * By default, throws an error if the value in def is empty. */ ignoreEmpty?: boolean + + /** + * The content of the def is a predicted output. + * @see https://platform.openai.com/docs/guides/predicted-outputs + */ + prediction?: boolean } /** diff --git a/packages/core/src/usage.ts b/packages/core/src/usage.ts index 5cc35c5fa..a8a1df686 100644 --- a/packages/core/src/usage.ts +++ b/packages/core/src/usage.ts @@ -129,6 +129,8 @@ export class GenerationStats { completion_tokens_details: { audio_tokens: 0, reasoning_tokens: 0, + accepted_prediction_tokens: 0, + rejected_prediction_tokens: 0, }, prompt_tokens_details: { audio_tokens: 0, @@ -169,6 +171,12 @@ export class GenerationStats { res.completion_tokens += childUsage.completion_tokens res.prompt_tokens += childUsage.prompt_tokens res.total_tokens += childUsage.total_tokens + res.completion_tokens_details.accepted_prediction_tokens += + childUsage.completion_tokens_details + .accepted_prediction_tokens ?? 0 + res.completion_tokens_details.rejected_prediction_tokens += + childUsage.completion_tokens_details + .rejected_prediction_tokens ?? 0 res.completion_tokens_details.audio_tokens += childUsage.completion_tokens_details.audio_tokens res.completion_tokens_details.reasoning_tokens += @@ -232,7 +240,16 @@ export class GenerationStats { "reasoning tokens", this.usage.completion_tokens_details.reasoning_tokens ) - + if (this.usage.completion_tokens_details?.accepted_prediction_tokens) + trace.itemValue( + "accepted prediction tokens", + this.usage.completion_tokens_details.accepted_prediction_tokens + ) + if (this.usage.completion_tokens_details?.rejected_prediction_tokens) + trace.itemValue( + "rejected prediction tokens", + this.usage.completion_tokens_details.rejected_prediction_tokens + ) if (this.chatTurns.length > 1) { trace.startDetails("chat turns") try { diff --git a/packages/sample/genaisrc/prediction.genai.mjs b/packages/sample/genaisrc/prediction.genai.mjs new file mode 100644 index 000000000..0adcc7463 --- /dev/null +++ b/packages/sample/genaisrc/prediction.genai.mjs @@ -0,0 +1,10 @@ +script({ + files: "src/greeter.ts", + tests: { + files: "src/greeter.ts", + }, +}) + +def("FILE", env.files[0], { prediction: true }) + +$`Add comments to every line of code. Respond only with code.`