Skip to content

Commit

Permalink
feat: pushing ingestion tasks to kafka
Browse files Browse the repository at this point in the history
  • Loading branch information
elisalimli committed Mar 11, 2024
1 parent 717f5bb commit 5f243d1
Show file tree
Hide file tree
Showing 6 changed files with 203 additions and 39 deletions.
144 changes: 105 additions & 39 deletions api/ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,59 +4,125 @@
import aiohttp
from fastapi import APIRouter

from models.ingest import RequestPayload
from models.ingest import RequestPayload, TaskStatus
from models.api import ApiError
from service.embedding import EmbeddingService
from service.ingest import handle_google_drive, handle_urls
from utils.summarise import SUMMARY_SUFFIX

from service.redis.ingest_task_manager import (
IngestTaskManager,
CreateTaskDto,
UpdateTaskDto,
)
from service.redis.client import redis_client

from fastapi.responses import JSONResponse
from fastapi import status
from service.kafka.config import kafka_bootstrap_servers, ingest_topic
from service.kafka.producer import kafka_producer

router = APIRouter()


class IngestPayload(RequestPayload):
task_id: str


@router.post("/ingest")
async def ingest(payload: RequestPayload) -> Dict:
encoder = payload.document_processor.encoder.get_encoder()
embedding_service = EmbeddingService(
encoder=encoder,
index_name=payload.index_name,
vector_credentials=payload.vector_database,
dimensions=payload.document_processor.encoder.dimensions,
async def add_ingest_queue(payload: RequestPayload):
try:
task_manager = IngestTaskManager(redis_client)
task_id = task_manager.create(CreateTaskDto(status=TaskStatus.PENDING))
print("Task ID: ", task_id)

message = IngestPayload(**payload.model_dump(), task_id=str(task_id))

msg = message.model_dump_json().encode()

kafka_producer.send(ingest_topic, msg)
kafka_producer.flush()

return {"success": True, "task_id": task_id}

except Exception as err:
print(f"error: {err}")


@router.get("/ingest/tasks/{task_id}")
async def get_task(task_id: str):
task_manager = IngestTaskManager(redis_client)

task = task_manager.get(task_id)

if task:
return task

return JSONResponse(
status_code=status.HTTP_404_NOT_FOUND,
content={"sucess": False, "error": {"message": "Task not found"}},
)
chunks = []
summary_documents = []
if payload.files:
chunks, summary_documents = await handle_urls(
embedding_service=embedding_service,
files=payload.files,
config=payload.document_processor,
)

elif payload.google_drive:
chunks, summary_documents = await handle_google_drive(
embedding_service, payload.google_drive
) # type: ignore TODO: Fix typing

tasks = [
embedding_service.embed_and_upsert(
chunks=chunks, encoder=encoder, index_name=payload.index_name
),
]
async def ingest(payload: IngestPayload, task_manager: IngestTaskManager) -> Dict:
try:
encoder = payload.document_processor.encoder.get_encoder()
embedding_service = EmbeddingService(
encoder=encoder,
index_name=payload.index_name,
vector_credentials=payload.vector_database,
dimensions=payload.document_processor.encoder.dimensions,
)
chunks = []
summary_documents = []
if payload.files:
chunks, summary_documents = await handle_urls(
embedding_service=embedding_service,
files=payload.files,
config=payload.document_processor,
)

elif payload.google_drive:
chunks, summary_documents = await handle_google_drive(
embedding_service, payload.google_drive
) # type: ignore TODO: Fix typing

if summary_documents and all(item is not None for item in summary_documents):
tasks.append(
tasks = [
embedding_service.embed_and_upsert(
chunks=summary_documents,
encoder=encoder,
index_name=f"{payload.index_name}{SUMMARY_SUFFIX}",
chunks=chunks, encoder=encoder, index_name=payload.index_name
),
]

if summary_documents and all(item is not None for item in summary_documents):
tasks.append(
embedding_service.embed_and_upsert(
chunks=summary_documents,
encoder=encoder,
index_name=f"{payload.index_name}{SUMMARY_SUFFIX}",
)
)
)

await asyncio.gather(*tasks)
await asyncio.gather(*tasks)

if payload.webhook_url:
async with aiohttp.ClientSession() as session:
await session.post(
url=payload.webhook_url,
json={"index_name": payload.index_name, "status": "completed"},
)
if payload.webhook_url:
async with aiohttp.ClientSession() as session:
await session.post(
url=payload.webhook_url,
json={"index_name": payload.index_name, "status": "completed"},
)

return {"success": True, "index_name": payload.index_name}
task_manager.update(
task_id=payload.task_id,
task=UpdateTaskDto(
status=TaskStatus.DONE,
),
)
except Exception as e:
print("Marking task as failed...", e)
task_manager.update(
task_id=payload.task_id,
task=UpdateTaskDto(
status=TaskStatus.FAILED,
error=ApiError(message=str(e)),
),
)
6 changes: 6 additions & 0 deletions models/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from pydantic import BaseModel
from typing import Optional


class ApiError(BaseModel):
message: Optional[str]
12 changes: 12 additions & 0 deletions models/ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from models.file import File
from models.google_drive import GoogleDrive
from models.vector_database import VectorDatabase
from models.api import ApiError


class EncoderProvider(str, Enum):
Expand Down Expand Up @@ -90,3 +91,14 @@ class RequestPayload(BaseModel):
files: Optional[List[File]] = None
google_drive: Optional[GoogleDrive] = None
webhook_url: Optional[str] = None


class TaskStatus(str, Enum):
DONE = "DONE"
PENDING = "PENDING"
FAILED = "FAILED"


class IngestTaskResponse(BaseModel):
status: TaskStatus
error: Optional[ApiError] = None
9 changes: 9 additions & 0 deletions service/kafka/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from decouple import config


ingest_topic = config("KAFKA_TOPIC_INGEST", default="ingestion")


kafka_bootstrap_servers: str = config(
"KAFKA_BOOTSTRAP_SERVERS", default="localhost:9092"
)
41 changes: 41 additions & 0 deletions service/kafka/consume.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import asyncio
from api.ingest import ingest as _ingest, IngestPayload


from service.redis.client import redis_client
from service.redis.ingest_task_manager import IngestTaskManager
from service.kafka.config import ingest_topic
from service.kafka.consumer import get_kafka_consumer

from kafka.consumer.fetcher import ConsumerRecord


async def ingest(msg: ConsumerRecord):
payload = IngestPayload(**msg.value)
task_manager = IngestTaskManager(redis_client)
await _ingest(payload, task_manager)


kafka_actions = {
ingest_topic: ingest,
}


async def process_msg(msg: ConsumerRecord, topic: str, consumer):
await kafka_actions[topic](msg)
consumer.commit()


async def consume():
consumer = get_kafka_consumer(ingest_topic)

while True:
# Response format is {TopicPartiton('topic1', 1): [msg1, msg2]}
msg_pack = consumer.poll(timeout_ms=3000)

for tp, messages in msg_pack.items():
for message in messages:
await process_msg(message, tp.topic, consumer)


asyncio.run(consume())
30 changes: 30 additions & 0 deletions service/redis/ingest_task_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from redis import Redis
import json
from models.ingest import IngestTaskResponse


class CreateTaskDto(IngestTaskResponse):
pass


class UpdateTaskDto(IngestTaskResponse):
pass


class IngestTaskManager:
def __init__(self, redis_client: Redis):
self.redis_client = redis_client

def create(self, task: CreateTaskDto):
task_id = self.redis_client.incr("task_id")
self.redis_client.set(task_id, task.model_dump_json())
return task_id

def get(self, task_id):
return IngestTaskResponse(**json.loads(self.redis_client.get(task_id)))

def update(self, task_id, task: UpdateTaskDto):
self.redis_client.set(task_id, task.model_dump_json())

def delete(self, task_id):
self.redis_client.delete(task_id)

0 comments on commit 5f243d1

Please sign in to comment.