Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[#342] RAG - fix PDF format in vector database #551

Merged
merged 2 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file not shown.
76 changes: 76 additions & 0 deletions llama_stack/providers/tests/memory/test_vector_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

import base64
import mimetypes
import os
from pathlib import Path

import pytest

from llama_stack.apis.memory.memory import MemoryBankDocument, URL
from llama_stack.providers.utils.memory.vector_store import content_from_doc

DUMMY_PDF_PATH = Path(os.path.abspath(__file__)).parent / "fixtures" / "dummy.pdf"


def read_file(file_path: str) -> bytes:
with open(file_path, "rb") as file:
return file.read()


def data_url_from_file(file_path: str) -> str:
with open(file_path, "rb") as file:
file_content = file.read()

base64_content = base64.b64encode(file_content).decode("utf-8")
mime_type, _ = mimetypes.guess_type(file_path)

data_url = f"data:{mime_type};base64,{base64_content}"

return data_url


class TestVectorStore:
@pytest.mark.asyncio
async def test_returns_content_from_pdf_data_uri(self):
data_uri = data_url_from_file(DUMMY_PDF_PATH)
doc = MemoryBankDocument(
document_id="dummy",
content=data_uri,
mime_type="application/pdf",
metadata={},
)
content = await content_from_doc(doc)
assert content == "Dummy PDF file"

@pytest.mark.asyncio
async def test_downloads_pdf_and_returns_content(self):
# Using GitHub to host the PDF file
url = "https://raw.githubusercontent.com/meta-llama/llama-stack/da035d69cfca915318eaf485770a467ca3c2a238/llama_stack/providers/tests/memory/fixtures/dummy.pdf"
Copy link
Contributor Author

@aidando73 aidando73 Dec 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did consider mocking but that would have introduced another dep that we don't have + make the test less realistic. So decided to do this.

But lmk if you prefer mocking here

doc = MemoryBankDocument(
document_id="dummy",
content=url,
mime_type="application/pdf",
metadata={},
)
content = await content_from_doc(doc)
assert content == "Dummy PDF file"

@pytest.mark.asyncio
async def test_downloads_pdf_and_returns_content_with_url_object(self):
# Using GitHub to host the PDF file
url = "https://raw.githubusercontent.com/meta-llama/llama-stack/da035d69cfca915318eaf485770a467ca3c2a238/llama_stack/providers/tests/memory/fixtures/dummy.pdf"
doc = MemoryBankDocument(
document_id="dummy",
content=URL(
uri=url,
),
mime_type="application/pdf",
metadata={},
)
content = await content_from_doc(doc)
assert content == "Dummy PDF file"
18 changes: 14 additions & 4 deletions llama_stack/providers/utils/memory/vector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@ def get_embedding_model(model: str) -> "SentenceTransformer":
return loaded_model


def parse_pdf(data: bytes) -> str:
# For PDF and DOC/DOCX files, we can't reliably convert to string
pdf_bytes = io.BytesIO(data)
pdf_reader = PdfReader(pdf_bytes)
return "\n".join([page.extract_text() for page in pdf_reader.pages])


def parse_data_url(data_url: str):
data_url_pattern = re.compile(
r"^"
Expand Down Expand Up @@ -88,10 +95,7 @@ def content_from_data(data_url: str) -> str:
return data.decode(encoding)

elif mime_type == "application/pdf":
# For PDF and DOC/DOCX files, we can't reliably convert to string)
pdf_bytes = io.BytesIO(data)
pdf_reader = PdfReader(pdf_bytes)
return "\n".join([page.extract_text() for page in pdf_reader.pages])
return parse_pdf(data)

else:
log.error("Could not extract content from data_url properly.")
Expand All @@ -105,6 +109,9 @@ async def content_from_doc(doc: MemoryBankDocument) -> str:
else:
async with httpx.AsyncClient() as client:
r = await client.get(doc.content.uri)
if doc.mime_type == "application/pdf":
return parse_pdf(r.content)
else:
return r.text

pattern = re.compile("^(https?://|file://|data:)")
Expand All @@ -114,6 +121,9 @@ async def content_from_doc(doc: MemoryBankDocument) -> str:
else:
async with httpx.AsyncClient() as client:
r = await client.get(doc.content)
if doc.mime_type == "application/pdf":
return parse_pdf(r.content)
else:
return r.text

return interleaved_text_media_as_str(doc.content)
Expand Down
Loading