Skip to content

Commit

Permalink
Add apis for querying ingestion and chat response status
Browse files Browse the repository at this point in the history
  • Loading branch information
johnshaughnessy committed Apr 15, 2024
1 parent 00ff342 commit 5f4c03e
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 5 deletions.
4 changes: 4 additions & 0 deletions memory_cache_hub/api/v1/depends.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,15 @@
chroma = None
llamafile_manager = None
db = None
projects_ingesting_files = []
projects_waiting_for_chat = []

get_api_config = lambda: api_config
get_chroma = lambda: chroma
get_llamafile_manager = lambda: llamafile_manager
get_db = lambda: db
get_projects_ingesting_files = lambda: projects_ingesting_files
get_projects_waiting_for_chat = lambda: projects_waiting_for_chat

def set_api_config(config: ApiConfig):
global api_config
Expand Down
20 changes: 19 additions & 1 deletion memory_cache_hub/api/v1/ingest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from fastapi import APIRouter, Depends
from memory_cache_hub.api.v1.depends import get_root_directory, get_chroma_client, get_embedding_function, get_db
from memory_cache_hub.api.v1.depends import get_root_directory, get_chroma_client, get_embedding_function, get_db, get_projects_ingesting_files
from memory_cache_hub.api.v1.types import IngestProjectFilesRequest, IngestProjectFilesResponse
from memory_cache_hub.core.files import get_project_uploads_directory, list_project_file_uploads
from memory_cache_hub.core.chromadb import chroma_collection_for_project
Expand All @@ -10,14 +10,27 @@

router = APIRouter()

@router.post("/check_ingestion_status", status_code=200, tags=["ingest"])
def check_ingestion_status(
project_id: int,
projects_ingesting_files = Depends(get_projects_ingesting_files)
):
if project_id in projects_ingesting_files:
return {"status": "ok", "isIngesting": True, "message": "Project is ingesting files"}
return {"status": "ok", "isIngesting": False, "message": "Project is not ingesting files"}

@router.post("/ingest_project_files", status_code=200, tags=["ingest"])
def ingest_project_files(
project_id: int,
root_directory: str = Depends(get_root_directory),
chroma_client = Depends(get_chroma_client),
chroma_embedding_function = Depends(get_embedding_function),
projects_ingesting_files = Depends(get_projects_ingesting_files),
db = Depends(get_db)
):
if project_id in projects_ingesting_files:
return {"status": "error", "message": "Project is already ingesting files"}
projects_ingesting_files.append(project_id)
project = db_get_project(db, project_id)
project_files = list_project_file_uploads(root_directory, project.name)
chroma_collection = chroma_collection_for_project(chroma_client, chroma_embedding_function, project.name)
Expand All @@ -39,6 +52,8 @@ def ingest_project_files(

fragments = fragments_from_files(file_paths, 1000, 200, chroma_embedding_function)
if len(fragments) == 0:
# Remove the project_id from the list of projects ingesting files
projects_ingesting_files.remove(project_id)
return {"status": "ok", "message": "No fragments found in the project files"}

# If we had multiple fragments with the same ID, remove the duplicates
Expand All @@ -51,9 +66,12 @@ def ingest_project_files(
metadatas=[asdict(fragment.fragment_metadata) for fragment in fragments],
documents=[fragment.fragment_text for fragment in fragments],
)
projects_ingesting_files.remove(project_id)
return IngestProjectFilesResponse(

num_files=len(file_paths),
num_fragments=len(fragments),
)
except Exception as e:
projects_ingesting_files.remove(project_id)
return {"status": "error", "message": str(e)}
9 changes: 7 additions & 2 deletions memory_cache_hub/api/v1/llamafile_manager.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from fastapi import APIRouter, Depends
from memory_cache_hub.api.v1.depends import get_llamafile_manager
from memory_cache_hub.api.v1.types import DownloadLlamafileByNameRequest, DownloadLlamafileByNameResponse, LlamafileDownloadStatusResponse, StartLlamafileResponse, StopLlamafileResponse
from memory_cache_hub.llamafile.llamafile_manager import get_llamafile_info_by_filename, download_llamafile, start_llamafile, stop_llamafile, has_llamafile, delete_llamafile
from memory_cache_hub.llamafile.llamafile_manager import get_llamafile_info_by_filename, download_llamafile, start_llamafile, stop_llamafile, has_llamafile, delete_llamafile, running_llamafile_info
import os
import shutil

Expand Down Expand Up @@ -44,7 +44,7 @@ async def api_start_llamafile(
llamafile_info = get_llamafile_info_by_filename(llamafile_manager, llamafile_filename)
if llamafile_info is None:
return StartLlamafileResponse(status="error", message="Llamafile not found")
if start_llamafile(llamafile_manager, llamafile_info):
if await start_llamafile(llamafile_manager, llamafile_info):
return StartLlamafileResponse(status="success", message="Llamafile started")
else:
return StartLlamafileResponse(status="error", message="Llamafile not found")
Expand Down Expand Up @@ -76,3 +76,8 @@ async def api_delete_llamafile(
return {"status": "success", "message": "Llamafile deleted"}
else:
return {"status": "error", "message": "Llamafile not found"}

# Create a route to call running_llamafile_info, which will return the llamafile info of the running llamafile, if any, or None.
@router.get("/running_llamafile_info", status_code=200, tags=["llamafile"])
async def api_running_llamafile_info(llamafile_manager = Depends(get_llamafile_manager)):
return await running_llamafile_info(llamafile_manager)
15 changes: 14 additions & 1 deletion memory_cache_hub/api/v1/rag.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from memory_cache_hub.api.v1.depends import get_root_directory, get_chroma_client, get_embedding_function, get_completions_url, get_completions_model, get_db
from memory_cache_hub.api.v1.depends import get_root_directory, get_chroma_client, get_embedding_function, get_completions_url, get_completions_model, get_db, get_projects_waiting_for_chat
from memory_cache_hub.api.v1.types import RagAskRequest, RagAskResponse
from memory_cache_hub.core.types import Message
from memory_cache_hub.core.files import get_project_uploads_directory, list_project_file_uploads
Expand All @@ -21,12 +21,14 @@ def rag_ask(
complete_model: str = Depends(get_completions_model),
chroma_client = Depends(get_chroma_client),
chroma_embedding_function = Depends(get_embedding_function),
projects_waiting_for_chat = Depends(get_projects_waiting_for_chat),
db=Depends(get_db)
):
print("GOT RAG ASK REQUEST:")
print(body)
prompt = body.prompt
project = db_get_project(db, body.project_id)
projects_waiting_for_chat.append(body.project_id)
chroma_collection = chroma_collection_for_project(chroma_client, chroma_embedding_function, project.name)
query_results = chroma_collection.query(query_texts=[prompt])

Expand Down Expand Up @@ -63,6 +65,7 @@ def rag_ask(
reply = openai_compatible_completions(complete_url, complete_model, messages)

except Exception as e:
projects_waiting_for_chat.remove(body.project_id)
return RagAskResponse(
status="error",
message=str(e)
Expand All @@ -71,6 +74,7 @@ def rag_ask(
print(reply)
print("\n-------\n")

projects_waiting_for_chat.remove(body.project_id)
return RagAskResponse(
status="ok",
response=reply
Expand Down Expand Up @@ -102,3 +106,12 @@ def vector_db_query(
})

return response

@router.post("/check_waiting_for_chat_status", status_code=200, tags=["rag"])
def check_waiting_for_chat_status(
project_id: int,
projects_waiting_for_chat = Depends(get_projects_waiting_for_chat)
):
if project_id in projects_waiting_for_chat:
return {"status": "ok", "isWaiting": True, "message": "Project is waiting for chat"}
return {"status": "ok", "isWaiting": False, "message": "Project is not waiting for chat"}
26 changes: 25 additions & 1 deletion memory_cache_hub/llamafile/llamafile_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ def on_complete(download_handle):
llamafile_manager.download_handles.append(download_handle)
return download_handle

def start_llamafile(llamafile_manager: LlamafileManager, llamafile_info: LlamafileInfo):
async def start_llamafile(llamafile_manager: LlamafileManager, llamafile_info: LlamafileInfo):
await stop_all_running_llamafiles(llamafile_manager)
run_handle = RunHandle(
llamafile_info=llamafile_info,
llamafile_store_path=llamafile_manager.llamafile_store_path,
Expand All @@ -43,6 +44,29 @@ async def stop_llamafile(llamafile_manager: LlamafileManager, llamafile_info: Ll
print(f"No running llamafile found for {llamafile_info.filename}.")
return False

async def is_running_llamafile(llamafile_manager: LlamafileManager, llamafile_info: LlamafileInfo):
for run_handle in llamafile_manager.run_handles:
if run_handle.llamafile_info.filename == llamafile_info.filename:
return run_handle.is_running()
return False

async def is_running_any_llamafile(llamafile_manager: LlamafileManager):
for run_handle in llamafile_manager.run_handles:
if run_handle.is_running():
return True
return False

async def running_llamafile_info(llamafile_manager: LlamafileManager):
for run_handle in llamafile_manager.run_handles:
if run_handle.is_running():
return run_handle.llamafile_info
return None

async def stop_all_running_llamafiles(llamafile_manager: LlamafileManager):
for run_handle in llamafile_manager.run_handles:
if run_handle.is_running:
await stop_llamafile(llamafile_manager, run_handle.llamafile_info)

def has_llamafile(llamafile_manager: LlamafileManager, llamafile_info: LlamafileInfo):
# Check if the file is already downloaded
file_path = os.path.join(llamafile_manager.llamafile_store_path, llamafile_info.filename)
Expand Down

0 comments on commit 5f4c03e

Please sign in to comment.