Skip to content

Commit

Permalink
✨ feat: add Jinja chat message rendering and Prompty integration
Browse files Browse the repository at this point in the history
  • Loading branch information
pelikhan committed Oct 18, 2024
1 parent ffa1176 commit aaef184
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 14 deletions.
19 changes: 19 additions & 0 deletions packages/core/src/jinja.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// Import the Template class from the @huggingface/jinja package
import { Template } from "@huggingface/jinja"
import { ChatCompletionMessageParam } from "./chattypes"

/**
* Renders a string template using the Jinja templating engine.
Expand All @@ -25,3 +26,21 @@ export function jinjaRender(
// Return the rendered string
return res
}

export function jinjaRenderChatMessage(
msg: ChatCompletionMessageParam,
args: Record<string, any>
) {
const { content } = msg
let template: string[] = []
if (typeof content === "string") template.push(content)
else
for (const part of content) {
if (part.type === "text") template.push(part.text)
else if (part.type === "image_url")
template.push(`![](${part.image_url})`)
else if (part.type === "refusal")
template.push(`refusal: ${part.refusal}`)
}
return jinjaRender(template.join("\n"), args)
}
49 changes: 35 additions & 14 deletions packages/core/src/promptdom.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import { YAMLStringify } from "./yaml"
import {
MARKDOWN_PROMPT_FENCE,
PROMPT_FENCE,
PROMPTY_REGEX,
TEMPLATE_ARG_DATA_SLICE_SAMPLE,
TEMPLATE_ARG_FILE_MAX_TOKENS,
} from "./constants"
Expand All @@ -27,6 +28,8 @@ import { resolveTokenEncoder } from "./encoders"
import { expandFiles } from "./fs"
import { interpolateVariables } from "./mustache"
import { createDiff } from "./diff"
import { promptyParse } from "./prompty"
import { jinjaRenderChatMessage } from "./jinja"

// Definition of the PromptNode interface which is an essential part of the code structure.
export interface PromptNode extends ContextExpansionOptions {
Expand Down Expand Up @@ -97,7 +100,6 @@ export interface PromptImportTemplate extends PromptNode {
files: string | string[] // Files to import
args?: Record<string, string | number | boolean> // Arguments for the template
options?: ImportTemplateOptions // Additional options
resolved?: Record<string, string> // Resolved content from files
}

// Interface representing a prompt image.
Expand Down Expand Up @@ -244,9 +246,7 @@ function renderDefNode(def: PromptDefNode): string {
dfence += "`"
}
const diffFormat =
body.length > 500
? " preferred_output_format=CHANGELOG "
: ""
body.length > 500 ? " preferred_output_format=CHANGELOG " : ""
const res =
(name ? name + ":\n" : "") +
dfence +
Expand Down Expand Up @@ -667,23 +667,28 @@ async function resolvePromptNode(
},
importTemplate: async (n) => {
try {
n.resolved = {}
const { files, args, options } = n
n.children = []
n.preview = ""
const fs = await (
await expandFiles(arrayify(files))
).map((filename) => <WorkspaceFile>{ filename })
if (fs.length === 0)
throw new Error(`No files found for import: ${files}`)
for (const f of fs) {
await resolveFileContent(f, options)
n.resolved[f.filename] = await interpolateVariables(
f.content,
args
)
if (PROMPTY_REGEX.test(f.filename))
await resolveImportPrompty(n, f, args, options)
else {
const rendered = await interpolateVariables(
f.content,
args
)
n.children.push(createTextNode(rendered))
n.preview += rendered + "\n"
}
}
n.preview = inspect(n.resolved, { maxDepth: 3 })
n.tokens = estimateTokens(
Object.values(n.resolved).join("\n"),
encoder
)
n.tokens = estimateTokens(n.preview, encoder)
} catch (e) {
n.error = e
}
Expand All @@ -701,6 +706,22 @@ async function resolvePromptNode(
return { errors: err }
}

async function resolveImportPrompty(
n: PromptImportTemplate,
f: WorkspaceFile,
args: Record<string, string | number | boolean>,
options: ImportTemplateOptions
) {
const { meta, messages } = promptyParse(f.content)
for (const message of messages) {
const txt = jinjaRenderChatMessage(message, args)
if (message.role === "assistant")
n.children.push(createAssistantNode(txt))
else n.children.push(createTextNode(txt))
n.preview += txt + "\n"
}
}

// Function to handle truncation of prompt nodes based on token limits.
async function truncatePromptNode(
model: string,
Expand Down
11 changes: 11 additions & 0 deletions packages/sample/genaisrc/import-prompty.genai.mjs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
script({
model: "small",
tests: {
keywords: "paris",
},
})

importTemplate("src/basic.prompty", {
question: "what is the capital of france?",
hint: "starts with p",
})

0 comments on commit aaef184

Please sign in to comment.