Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add prediction feature for model outputs #879

Merged
merged 2 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions docs/src/content/docs/reference/scripts/context.md
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,23 @@ You can schedule a check for prompt injection/jai break with your configured [co
def("FILE", env.files, { detectPromptInjection: true })
```

### Predicted output

Some models, like OpenAI gpt-4o and gpt-4o-mini, support specifying a [predicted output](https://platform.openai.com/docs/guides/predicted-outputs) (with some [limitations](https://platform.openai.com/docs/guides/predicted-outputs#limitations)). This helps reduce latency for model responses where much of the response is known ahead of time.
This can be helpful when asking the LLM to edit specific files.

Set the `prediction: true` flag to enable it on a `def` call. Note that only a single file can be predicted.

```js
def("FILE", env.files[0], { prediction: true })
```

:::note

This feature disables line number insertion.

:::

## Data definition (`defData`)

The `defData` function offers additional formatting options for converting a data object into a textual representation. It supports rendering objects as YAML, JSON, or CSV (formatted as a Markdown table).
Expand Down
7 changes: 6 additions & 1 deletion packages/core/src/chat.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// cspell: disable
import { MarkdownTrace } from "./trace"
import { PromptImage, renderPromptNode } from "./promptdom"
import { PromptImage, PromptPrediction, renderPromptNode } from "./promptdom"
import { LanguageModelConfiguration, host } from "./host"
import { GenerationOptions } from "./generation"
import {
Expand Down Expand Up @@ -801,6 +801,7 @@ export async function executeChatSession(
fileOutputs: FileOutput[],
outputProcessors: PromptOutputProcessorHandler[],
fileMerges: FileMergeHandler[],
prediction: PromptPrediction,
completer: ChatCompletionHandler,
chatParticipants: ChatParticipant[],
genOptions: GenerationOptions
Expand Down Expand Up @@ -878,6 +879,10 @@ export async function executeChatSession(
top_logprobs,
messages,
tools: fallbackTools ? undefined : tools,
// https://platform.openai.com/docs/guides/predicted-outputs
prediction: prediction?.content
? prediction
: undefined,
response_format:
responseType === "json_object"
? { type: responseType }
Expand Down
15 changes: 13 additions & 2 deletions packages/core/src/expander.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@ import {
MODEL_PROVIDER_AICI,
PROMPTY_REGEX,
} from "./constants"
import { finalizeMessages, PromptImage, renderPromptNode } from "./promptdom"
import {
finalizeMessages,
PromptImage,
PromptPrediction,
renderPromptNode,
} from "./promptdom"
import { createPromptContext } from "./promptcontext"
import { evalPrompt } from "./evalprompt"
import { renderAICI } from "./aici"
Expand Down Expand Up @@ -48,6 +53,7 @@ export async function callExpander(
let outputProcessors: PromptOutputProcessorHandler[] = []
let chatParticipants: ChatParticipant[] = []
let fileOutputs: FileOutput[] = []
let prediction: PromptPrediction
let aici: AICIRequest

const logCb = (msg: any) => {
Expand Down Expand Up @@ -79,6 +85,7 @@ export async function callExpander(
outputProcessors: ops,
chatParticipants: cps,
fileOutputs: fos,
prediction: pred,
} = await renderPromptNode(model, node, {
flexTokens: options.flexTokens,
trace,
Expand All @@ -91,6 +98,7 @@ export async function callExpander(
outputProcessors = ops
chatParticipants = cps
fileOutputs = fos
prediction = pred
if (errors?.length) {
for (const error of errors) trace.error(``, error)
status = "error"
Expand Down Expand Up @@ -127,6 +135,7 @@ export async function callExpander(
outputProcessors,
chatParticipants,
fileOutputs,
prediction,
aici,
})
}
Expand Down Expand Up @@ -232,7 +241,8 @@ export async function expandTemplate(
const outputProcessors = prompt.outputProcessors.slice(0)
const chatParticipants = prompt.chatParticipants.slice(0)
const fileOutputs = prompt.fileOutputs.slice(0)

const prediction = prompt.prediction

if (prompt.logs?.length) trace.details("📝 console.log", prompt.logs)
if (prompt.aici) trace.fence(prompt.aici, "yaml")
trace.endDetails()
Expand Down Expand Up @@ -380,6 +390,7 @@ ${schemaTs}
responseType,
responseSchema,
fileMerges,
prediction,
outputProcessors,
chatParticipants,
fileOutputs,
Expand Down
25 changes: 22 additions & 3 deletions packages/core/src/promptdom.ts
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,11 @@ export interface PromptDefNode extends PromptNode, DefOptions {
resolved?: WorkspaceFile // Resolved file content
}

export interface PromptPrediction {
type: "content"
content: string
}

// Interface for an assistant node.
export interface PromptAssistantNode extends PromptNode {
type: "assistant"
Expand Down Expand Up @@ -235,7 +240,7 @@ export function createDefDiff(
// Function to render a definition node to a string.
function renderDefNode(def: PromptDefNode): string {
const { name, resolved: file } = def
const { language, lineNumbers, schema } = def || {}
const { language, lineNumbers, schema, prediction } = def || {}

file.content = extractRange(file.content, def)

Expand All @@ -245,7 +250,8 @@ function renderDefNode(def: PromptDefNode): string {
: PROMPT_FENCE
const norm = (s: string, lang: string) => {
s = (s || "").replace(/\n*$/, "")
if (s && lineNumbers) s = addLineNumbers(s, { language: lang })
if (s && lineNumbers && !prediction)
s = addLineNumbers(s, { language: lang })
if (s) s += "\n"
return s
}
Expand Down Expand Up @@ -552,6 +558,7 @@ export interface PromptNodeRender {
chatParticipants: ChatParticipant[] // Chat participants
messages: ChatCompletionMessageParam[] // Messages for chat completion
fileOutputs: FileOutput[] // File outputs
prediction: PromptPrediction // predicted output for the prompt
}

/**
Expand Down Expand Up @@ -1054,6 +1061,7 @@ export async function renderPromptNode(
const outputProcessors: PromptOutputProcessorHandler[] = []
const chatParticipants: ChatParticipant[] = []
const fileOutputs: FileOutput[] = []
let prediction: PromptPrediction

await visitNode(node, {
error: (n) => {
Expand All @@ -1065,7 +1073,17 @@ export async function renderPromptNode(
},
def: async (n) => {
const value = n.resolved
if (value !== undefined) appendUser(renderDefNode(n))
if (value !== undefined) {
appendUser(renderDefNode(n))
if (n.prediction) {
if (prediction) n.error = "duplicate prediction"
else
prediction = {
type: "content",
content: extractRange(value.content, n),
}
}
}
},
assistant: async (n) => {
const value = await n.resolved
Expand Down Expand Up @@ -1173,6 +1191,7 @@ ${trimNewlines(schemaText)}
errors,
messages,
fileOutputs,
prediction,
})
return res
}
Expand Down
2 changes: 2 additions & 0 deletions packages/core/src/promptrunner.ts
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ export async function runTemplate(
outputProcessors,
chatParticipants,
fileOutputs,
prediction,
status,
statusText,
temperature,
Expand Down Expand Up @@ -221,6 +222,7 @@ export async function runTemplate(
fileOutputs,
outputProcessors,
fileMerges,
prediction,
completer,
chatParticipants,
genOptions
Expand Down
5 changes: 5 additions & 0 deletions packages/core/src/runpromptcontext.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import {
createSystemNode,
finalizeMessages,
PromptImage,
PromptPrediction,
} from "./promptdom"
import { MarkdownTrace } from "./trace"
import { GenerationOptions } from "./generation"
Expand Down Expand Up @@ -626,6 +627,7 @@ export function createChatGenerationContext(
const fileMerges: FileMergeHandler[] = []
const outputProcessors: PromptOutputProcessorHandler[] = []
const fileOutputs: FileOutput[] = []
let prediction: PromptPrediction

// expand template
const { provider } = parseModelIdentifier(genOptions.model)
Expand All @@ -644,6 +646,7 @@ export function createChatGenerationContext(
outputProcessors: ops,
fileOutputs: fos,
images: imgs,
prediction: pred,
} = await renderPromptNode(genOptions.model, node, {
flexTokens: genOptions.flexTokens,
trace: runTrace,
Expand All @@ -657,6 +660,7 @@ export function createChatGenerationContext(
outputProcessors.push(...ops)
fileOutputs.push(...fos)
images.push(...imgs)
prediction = pred

if (errors?.length) {
logError(errors.map((err) => errorMessage(err)).join("\n"))
Expand Down Expand Up @@ -773,6 +777,7 @@ export function createChatGenerationContext(
fileOutputs,
outputProcessors,
fileMerges,
prediction,
completer,
chatParticipants,
genOptions
Expand Down
6 changes: 6 additions & 0 deletions packages/core/src/types/prompt_template.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -930,6 +930,12 @@ interface DefOptions
* By default, throws an error if the value in def is empty.
*/
ignoreEmpty?: boolean

/**
* The content of the def is a predicted output.
* This setting disables line numbers.
*/
prediction?: boolean
}

/**
Expand Down
23 changes: 22 additions & 1 deletion packages/core/src/usage.ts
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ export class GenerationStats {
completion_tokens_details: {
audio_tokens: 0,
reasoning_tokens: 0,
accepted_prediction_tokens: 0,
rejected_prediction_tokens: 0,
},
prompt_tokens_details: {
audio_tokens: 0,
Expand Down Expand Up @@ -169,6 +171,12 @@ export class GenerationStats {
res.completion_tokens += childUsage.completion_tokens
res.prompt_tokens += childUsage.prompt_tokens
res.total_tokens += childUsage.total_tokens
res.completion_tokens_details.accepted_prediction_tokens +=
childUsage.completion_tokens_details
.accepted_prediction_tokens ?? 0
res.completion_tokens_details.rejected_prediction_tokens +=
childUsage.completion_tokens_details
.rejected_prediction_tokens ?? 0
res.completion_tokens_details.audio_tokens +=
childUsage.completion_tokens_details.audio_tokens
res.completion_tokens_details.reasoning_tokens +=
Expand Down Expand Up @@ -232,7 +240,16 @@ export class GenerationStats {
"reasoning tokens",
this.usage.completion_tokens_details.reasoning_tokens
)

if (this.usage.completion_tokens_details?.accepted_prediction_tokens)
trace.itemValue(
"accepted prediction tokens",
this.usage.completion_tokens_details.accepted_prediction_tokens
)
if (this.usage.completion_tokens_details?.rejected_prediction_tokens)
trace.itemValue(
"rejected prediction tokens",
this.usage.completion_tokens_details.rejected_prediction_tokens
)
if (this.chatTurns.length > 1) {
trace.startDetails("chat turns")
try {
Expand Down Expand Up @@ -319,6 +336,10 @@ export class GenerationStats {
usage.prompt_tokens_details?.audio_tokens ?? 0
this.usage.completion_tokens_details.reasoning_tokens +=
usage.prompt_tokens_details?.cached_tokens ?? 0
this.usage.completion_tokens_details.accepted_prediction_tokens +=
usage.completion_tokens_details?.accepted_prediction_tokens ?? 0
this.usage.completion_tokens_details.rejected_prediction_tokens +=
usage.completion_tokens_details?.rejected_prediction_tokens ?? 0

const { provider } = parseModelIdentifier(this.model)
const chatTurn = {
Expand Down
11 changes: 11 additions & 0 deletions packages/sample/genaisrc/prediction.genai.mjs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
script({
model: "openai:gpt-4o",
files: "src/greeter.ts",
tests: {
files: "src/greeter.ts",
},
})

def("FILE", env.files[0], { prediction: true })

$`Update FILE with a top level file comment that summarize the content.`
Loading