Skip to content

Commit

Permalink
add prediction
Browse files Browse the repository at this point in the history
  • Loading branch information
pelikhan committed Nov 20, 2024
1 parent 18db390 commit f666534
Show file tree
Hide file tree
Showing 9 changed files with 90 additions and 5 deletions.
11 changes: 11 additions & 0 deletions docs/src/content/docs/reference/scripts/context.md
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,17 @@ You can schedule a check for prompt injection/jai break with your configured [co
def("FILE", env.files, { detectPromptInjection: true })
```
### Predicted output
Some models, like OpenAI gpt-4o and gpt-4o-mini, support specifying a [predicted output](https://platform.openai.com/docs/guides/predicted-outputs). This helps reduce latency for model responses where much of the response is known ahead of time.
This can be helpful when asking the LLM to edit specific files.
Set the `prediction: true` flag to enable it on a `def` call. Note that only a single file can be predicted.
```js
def("FILE", env.files[0], { prediction: true })
```
## Data definition (`defData`)
The `defData` function offers additional formatting options for converting a data object into a textual representation. It supports rendering objects as YAML, JSON, or CSV (formatted as a Markdown table).
Expand Down
7 changes: 6 additions & 1 deletion packages/core/src/chat.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// cspell: disable
import { MarkdownTrace } from "./trace"
import { PromptImage, renderPromptNode } from "./promptdom"
import { PromptImage, PromptPrediction, renderPromptNode } from "./promptdom"
import { LanguageModelConfiguration, host } from "./host"
import { GenerationOptions } from "./generation"
import {
Expand Down Expand Up @@ -801,6 +801,7 @@ export async function executeChatSession(
fileOutputs: FileOutput[],
outputProcessors: PromptOutputProcessorHandler[],
fileMerges: FileMergeHandler[],
prediction: PromptPrediction,
completer: ChatCompletionHandler,
chatParticipants: ChatParticipant[],
genOptions: GenerationOptions
Expand Down Expand Up @@ -878,6 +879,10 @@ export async function executeChatSession(
top_logprobs,
messages,
tools: fallbackTools ? undefined : tools,
// https://platform.openai.com/docs/guides/predicted-outputs
prediction: prediction?.content
? prediction
: undefined,
response_format:
responseType === "json_object"
? { type: responseType }
Expand Down
15 changes: 13 additions & 2 deletions packages/core/src/expander.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@ import {
MODEL_PROVIDER_AICI,
PROMPTY_REGEX,
} from "./constants"
import { finalizeMessages, PromptImage, renderPromptNode } from "./promptdom"
import {
finalizeMessages,
PromptImage,
PromptPrediction,
renderPromptNode,
} from "./promptdom"
import { createPromptContext } from "./promptcontext"
import { evalPrompt } from "./evalprompt"
import { renderAICI } from "./aici"
Expand Down Expand Up @@ -48,6 +53,7 @@ export async function callExpander(
let outputProcessors: PromptOutputProcessorHandler[] = []
let chatParticipants: ChatParticipant[] = []
let fileOutputs: FileOutput[] = []
let prediction: PromptPrediction
let aici: AICIRequest

const logCb = (msg: any) => {
Expand Down Expand Up @@ -79,6 +85,7 @@ export async function callExpander(
outputProcessors: ops,
chatParticipants: cps,
fileOutputs: fos,
prediction: pred,
} = await renderPromptNode(model, node, {
flexTokens: options.flexTokens,
trace,
Expand All @@ -91,6 +98,7 @@ export async function callExpander(
outputProcessors = ops
chatParticipants = cps
fileOutputs = fos
prediction = pred
if (errors?.length) {
for (const error of errors) trace.error(``, error)
status = "error"
Expand Down Expand Up @@ -127,6 +135,7 @@ export async function callExpander(
outputProcessors,
chatParticipants,
fileOutputs,
prediction,
aici,
})
}
Expand Down Expand Up @@ -232,7 +241,8 @@ export async function expandTemplate(
const outputProcessors = prompt.outputProcessors.slice(0)
const chatParticipants = prompt.chatParticipants.slice(0)
const fileOutputs = prompt.fileOutputs.slice(0)

const prediction = prompt.prediction

if (prompt.logs?.length) trace.details("📝 console.log", prompt.logs)
if (prompt.aici) trace.fence(prompt.aici, "yaml")
trace.endDetails()
Expand Down Expand Up @@ -380,6 +390,7 @@ ${schemaTs}
responseType,
responseSchema,
fileMerges,
prediction,
outputProcessors,
chatParticipants,
fileOutputs,
Expand Down
20 changes: 19 additions & 1 deletion packages/core/src/promptdom.ts
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,11 @@ export interface PromptDefNode extends PromptNode, DefOptions {
resolved?: WorkspaceFile // Resolved file content
}

export interface PromptPrediction {
type: "content"
content: string
}

// Interface for an assistant node.
export interface PromptAssistantNode extends PromptNode {
type: "assistant"
Expand Down Expand Up @@ -552,6 +557,7 @@ export interface PromptNodeRender {
chatParticipants: ChatParticipant[] // Chat participants
messages: ChatCompletionMessageParam[] // Messages for chat completion
fileOutputs: FileOutput[] // File outputs
prediction: PromptPrediction // predicted output for the prompt
}

/**
Expand Down Expand Up @@ -1054,6 +1060,7 @@ export async function renderPromptNode(
const outputProcessors: PromptOutputProcessorHandler[] = []
const chatParticipants: ChatParticipant[] = []
const fileOutputs: FileOutput[] = []
let prediction: PromptPrediction

await visitNode(node, {
error: (n) => {
Expand All @@ -1065,7 +1072,17 @@ export async function renderPromptNode(
},
def: async (n) => {
const value = n.resolved
if (value !== undefined) appendUser(renderDefNode(n))
if (value !== undefined) {
appendUser(renderDefNode(n))
if (n.prediction) {
if (prediction) n.error = "duplicate prediction"
else
prediction = {
type: "content",
content: extractRange(value.content, n),
}
}
}
},
assistant: async (n) => {
const value = await n.resolved
Expand Down Expand Up @@ -1173,6 +1190,7 @@ ${trimNewlines(schemaText)}
errors,
messages,
fileOutputs,
prediction,
})
return res
}
Expand Down
2 changes: 2 additions & 0 deletions packages/core/src/promptrunner.ts
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ export async function runTemplate(
outputProcessors,
chatParticipants,
fileOutputs,
prediction,
status,
statusText,
temperature,
Expand Down Expand Up @@ -221,6 +222,7 @@ export async function runTemplate(
fileOutputs,
outputProcessors,
fileMerges,
prediction,
completer,
chatParticipants,
genOptions
Expand Down
5 changes: 5 additions & 0 deletions packages/core/src/runpromptcontext.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import {
createSystemNode,
finalizeMessages,
PromptImage,
PromptPrediction,
} from "./promptdom"
import { MarkdownTrace } from "./trace"
import { GenerationOptions } from "./generation"
Expand Down Expand Up @@ -626,6 +627,7 @@ export function createChatGenerationContext(
const fileMerges: FileMergeHandler[] = []
const outputProcessors: PromptOutputProcessorHandler[] = []
const fileOutputs: FileOutput[] = []
let prediction: PromptPrediction

// expand template
const { provider } = parseModelIdentifier(genOptions.model)
Expand All @@ -644,6 +646,7 @@ export function createChatGenerationContext(
outputProcessors: ops,
fileOutputs: fos,
images: imgs,
prediction: pred,
} = await renderPromptNode(genOptions.model, node, {
flexTokens: genOptions.flexTokens,
trace: runTrace,
Expand All @@ -657,6 +660,7 @@ export function createChatGenerationContext(
outputProcessors.push(...ops)
fileOutputs.push(...fos)
images.push(...imgs)
prediction = pred

if (errors?.length) {
logError(errors.map((err) => errorMessage(err)).join("\n"))
Expand Down Expand Up @@ -773,6 +777,7 @@ export function createChatGenerationContext(
fileOutputs,
outputProcessors,
fileMerges,
prediction,
completer,
chatParticipants,
genOptions
Expand Down
6 changes: 6 additions & 0 deletions packages/core/src/types/prompt_template.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -930,6 +930,12 @@ interface DefOptions
* By default, throws an error if the value in def is empty.
*/
ignoreEmpty?: boolean

/**
* The content of the def is a predicted output.
* @see https://platform.openai.com/docs/guides/predicted-outputs
*/
prediction?: boolean
}

/**
Expand Down
19 changes: 18 additions & 1 deletion packages/core/src/usage.ts
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ export class GenerationStats {
completion_tokens_details: {
audio_tokens: 0,
reasoning_tokens: 0,
accepted_prediction_tokens: 0,
rejected_prediction_tokens: 0,
},
prompt_tokens_details: {
audio_tokens: 0,
Expand Down Expand Up @@ -169,6 +171,12 @@ export class GenerationStats {
res.completion_tokens += childUsage.completion_tokens
res.prompt_tokens += childUsage.prompt_tokens
res.total_tokens += childUsage.total_tokens
res.completion_tokens_details.accepted_prediction_tokens +=
childUsage.completion_tokens_details
.accepted_prediction_tokens ?? 0
res.completion_tokens_details.rejected_prediction_tokens +=
childUsage.completion_tokens_details
.rejected_prediction_tokens ?? 0
res.completion_tokens_details.audio_tokens +=
childUsage.completion_tokens_details.audio_tokens
res.completion_tokens_details.reasoning_tokens +=
Expand Down Expand Up @@ -232,7 +240,16 @@ export class GenerationStats {
"reasoning tokens",
this.usage.completion_tokens_details.reasoning_tokens
)

if (this.usage.completion_tokens_details?.accepted_prediction_tokens)
trace.itemValue(
"accepted prediction tokens",
this.usage.completion_tokens_details.accepted_prediction_tokens
)
if (this.usage.completion_tokens_details?.rejected_prediction_tokens)
trace.itemValue(
"rejected prediction tokens",
this.usage.completion_tokens_details.rejected_prediction_tokens
)
if (this.chatTurns.length > 1) {
trace.startDetails("chat turns")
try {
Expand Down
10 changes: 10 additions & 0 deletions packages/sample/genaisrc/prediction.genai.mjs
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
script({
files: "src/greeter.ts",
tests: {
files: "src/greeter.ts",
},
})

def("FILE", env.files[0], { prediction: true })

$`Add comments to every line of code. Respond only with code.`

0 comments on commit f666534

Please sign in to comment.