diff --git a/apps/workers/inference.ts b/apps/workers/inference.ts index 071f4742..1f9fecdb 100644 --- a/apps/workers/inference.ts +++ b/apps/workers/inference.ts @@ -1,5 +1,6 @@ import { Ollama } from "ollama"; import OpenAI from "openai"; +import { encoding_for_model, TiktokenModel } from "tiktoken"; import serverConfig from "@hoarder/shared/config"; import logger from "@hoarder/shared/logger"; @@ -9,6 +10,10 @@ export interface InferenceResponse { totalTokens: number | undefined; } +export interface EmbeddingResponse { + embeddings: number[][]; +} + export interface InferenceClient { inferFromText(prompt: string): Promise; inferFromImage( @@ -16,6 +21,7 @@ export interface InferenceClient { contentType: string, image: string, ): Promise; + generateEmbeddingFromText(prompt: string): Promise; } export class InferenceClientFactory { @@ -87,6 +93,30 @@ class OpenAIInferenceClient implements InferenceClient { } return { response, totalTokens: chatCompletion.usage?.total_tokens }; } + + truncateTextTokens(text: string, maxTokens: number, model: string) { + const encoding = encoding_for_model(model as TiktokenModel); + const encoded = encoding.encode(text); + if (encoded.length <= maxTokens) { + return text; + } + + return new TextDecoder().decode( + encoding.decode(encoded.slice(0, maxTokens)), + ); + } + + async generateEmbeddingFromText(prompt: string): Promise { + const model = serverConfig.embedding.textModel; + const embedResponse = await this.openAI.embeddings.create({ + model: model, + input: [this.truncateTextTokens(prompt, 2000, model)], + }); + const embedding2D: number[][] = embedResponse.data.map( + (embedding: OpenAI.Embedding) => embedding.embedding, + ); + return { embeddings: embedding2D }; + } } class OllamaInferenceClient implements InferenceClient { @@ -134,6 +164,17 @@ class OllamaInferenceClient implements InferenceClient { return { response, totalTokens }; } + async runEmbeddingModel(model: string, prompt: string) { + const embedding = await this.ollama.embed({ + model: model, + input: prompt, + // Truncate the input to fit into the model's max token limit, + // in the future we want to add a way to split the input into multiple parts. + truncate: true, + }); + return { response: embedding }; + } + async inferFromText(prompt: string): Promise { return await this.runModel(serverConfig.inference.textModel, prompt); } @@ -149,4 +190,12 @@ class OllamaInferenceClient implements InferenceClient { image, ); } + + async generateEmbeddingFromText(prompt: string): Promise { + const embedResponse = await this.runEmbeddingModel( + serverConfig.embedding.textModel, + prompt, + ); + return { embeddings: embedResponse.response.embeddings }; + } } diff --git a/apps/workers/openaiWorker.ts b/apps/workers/openaiWorker.ts index 8bd2cf4a..f28b4d67 100644 --- a/apps/workers/openaiWorker.ts +++ b/apps/workers/openaiWorker.ts @@ -23,6 +23,8 @@ import type { InferenceClient } from "./inference"; import { InferenceClientFactory } from "./inference"; import { readPDFText, truncateContent } from "./utils"; +type BookmarkType = "link" | "text" | "image" | "pdf" | "unsupported"; + const openAIResponseSchema = z.object({ tags: z.array(z.string()), }); @@ -60,7 +62,7 @@ async function attemptMarkTaggingStatus( export class OpenAiWorker { static build() { - logger.info("Starting inference worker ..."); + logger.info("Starting AI worker ..."); const worker = new Runner( OpenAIQueue, { @@ -114,17 +116,10 @@ Aim for 3-5 tags. If there are no good tags, leave the array empty. function buildPrompt( bookmark: NonNullable>>, ) { - if (bookmark.link) { - if (!bookmark.link.description && !bookmark.link.content) { - throw new Error( - `No content found for link "${bookmark.id}". Skipping ...`, - ); - } + const content = extractTextFromBookmark(bookmark); + const bType = bookmarkType(bookmark); - let content = bookmark.link.content; - if (content) { - content = truncateContent(content); - } + if (bType === "link") { return ` ${TEXT_PROMPT_BASE} URL: ${bookmark.link.url} @@ -134,7 +129,7 @@ Content: ${content ?? ""} ${TEXT_PROMPT_INSTRUCTIONS}`; } - if (bookmark.text) { + if (bType == "text") { const content = truncateContent(bookmark.text.text ?? ""); // TODO: Ensure that the content doesn't exceed the context length of openai return ` @@ -224,25 +219,39 @@ async function inferTagsFromText( return await inferenceClient.inferFromText(buildPrompt(bookmark)); } +function bookmarkType( + bookmark: NonNullable>>, +): BookmarkType { + if (bookmark.link) { + return "link"; + } else if (bookmark.text) { + return "text"; + } + switch (bookmark.asset.assetType) { + case "image": + return "image"; + break; + case "pdf": + return "pdf"; + break; + default: + return "unsupported"; + } +} + async function inferTags( jobId: string, bookmark: NonNullable>>, inferenceClient: InferenceClient, ) { let response; - if (bookmark.link || bookmark.text) { + const bType = bookmarkType(bookmark); + if (bType === "text" || bType == "link") { response = await inferTagsFromText(bookmark, inferenceClient); - } else if (bookmark.asset) { - switch (bookmark.asset.assetType) { - case "image": - response = await inferTagsFromImage(jobId, bookmark, inferenceClient); - break; - case "pdf": - response = await inferTagsFromPDF(jobId, bookmark, inferenceClient); - break; - default: - throw new Error(`[inference][${jobId}] Unsupported bookmark type`); - } + } else if (bType == "image") { + response = await inferTagsFromImage(jobId, bookmark, inferenceClient); + } else if (bType == "pdf") { + response = await inferTagsFromPDF(jobId, bookmark, inferenceClient); } else { throw new Error(`[inference][${jobId}] Unsupported bookmark type`); } @@ -362,6 +371,93 @@ async function connectTags( }); } +// TODO: Make this function accept max tokens as an argument. +// TODO: Truncate text logic needs to be taken refactored such that the max token are tied to the model +// being used and not done once per bookmark. +function extractTextFromBookmark( + bookmark: NonNullable>>, +): string { + if (bookmark.link) { + if (!bookmark.link.description && !bookmark.link.content) { + throw new Error( + `No content found for link "${bookmark.id}". Skipping ...`, + ); + } + + let content = bookmark.link.content; + if (content) { + content = truncateContent(content); + } + return content ?? ""; + } + + if (!bookmark.text) { + logger.error( + `[extractTextFromBookmark] Unsupported bookmark type, skipping ...`, + ); + return ""; + } + const content = truncateContent(bookmark.text.text ?? ""); + if (!content) { + throw new Error( + `[inference] [UNEXPECTED] TruncateContent returned empty content for bookmark "${bookmark.id}". Skipping ...`, + ); + } + return content; +} + +async function extractTextFromPDFBookmark( + bookmark: NonNullable>>, + jobId: string, +) { + const { asset } = await readAsset({ + userId: bookmark.userId, + assetId: bookmark.asset.assetId, + }); + if (!asset) { + throw new Error( + `[inference][${jobId}] AssetId ${bookmark.asset.assetId} for bookmark ${bookmark.id} not found`, + ); + } + const pdfParse = await readPDFText(asset); + if (!pdfParse?.text) { + throw new Error( + `[inference][${jobId}] PDF text is empty. Please make sure that the PDF includes text and not just images.`, + ); + } + const content = truncateContent(pdfParse.text); + if (!content) { + throw new Error( + `[inference][${jobId}] [UNEXPECTED] TruncateContent returned empty content for PDF "${bookmark.id}". Skipping ...`, + ); + } + return content; +} + +async function embedBookmark( + jobId: string, + bookmark: NonNullable>>, + inferenceClient: InferenceClient, +) { + logger.info(`[embedding][${jobId}] ookmark ${bookmark.id}`); + const bType = bookmarkType(bookmark); + logger.info(`[embedding][${jobId}] Bookmark type: ${bType}`); + if (bType === "text") { + const embedding = await inferenceClient.generateEmbeddingFromText( + extractTextFromBookmark(bookmark), + ); + logger.info( + `[embeddings] Embedding generated successfully: ${embedding.embeddings}`, + ); + } else if (bType == "pdf") { + const content = await extractTextFromPDFBookmark(bookmark, jobId); + const embedding = await inferenceClient.generateEmbeddingFromText(content); + logger.info( + `[embeddings] Embedding generated successfully: ${embedding.embeddings}`, + ); + } +} + async function runOpenAI(job: DequeuedJob) { const jobId = job.id ?? "unknown"; @@ -398,4 +494,6 @@ async function runOpenAI(job: DequeuedJob) { // Update the search index await triggerSearchReindex(bookmarkId); + + await embedBookmark(jobId, bookmark, inferenceClient); } diff --git a/apps/workers/package.json b/apps/workers/package.json index bbd5b17d..0532ec66 100644 --- a/apps/workers/package.json +++ b/apps/workers/package.json @@ -34,6 +34,7 @@ "puppeteer-extra": "^3.3.6", "puppeteer-extra-plugin-adblocker": "^2.13.6", "puppeteer-extra-plugin-stealth": "^2.11.2", + "tiktoken": "^1.0.16", "tsx": "^4.7.1", "typescript": "^5.3.3", "zod": "^3.22.4" diff --git a/packages/shared/config.ts b/packages/shared/config.ts index 21cdb1c8..107cb3a7 100644 --- a/packages/shared/config.ts +++ b/packages/shared/config.ts @@ -23,6 +23,7 @@ const allEnv = z.object({ INFERENCE_JOB_TIMEOUT_SEC: z.coerce.number().default(30), INFERENCE_TEXT_MODEL: z.string().default("gpt-4o-mini"), INFERENCE_IMAGE_MODEL: z.string().default("gpt-4o-mini"), + EMBEDDING_TEXT_MODEL: z.string().default("text-embedding-3-small"), CRAWLER_HEADLESS_BROWSER: stringBool("true"), BROWSER_WEB_URL: z.string().url().optional(), BROWSER_WEBSOCKET_URL: z.string().url().optional(), @@ -73,6 +74,9 @@ const serverConfigSchema = allEnv.transform((val) => { imageModel: val.INFERENCE_IMAGE_MODEL, inferredTagLang: val.INFERENCE_LANG, }, + embedding: { + textModel: val.EMBEDDING_TEXT_MODEL, + }, crawler: { numWorkers: val.CRAWLER_NUM_WORKERS, headlessBrowser: val.CRAWLER_HEADLESS_BROWSER,