From aaef1845ed2c4d15f4b148885f1f7e26d1cbeaae Mon Sep 17 00:00:00 2001 From: Peli de Halleux Date: Fri, 18 Oct 2024 18:53:17 +0000 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20feat:=20add=20Jinja=20chat=20messag?= =?UTF-8?q?e=20rendering=20and=20Prompty=20integration?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- packages/core/src/jinja.ts | 19 +++++++ packages/core/src/promptdom.ts | 49 +++++++++++++------ .../sample/genaisrc/import-prompty.genai.mjs | 11 +++++ 3 files changed, 65 insertions(+), 14 deletions(-) create mode 100644 packages/sample/genaisrc/import-prompty.genai.mjs diff --git a/packages/core/src/jinja.ts b/packages/core/src/jinja.ts index 902acd25fb..5600e1ae6a 100644 --- a/packages/core/src/jinja.ts +++ b/packages/core/src/jinja.ts @@ -1,5 +1,6 @@ // Import the Template class from the @huggingface/jinja package import { Template } from "@huggingface/jinja" +import { ChatCompletionMessageParam } from "./chattypes" /** * Renders a string template using the Jinja templating engine. @@ -25,3 +26,21 @@ export function jinjaRender( // Return the rendered string return res } + +export function jinjaRenderChatMessage( + msg: ChatCompletionMessageParam, + args: Record +) { + const { content } = msg + let template: string[] = [] + if (typeof content === "string") template.push(content) + else + for (const part of content) { + if (part.type === "text") template.push(part.text) + else if (part.type === "image_url") + template.push(`![](${part.image_url})`) + else if (part.type === "refusal") + template.push(`refusal: ${part.refusal}`) + } + return jinjaRender(template.join("\n"), args) +} diff --git a/packages/core/src/promptdom.ts b/packages/core/src/promptdom.ts index 90865ce296..61a79f6fb3 100644 --- a/packages/core/src/promptdom.ts +++ b/packages/core/src/promptdom.ts @@ -10,6 +10,7 @@ import { YAMLStringify } from "./yaml" import { MARKDOWN_PROMPT_FENCE, PROMPT_FENCE, + PROMPTY_REGEX, TEMPLATE_ARG_DATA_SLICE_SAMPLE, TEMPLATE_ARG_FILE_MAX_TOKENS, } from "./constants" @@ -27,6 +28,8 @@ import { resolveTokenEncoder } from "./encoders" import { expandFiles } from "./fs" import { interpolateVariables } from "./mustache" import { createDiff } from "./diff" +import { promptyParse } from "./prompty" +import { jinjaRenderChatMessage } from "./jinja" // Definition of the PromptNode interface which is an essential part of the code structure. export interface PromptNode extends ContextExpansionOptions { @@ -97,7 +100,6 @@ export interface PromptImportTemplate extends PromptNode { files: string | string[] // Files to import args?: Record // Arguments for the template options?: ImportTemplateOptions // Additional options - resolved?: Record // Resolved content from files } // Interface representing a prompt image. @@ -244,9 +246,7 @@ function renderDefNode(def: PromptDefNode): string { dfence += "`" } const diffFormat = - body.length > 500 - ? " preferred_output_format=CHANGELOG " - : "" + body.length > 500 ? " preferred_output_format=CHANGELOG " : "" const res = (name ? name + ":\n" : "") + dfence + @@ -667,23 +667,28 @@ async function resolvePromptNode( }, importTemplate: async (n) => { try { - n.resolved = {} const { files, args, options } = n + n.children = [] + n.preview = "" const fs = await ( await expandFiles(arrayify(files)) ).map((filename) => { filename }) + if (fs.length === 0) + throw new Error(`No files found for import: ${files}`) for (const f of fs) { await resolveFileContent(f, options) - n.resolved[f.filename] = await interpolateVariables( - f.content, - args - ) + if (PROMPTY_REGEX.test(f.filename)) + await resolveImportPrompty(n, f, args, options) + else { + const rendered = await interpolateVariables( + f.content, + args + ) + n.children.push(createTextNode(rendered)) + n.preview += rendered + "\n" + } } - n.preview = inspect(n.resolved, { maxDepth: 3 }) - n.tokens = estimateTokens( - Object.values(n.resolved).join("\n"), - encoder - ) + n.tokens = estimateTokens(n.preview, encoder) } catch (e) { n.error = e } @@ -701,6 +706,22 @@ async function resolvePromptNode( return { errors: err } } +async function resolveImportPrompty( + n: PromptImportTemplate, + f: WorkspaceFile, + args: Record, + options: ImportTemplateOptions +) { + const { meta, messages } = promptyParse(f.content) + for (const message of messages) { + const txt = jinjaRenderChatMessage(message, args) + if (message.role === "assistant") + n.children.push(createAssistantNode(txt)) + else n.children.push(createTextNode(txt)) + n.preview += txt + "\n" + } +} + // Function to handle truncation of prompt nodes based on token limits. async function truncatePromptNode( model: string, diff --git a/packages/sample/genaisrc/import-prompty.genai.mjs b/packages/sample/genaisrc/import-prompty.genai.mjs new file mode 100644 index 0000000000..3497fb22e8 --- /dev/null +++ b/packages/sample/genaisrc/import-prompty.genai.mjs @@ -0,0 +1,11 @@ +script({ + model: "small", + tests: { + keywords: "paris", + }, +}) + +importTemplate("src/basic.prompty", { + question: "what is the capital of france?", + hint: "starts with p", +})