Skip to content

Commit

Permalink
Add prediction feature for model outputs (#879)
Browse files Browse the repository at this point in the history
* add prediction

* feat: ✨ add prediction option and update documentation
  • Loading branch information
pelikhan authored Nov 20, 2024
1 parent 18db390 commit d631679
Show file tree
Hide file tree
Showing 9 changed files with 104 additions and 7 deletions.
17 changes: 17 additions & 0 deletions docs/src/content/docs/reference/scripts/context.md
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,23 @@ You can schedule a check for prompt injection/jai break with your configured [co
def("FILE", env.files, { detectPromptInjection: true })
```
### Predicted output
Some models, like OpenAI gpt-4o and gpt-4o-mini, support specifying a [predicted output](https://platform.openai.com/docs/guides/predicted-outputs) (with some [limitations](https://platform.openai.com/docs/guides/predicted-outputs#limitations)). This helps reduce latency for model responses where much of the response is known ahead of time.
This can be helpful when asking the LLM to edit specific files.
Set the `prediction: true` flag to enable it on a `def` call. Note that only a single file can be predicted.
```js
def("FILE", env.files[0], { prediction: true })
```
:::note
This feature disables line number insertion.
:::
## Data definition (`defData`)
The `defData` function offers additional formatting options for converting a data object into a textual representation. It supports rendering objects as YAML, JSON, or CSV (formatted as a Markdown table).
Expand Down
7 changes: 6 additions & 1 deletion packages/core/src/chat.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// cspell: disable
import { MarkdownTrace } from "./trace"
import { PromptImage, renderPromptNode } from "./promptdom"
import { PromptImage, PromptPrediction, renderPromptNode } from "./promptdom"
import { LanguageModelConfiguration, host } from "./host"
import { GenerationOptions } from "./generation"
import {
Expand Down Expand Up @@ -801,6 +801,7 @@ export async function executeChatSession(
fileOutputs: FileOutput[],
outputProcessors: PromptOutputProcessorHandler[],
fileMerges: FileMergeHandler[],
prediction: PromptPrediction,
completer: ChatCompletionHandler,
chatParticipants: ChatParticipant[],
genOptions: GenerationOptions
Expand Down Expand Up @@ -878,6 +879,10 @@ export async function executeChatSession(
top_logprobs,
messages,
tools: fallbackTools ? undefined : tools,
// https://platform.openai.com/docs/guides/predicted-outputs
prediction: prediction?.content
? prediction
: undefined,
response_format:
responseType === "json_object"
? { type: responseType }
Expand Down
15 changes: 13 additions & 2 deletions packages/core/src/expander.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@ import {
MODEL_PROVIDER_AICI,
PROMPTY_REGEX,
} from "./constants"
import { finalizeMessages, PromptImage, renderPromptNode } from "./promptdom"
import {
finalizeMessages,
PromptImage,
PromptPrediction,
renderPromptNode,
} from "./promptdom"
import { createPromptContext } from "./promptcontext"
import { evalPrompt } from "./evalprompt"
import { renderAICI } from "./aici"
Expand Down Expand Up @@ -48,6 +53,7 @@ export async function callExpander(
let outputProcessors: PromptOutputProcessorHandler[] = []
let chatParticipants: ChatParticipant[] = []
let fileOutputs: FileOutput[] = []
let prediction: PromptPrediction
let aici: AICIRequest

const logCb = (msg: any) => {
Expand Down Expand Up @@ -79,6 +85,7 @@ export async function callExpander(
outputProcessors: ops,
chatParticipants: cps,
fileOutputs: fos,
prediction: pred,
} = await renderPromptNode(model, node, {
flexTokens: options.flexTokens,
trace,
Expand All @@ -91,6 +98,7 @@ export async function callExpander(
outputProcessors = ops
chatParticipants = cps
fileOutputs = fos
prediction = pred
if (errors?.length) {
for (const error of errors) trace.error(``, error)
status = "error"
Expand Down Expand Up @@ -127,6 +135,7 @@ export async function callExpander(
outputProcessors,
chatParticipants,
fileOutputs,
prediction,
aici,
})
}
Expand Down Expand Up @@ -232,7 +241,8 @@ export async function expandTemplate(
const outputProcessors = prompt.outputProcessors.slice(0)
const chatParticipants = prompt.chatParticipants.slice(0)
const fileOutputs = prompt.fileOutputs.slice(0)

const prediction = prompt.prediction

if (prompt.logs?.length) trace.details("📝 console.log", prompt.logs)
if (prompt.aici) trace.fence(prompt.aici, "yaml")
trace.endDetails()
Expand Down Expand Up @@ -380,6 +390,7 @@ ${schemaTs}
responseType,
responseSchema,
fileMerges,
prediction,
outputProcessors,
chatParticipants,
fileOutputs,
Expand Down
25 changes: 22 additions & 3 deletions packages/core/src/promptdom.ts
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,11 @@ export interface PromptDefNode extends PromptNode, DefOptions {
resolved?: WorkspaceFile // Resolved file content
}

export interface PromptPrediction {
type: "content"
content: string
}

// Interface for an assistant node.
export interface PromptAssistantNode extends PromptNode {
type: "assistant"
Expand Down Expand Up @@ -235,7 +240,7 @@ export function createDefDiff(
// Function to render a definition node to a string.
function renderDefNode(def: PromptDefNode): string {
const { name, resolved: file } = def
const { language, lineNumbers, schema } = def || {}
const { language, lineNumbers, schema, prediction } = def || {}

file.content = extractRange(file.content, def)

Expand All @@ -245,7 +250,8 @@ function renderDefNode(def: PromptDefNode): string {
: PROMPT_FENCE
const norm = (s: string, lang: string) => {
s = (s || "").replace(/\n*$/, "")
if (s && lineNumbers) s = addLineNumbers(s, { language: lang })
if (s && lineNumbers && !prediction)
s = addLineNumbers(s, { language: lang })
if (s) s += "\n"
return s
}
Expand Down Expand Up @@ -552,6 +558,7 @@ export interface PromptNodeRender {
chatParticipants: ChatParticipant[] // Chat participants
messages: ChatCompletionMessageParam[] // Messages for chat completion
fileOutputs: FileOutput[] // File outputs
prediction: PromptPrediction // predicted output for the prompt
}

/**
Expand Down Expand Up @@ -1054,6 +1061,7 @@ export async function renderPromptNode(
const outputProcessors: PromptOutputProcessorHandler[] = []
const chatParticipants: ChatParticipant[] = []
const fileOutputs: FileOutput[] = []
let prediction: PromptPrediction

await visitNode(node, {
error: (n) => {
Expand All @@ -1065,7 +1073,17 @@ export async function renderPromptNode(
},
def: async (n) => {
const value = n.resolved
if (value !== undefined) appendUser(renderDefNode(n))
if (value !== undefined) {
appendUser(renderDefNode(n))
if (n.prediction) {
if (prediction) n.error = "duplicate prediction"
else
prediction = {
type: "content",
content: extractRange(value.content, n),
}
}
}
},
assistant: async (n) => {
const value = await n.resolved
Expand Down Expand Up @@ -1173,6 +1191,7 @@ ${trimNewlines(schemaText)}
errors,
messages,
fileOutputs,
prediction,
})
return res
}
Expand Down
2 changes: 2 additions & 0 deletions packages/core/src/promptrunner.ts
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ export async function runTemplate(
outputProcessors,
chatParticipants,
fileOutputs,
prediction,
status,
statusText,
temperature,
Expand Down Expand Up @@ -221,6 +222,7 @@ export async function runTemplate(
fileOutputs,
outputProcessors,
fileMerges,
prediction,
completer,
chatParticipants,
genOptions
Expand Down
5 changes: 5 additions & 0 deletions packages/core/src/runpromptcontext.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import {
createSystemNode,
finalizeMessages,
PromptImage,
PromptPrediction,
} from "./promptdom"
import { MarkdownTrace } from "./trace"
import { GenerationOptions } from "./generation"
Expand Down Expand Up @@ -626,6 +627,7 @@ export function createChatGenerationContext(
const fileMerges: FileMergeHandler[] = []
const outputProcessors: PromptOutputProcessorHandler[] = []
const fileOutputs: FileOutput[] = []
let prediction: PromptPrediction

// expand template
const { provider } = parseModelIdentifier(genOptions.model)
Expand All @@ -644,6 +646,7 @@ export function createChatGenerationContext(
outputProcessors: ops,
fileOutputs: fos,
images: imgs,
prediction: pred,
} = await renderPromptNode(genOptions.model, node, {
flexTokens: genOptions.flexTokens,
trace: runTrace,
Expand All @@ -657,6 +660,7 @@ export function createChatGenerationContext(
outputProcessors.push(...ops)
fileOutputs.push(...fos)
images.push(...imgs)
prediction = pred

if (errors?.length) {
logError(errors.map((err) => errorMessage(err)).join("\n"))
Expand Down Expand Up @@ -773,6 +777,7 @@ export function createChatGenerationContext(
fileOutputs,
outputProcessors,
fileMerges,
prediction,
completer,
chatParticipants,
genOptions
Expand Down
6 changes: 6 additions & 0 deletions packages/core/src/types/prompt_template.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -930,6 +930,12 @@ interface DefOptions
* By default, throws an error if the value in def is empty.
*/
ignoreEmpty?: boolean

/**
* The content of the def is a predicted output.
* This setting disables line numbers.
*/
prediction?: boolean
}

/**
Expand Down
23 changes: 22 additions & 1 deletion packages/core/src/usage.ts
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ export class GenerationStats {
completion_tokens_details: {
audio_tokens: 0,
reasoning_tokens: 0,
accepted_prediction_tokens: 0,
rejected_prediction_tokens: 0,
},
prompt_tokens_details: {
audio_tokens: 0,
Expand Down Expand Up @@ -169,6 +171,12 @@ export class GenerationStats {
res.completion_tokens += childUsage.completion_tokens
res.prompt_tokens += childUsage.prompt_tokens
res.total_tokens += childUsage.total_tokens
res.completion_tokens_details.accepted_prediction_tokens +=
childUsage.completion_tokens_details
.accepted_prediction_tokens ?? 0
res.completion_tokens_details.rejected_prediction_tokens +=
childUsage.completion_tokens_details
.rejected_prediction_tokens ?? 0
res.completion_tokens_details.audio_tokens +=
childUsage.completion_tokens_details.audio_tokens
res.completion_tokens_details.reasoning_tokens +=
Expand Down Expand Up @@ -232,7 +240,16 @@ export class GenerationStats {
"reasoning tokens",
this.usage.completion_tokens_details.reasoning_tokens
)

if (this.usage.completion_tokens_details?.accepted_prediction_tokens)
trace.itemValue(
"accepted prediction tokens",
this.usage.completion_tokens_details.accepted_prediction_tokens
)
if (this.usage.completion_tokens_details?.rejected_prediction_tokens)
trace.itemValue(
"rejected prediction tokens",
this.usage.completion_tokens_details.rejected_prediction_tokens
)
if (this.chatTurns.length > 1) {
trace.startDetails("chat turns")
try {
Expand Down Expand Up @@ -319,6 +336,10 @@ export class GenerationStats {
usage.prompt_tokens_details?.audio_tokens ?? 0
this.usage.completion_tokens_details.reasoning_tokens +=
usage.prompt_tokens_details?.cached_tokens ?? 0
this.usage.completion_tokens_details.accepted_prediction_tokens +=
usage.completion_tokens_details?.accepted_prediction_tokens ?? 0
this.usage.completion_tokens_details.rejected_prediction_tokens +=
usage.completion_tokens_details?.rejected_prediction_tokens ?? 0

const { provider } = parseModelIdentifier(this.model)
const chatTurn = {
Expand Down
11 changes: 11 additions & 0 deletions packages/sample/genaisrc/prediction.genai.mjs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
script({
model: "openai:gpt-4o",
files: "src/greeter.ts",
tests: {
files: "src/greeter.ts",
},
})

def("FILE", env.files[0], { prediction: true })

$`Update FILE with a top level file comment that summarize the content.`

0 comments on commit d631679

Please sign in to comment.