Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate Embeddings of Text/Links/PDF #403

Open
wants to merge 11 commits into
base: rag
Choose a base branch
from
49 changes: 49 additions & 0 deletions apps/workers/inference.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { Ollama } from "ollama";
import OpenAI from "openai";
import { encoding_for_model, TiktokenModel } from "tiktoken";

import serverConfig from "@hoarder/shared/config";
import logger from "@hoarder/shared/logger";
Expand All @@ -9,13 +10,18 @@ export interface InferenceResponse {
totalTokens: number | undefined;
}

export interface EmbeddingResponse {
embeddings: number[][];
}

export interface InferenceClient {
inferFromText(prompt: string): Promise<InferenceResponse>;
inferFromImage(
prompt: string,
contentType: string,
image: string,
): Promise<InferenceResponse>;
generateEmbeddingFromText(prompt: string): Promise<EmbeddingResponse>;
}

export class InferenceClientFactory {
Expand Down Expand Up @@ -87,6 +93,30 @@ class OpenAIInferenceClient implements InferenceClient {
}
return { response, totalTokens: chatCompletion.usage?.total_tokens };
}

truncateTextTokens(text: string, maxTokens: number, model: string) {
const encoding = encoding_for_model(model as TiktokenModel);
const encoded = encoding.encode(text);
if (encoded.length <= maxTokens) {
return text;
}

return new TextDecoder().decode(
encoding.decode(encoded.slice(0, maxTokens)),
);
}

async generateEmbeddingFromText(prompt: string): Promise<EmbeddingResponse> {
const model = serverConfig.embedding.textModel;
const embedResponse = await this.openAI.embeddings.create({
model: model,
input: [this.truncateTextTokens(prompt, 2000, model)],
});
const embedding2D: number[][] = embedResponse.data.map(
(embedding: OpenAI.Embedding) => embedding.embedding,
);
return { embeddings: embedding2D };
}
}

class OllamaInferenceClient implements InferenceClient {
Expand Down Expand Up @@ -134,6 +164,17 @@ class OllamaInferenceClient implements InferenceClient {
return { response, totalTokens };
}

async runEmbeddingModel(model: string, prompt: string) {
const embedding = await this.ollama.embed({
model: model,
input: prompt,
// Truncate the input to fit into the model's max token limit,
// in the future we want to add a way to split the input into multiple parts.
truncate: true,
});
return { response: embedding };
}

async inferFromText(prompt: string): Promise<InferenceResponse> {
return await this.runModel(serverConfig.inference.textModel, prompt);
}
Expand All @@ -149,4 +190,12 @@ class OllamaInferenceClient implements InferenceClient {
image,
);
}

async generateEmbeddingFromText(prompt: string): Promise<EmbeddingResponse> {
const embedResponse = await this.runEmbeddingModel(
serverConfig.embedding.textModel,
prompt,
);
return { embeddings: embedResponse.response.embeddings };
}
}
146 changes: 122 additions & 24 deletions apps/workers/openaiWorker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ import type { InferenceClient } from "./inference";
import { InferenceClientFactory } from "./inference";
import { readPDFText, truncateContent } from "./utils";

type BookmarkType = "link" | "text" | "image" | "pdf" | "unsupported";

const openAIResponseSchema = z.object({
tags: z.array(z.string()),
});
Expand Down Expand Up @@ -60,7 +62,7 @@ async function attemptMarkTaggingStatus(

export class OpenAiWorker {
static build() {
logger.info("Starting inference worker ...");
logger.info("Starting AI worker ...");
const worker = new Runner<ZOpenAIRequest>(
OpenAIQueue,
{
Expand Down Expand Up @@ -114,17 +116,10 @@ Aim for 3-5 tags. If there are no good tags, leave the array empty.
function buildPrompt(
bookmark: NonNullable<Awaited<ReturnType<typeof fetchBookmark>>>,
) {
if (bookmark.link) {
if (!bookmark.link.description && !bookmark.link.content) {
throw new Error(
`No content found for link "${bookmark.id}". Skipping ...`,
);
}
const content = extractTextFromBookmark(bookmark);
const bType = bookmarkType(bookmark);

let content = bookmark.link.content;
if (content) {
content = truncateContent(content);
}
if (bType === "link") {
return `
${TEXT_PROMPT_BASE}
URL: ${bookmark.link.url}
Expand All @@ -134,7 +129,7 @@ Content: ${content ?? ""}
${TEXT_PROMPT_INSTRUCTIONS}`;
}

if (bookmark.text) {
if (bType == "text") {
const content = truncateContent(bookmark.text.text ?? "");
// TODO: Ensure that the content doesn't exceed the context length of openai
return `
Expand Down Expand Up @@ -224,25 +219,39 @@ async function inferTagsFromText(
return await inferenceClient.inferFromText(buildPrompt(bookmark));
}

function bookmarkType(
bookmark: NonNullable<Awaited<ReturnType<typeof fetchBookmark>>>,
): BookmarkType {
if (bookmark.link) {
return "link";
} else if (bookmark.text) {
return "text";
}
switch (bookmark.asset.assetType) {
case "image":
return "image";
break;
case "pdf":
return "pdf";
break;
default:
return "unsupported";
}
}

async function inferTags(
jobId: string,
bookmark: NonNullable<Awaited<ReturnType<typeof fetchBookmark>>>,
inferenceClient: InferenceClient,
) {
let response;
if (bookmark.link || bookmark.text) {
const bType = bookmarkType(bookmark);
if (bType === "text" || bType == "link") {
response = await inferTagsFromText(bookmark, inferenceClient);
} else if (bookmark.asset) {
switch (bookmark.asset.assetType) {
case "image":
response = await inferTagsFromImage(jobId, bookmark, inferenceClient);
break;
case "pdf":
response = await inferTagsFromPDF(jobId, bookmark, inferenceClient);
break;
default:
throw new Error(`[inference][${jobId}] Unsupported bookmark type`);
}
} else if (bType == "image") {
response = await inferTagsFromImage(jobId, bookmark, inferenceClient);
} else if (bType == "pdf") {
response = await inferTagsFromPDF(jobId, bookmark, inferenceClient);
} else {
throw new Error(`[inference][${jobId}] Unsupported bookmark type`);
}
Expand Down Expand Up @@ -362,6 +371,93 @@ async function connectTags(
});
}

// TODO: Make this function accept max tokens as an argument.
// TODO: Truncate text logic needs to be taken refactored such that the max token are tied to the model
// being used and not done once per bookmark.
function extractTextFromBookmark(
bookmark: NonNullable<Awaited<ReturnType<typeof fetchBookmark>>>,
): string {
if (bookmark.link) {
if (!bookmark.link.description && !bookmark.link.content) {
throw new Error(
`No content found for link "${bookmark.id}". Skipping ...`,
);
}

let content = bookmark.link.content;
if (content) {
content = truncateContent(content);
}
return content ?? "";
}

if (!bookmark.text) {
logger.error(
`[extractTextFromBookmark] Unsupported bookmark type, skipping ...`,
);
return "";
}
const content = truncateContent(bookmark.text.text ?? "");
if (!content) {
throw new Error(
`[inference] [UNEXPECTED] TruncateContent returned empty content for bookmark "${bookmark.id}". Skipping ...`,
);
}
return content;
}

async function extractTextFromPDFBookmark(
bookmark: NonNullable<Awaited<ReturnType<typeof fetchBookmark>>>,
jobId: string,
) {
const { asset } = await readAsset({
userId: bookmark.userId,
assetId: bookmark.asset.assetId,
});
if (!asset) {
throw new Error(
`[inference][${jobId}] AssetId ${bookmark.asset.assetId} for bookmark ${bookmark.id} not found`,
);
}
const pdfParse = await readPDFText(asset);
if (!pdfParse?.text) {
throw new Error(
`[inference][${jobId}] PDF text is empty. Please make sure that the PDF includes text and not just images.`,
);
}
const content = truncateContent(pdfParse.text);
if (!content) {
throw new Error(
`[inference][${jobId}] [UNEXPECTED] TruncateContent returned empty content for PDF "${bookmark.id}". Skipping ...`,
);
}
return content;
}

async function embedBookmark(
jobId: string,
bookmark: NonNullable<Awaited<ReturnType<typeof fetchBookmark>>>,
inferenceClient: InferenceClient,
) {
logger.info(`[embedding][${jobId}] ookmark ${bookmark.id}`);
const bType = bookmarkType(bookmark);
logger.info(`[embedding][${jobId}] Bookmark type: ${bType}`);
if (bType === "text") {
const embedding = await inferenceClient.generateEmbeddingFromText(
extractTextFromBookmark(bookmark),
);
logger.info(
`[embeddings] Embedding generated successfully: ${embedding.embeddings}`,
);
} else if (bType == "pdf") {
const content = await extractTextFromPDFBookmark(bookmark, jobId);
const embedding = await inferenceClient.generateEmbeddingFromText(content);
logger.info(
`[embeddings] Embedding generated successfully: ${embedding.embeddings}`,
);
}
}

async function runOpenAI(job: DequeuedJob<ZOpenAIRequest>) {
const jobId = job.id ?? "unknown";

Expand Down Expand Up @@ -398,4 +494,6 @@ async function runOpenAI(job: DequeuedJob<ZOpenAIRequest>) {

// Update the search index
await triggerSearchReindex(bookmarkId);

await embedBookmark(jobId, bookmark, inferenceClient);
}
1 change: 1 addition & 0 deletions apps/workers/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
"puppeteer-extra": "^3.3.6",
"puppeteer-extra-plugin-adblocker": "^2.13.6",
"puppeteer-extra-plugin-stealth": "^2.11.2",
"tiktoken": "^1.0.16",
"tsx": "^4.7.1",
"typescript": "^5.3.3",
"zod": "^3.22.4"
Expand Down
4 changes: 4 additions & 0 deletions packages/shared/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ const allEnv = z.object({
INFERENCE_JOB_TIMEOUT_SEC: z.coerce.number().default(30),
INFERENCE_TEXT_MODEL: z.string().default("gpt-4o-mini"),
INFERENCE_IMAGE_MODEL: z.string().default("gpt-4o-mini"),
EMBEDDING_TEXT_MODEL: z.string().default("text-embedding-3-small"),
CRAWLER_HEADLESS_BROWSER: stringBool("true"),
BROWSER_WEB_URL: z.string().url().optional(),
BROWSER_WEBSOCKET_URL: z.string().url().optional(),
Expand Down Expand Up @@ -73,6 +74,9 @@ const serverConfigSchema = allEnv.transform((val) => {
imageModel: val.INFERENCE_IMAGE_MODEL,
inferredTagLang: val.INFERENCE_LANG,
},
embedding: {
textModel: val.EMBEDDING_TEXT_MODEL,
},
crawler: {
numWorkers: val.CRAWLER_NUM_WORKERS,
headlessBrowser: val.CRAWLER_HEADLESS_BROWSER,
Expand Down