Skip to content

Commit

Permalink
Allow bot to analyze / generate image (#2)
Browse files Browse the repository at this point in the history
Also:

* Fix incorrect message order if run generates more than 1
* Include instructions, model and tools for each Assistant run
* Convert markdown to safe HTML for Telegram
* Use `zod` to ensure params schema
* Mask file URL for security
  • Loading branch information
daohoangson authored Nov 13, 2023
1 parent 97183d4 commit 6c84021
Show file tree
Hide file tree
Showing 15 changed files with 611 additions and 65 deletions.
6 changes: 6 additions & 0 deletions packages/core/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,19 @@
},
"devDependencies": {
"@types/node": "^20.9.0",
"@types/sanitize-html": "^2.9.4",
"sst": "^2.36.1",
"vitest": "^0.34.6"
},
"dependencies": {
"@anatine/zod-openapi": "^2.2.1",
"@aws-sdk/client-dynamodb": "^3.449.0",
"@aws-sdk/lib-dynamodb": "^3.449.0",
"marked": "^10.0.0",
"openai": "^4.17.4",
"openapi3-ts": "^4.1.2",
"sanitize-html": "^2.11.0",
"serialize-error": "^11.0.3",
"telegraf": "^4.15.0",
"zod": "^3.22.4"
}
Expand Down
23 changes: 21 additions & 2 deletions packages/core/src/3rdparty/openai/assistant_message.ts
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
import { Reply } from "../../abstracts/chat";
import { assistantId, threads } from "./openai";
import { analyzeImage, generateImage } from "./tools/image";
import { newThread } from "./tools/ops";

export async function* assistantGetNewMessages(
threadId: string,
runId: string
): AsyncGenerator<Reply> {
const list = await threads.messages.list(threadId);
for (const threadMessage of list.data) {
const threadMessages = [...list.data].reverse();
for (const threadMessage of threadMessages) {
if (threadMessage.run_id === runId) {
console.log(JSON.stringify(threadMessage, null, 2));
for (const messageContent of threadMessage.content) {
if (messageContent.type === "text") {
const markdown = messageContent.text.value;
yield { markdown };
yield { type: "markdown", markdown };
}
}
}
Expand All @@ -24,8 +27,24 @@ export async function assistantSendMessage(
content: string
): Promise<{ runId: string }> {
await threads.messages.create(threadId, { content, role: "user" });

const instructions = `Your name is Bubby.
You are a personal assistant bot. Ensure efficient and user-friendly interaction, focusing on simplicity and clarity in communication.
You provide concise and direct answers. Maintain a straightforward and easy-going conversation tone. Keep responses brief, typically in short sentences.
You can only reply to text or photo messages.`;
const run = await threads.runs.create(threadId, {
assistant_id: assistantId,
instructions,
model: "gpt-4-1106-preview",
tools: [
{ type: "code_interpreter" },
{ type: "retrieval" },
// vision
analyzeImage,
generateImage,
// ops
newThread,
],
});

return { runId: run.id };
Expand Down
46 changes: 46 additions & 0 deletions packages/core/src/3rdparty/openai/tools/image.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import { generateSchema } from "@anatine/zod-openapi";
import { FunctionParameters } from "openai/resources";
import { RunCreateParams } from "openai/resources/beta/threads/runs/runs";
import { z } from "zod";

export const analyzeImageParameters = z.object({
image_url: z.string({
description: "The image URL.",
}),
prompt: z.string({
description: "The prompt to ask the Vision AI model to analyze.",
}),
temperature: z.number({
description:
"What sampling temperature to use, between 0 and 2. " +
"Higher values like 0.8 will make the output more random, " +
"while lower values like 0.2 will make it more focused and deterministic.",
}),
});

export const analyzeImage: RunCreateParams.AssistantToolsFunction = {
function: {
description: "Analyze an image.",
name: "analyze_image",
parameters: generateSchema(analyzeImageParameters) as FunctionParameters,
},
type: "function",
};

export const generateImageParameters = z.object({
prompt: z.string({
description: "The prompt to ask the Vision AI model to generate.",
}),
size: z.enum(["1024x1024", "1792x1024", "1024x1792"], {
description: "The size of the generated images.",
}),
});

export const generateImage: RunCreateParams.AssistantToolsFunction = {
function: {
description: "Generate an image.",
name: "generate_image",
parameters: generateSchema(generateImageParameters) as FunctionParameters,
},
type: "function",
};
17 changes: 17 additions & 0 deletions packages/core/src/3rdparty/openai/tools/ops.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import { generateSchema } from "@anatine/zod-openapi";
import { FunctionParameters } from "openai/resources";
import { RunCreateParams } from "openai/resources/beta/threads/runs/runs";
import { z } from "zod";

export const newThread: RunCreateParams.AssistantToolsFunction = {
function: {
description: "Discard the recent messages and start a new thread.",
name: "new_thread",
parameters: {
type: "object",
properties: {},
required: [],
},
},
type: "function",
};
71 changes: 71 additions & 0 deletions packages/core/src/3rdparty/openai/vision_preview.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import {
ChatCompletionCreateParamsNonStreaming,
ImageGenerateParams,
} from "openai/resources";

import { openai } from "./openai";
import { AssistantThreadInput } from "./assistant_thread";

export async function visionAnalyzeImage(
input: AssistantThreadInput,
{
image_url,
prompt,
temperature,
}: {
image_url: string;
prompt: string;
temperature: number;
}
): Promise<string> {
let url = image_url;

const unmaskedUrl = await input.chat.unmaskFileUrl(image_url);
if (typeof unmaskedUrl === "string") {
url = unmaskedUrl;
}

const body: ChatCompletionCreateParamsNonStreaming = {
messages: [
{
content: [
{ type: "text", text: prompt },
{ type: "image_url", image_url: { url } },
],
role: "user",
},
],
model: "gpt-4-vision-preview",
max_tokens: 1024,
temperature,
};
console.log(JSON.stringify(body, null, 2));
const completion = await openai.chat.completions.create(body);
console.log(JSON.stringify(completion, null, 2));
return completion.choices[0].message.content ?? "";
}

export async function visionGenerateImage({
prompt,
size,
}: {
prompt: string;
size: ImageGenerateParams["size"];
}): Promise<{ caption: string; url: string } | undefined> {
const body: ImageGenerateParams = {
prompt,
model: "dall-e-3",
response_format: "url",
size,
};
console.log(JSON.stringify(body, null, 2));
const completion = await openai.images.generate(body);
console.log(JSON.stringify(completion, null, 2));
const image = completion.data[0];
if (typeof image === "object") {
return {
caption: image.revised_prompt ?? prompt,
url: image.url!,
};
}
}
97 changes: 91 additions & 6 deletions packages/core/src/3rdparty/openai/wait_for_assistant_run.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import { APIError } from "openai";
import type {
RequiredActionFunctionToolCall,
Run,
RunSubmitToolOutputsParams,
} from "openai/resources/beta/threads/runs/runs";
import { serializeError } from "serialize-error";
import { z } from "zod";

import {
AssistantThreadInput,
Expand All @@ -10,6 +14,14 @@ import {
import { threads } from "./openai";
import { Reply } from "../../abstracts/chat";
import { AssistantError } from "../../abstracts/assistant";
import {
analyzeImage,
analyzeImageParameters,
generateImage,
generateImageParameters,
} from "./tools/image";
import { newThread } from "./tools/ops";
import { visionAnalyzeImage, visionGenerateImage } from "./vision_preview";

export type AsisstantWaitForRunInput = AssistantThreadInput & {
threadId: string;
Expand Down Expand Up @@ -55,15 +67,88 @@ async function* takeRequiredActions(
const sto = requiredAction.submit_tool_outputs;
const tool_outputs: RunSubmitToolOutputsParams.ToolOutput[] = [];
for (const toolCall of sto.tool_calls) {
if (toolCall.function.name === "new_thread") {
const newThreadId = await assistantThreadIdInsert(input);
tool_outputs.push({
tool_call_id: toolCall.id,
output: newThreadId,
});
switch (toolCall.function.name) {
// image
case analyzeImage.function.name:
const analyzedImage = yield* takeRequiredAction(
toolCall,
analyzeImageParameters,
async function* (params) {
yield { type: "plaintext", plaintext: "🚨 Analyzing image..." };
return await visionAnalyzeImage(input, params);
}
);
tool_outputs.push(analyzedImage);
break;
case generateImage.function.name:
const generatedImage = yield* takeRequiredAction(
toolCall,
generateImageParameters,
async function* (params) {
yield {
type: "plaintext",
plaintext: "🚨 Generating image...",
};
const image = await visionGenerateImage(params);
yield { type: "photo", ...image } as Reply;
return {
success: true,
description: `Image has been generated and sent to user successfully.`,
};
}
);
tool_outputs.push(generatedImage);
break;
// ops
case newThread.function.name:
const newThreadId = yield* takeRequiredAction(
toolCall,
z.object({}),
async function* () {
const inserted = await assistantThreadIdInsert(input);
yield { type: "plaintext", plaintext: "🚨 New thread" };
return inserted;
}
);
tool_outputs.push(newThreadId);
break;
}
}
await threads.runs.submitToolOutputs(threadId, runId, { tool_outputs });
break;
}
}

async function* takeRequiredAction<T>(
toolCall: RequiredActionFunctionToolCall,
parameters: z.ZodType<T>,
callback: (parameters: T) => AsyncGenerator<Reply, any>
): AsyncGenerator<Reply, RunSubmitToolOutputsParams.ToolOutput> {
let params: T;
try {
const json = JSON.parse(toolCall.function.arguments);
params = parameters.parse(json);
} catch (paramsError) {
const obj = { paramsError: serializeError(paramsError) };
console.warn(obj);
return { tool_call_id: toolCall.id, output: JSON.stringify(obj) };
}

try {
const success = yield* callback(params);
return {
tool_call_id: toolCall.id,
output:
typeof success === "boolean"
? JSON.stringify({ success })
: JSON.stringify(success),
};
} catch (error) {
let obj: any = { error: serializeError(error) };
if (error instanceof APIError) {
obj = { apiError: serializeError(error.error) };
}
console.warn(obj);
return { tool_call_id: toolCall.id, output: JSON.stringify(obj) };
}
}
39 changes: 39 additions & 0 deletions packages/core/src/3rdparty/telegram/formatting.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import { marked } from "marked";
import sanitizeHtml from "sanitize-html";

export function convertMarkdownToSafeHtml(markdown: string): string {
const html = marked.parse(markdown);
return sanitizeHtml(html, {
allowedAttributes: {
code: ["class"],
},
allowedTags: [
// https://core.telegram.org/bots/api#formatting-options
// <b>bold</b>, <strong>bold</strong>
"b",
"strong",
// <i>italic</i>, <em>italic</em>
"i",
"em",
// <u>underline</u>, <ins>underline</ins>
"u",
"ins",
// <s>strikethrough</s>, <strike>strikethrough</strike>, <del>strikethrough</del>
"s",
"strike",
"del",
// <span class="tg-spoiler">spoiler</span>, <tg-spoiler>spoiler</tg-spoiler>
"span",
"tg-spoiler",
// <a href="http://www.example.com/">inline URL</a>
// <a href="tg://user?id=123456789">inline mention of a user</a>
"a",
// <tg-emoji emoji-id="5368324170671202286">👍</tg-emoji>
"tg-emoji",
// <code>inline fixed-width code</code>
"code",
// <pre>pre-formatted fixed-width code block</pre>
"pre",
],
});
}
Loading

0 comments on commit 6c84021

Please sign in to comment.