(null);
+
+ useEffect(() => {
+ const fetchFiles = async () => {
+ try {
+ const baseUrl = import.meta.env.VITE_API_BASE_URL;
+ const response = await axios.get(`${baseUrl}/v1/api/uploadFile`, {
+ headers: {
+ Authorization: `JWT ${localStorage.getItem("access")}`, // Assuming JWT is used for auth
+ },
+ });
+ console.log("Response data:", response.data);
+ if (Array.isArray(response.data)) {
+ setFiles(response.data);
+ } else {
+ // setError("Unexpected response format");
+ }
+ } catch (error) {
+ console.error("Error fetching files", error);
+ // setError("Error fetching files");
+ } finally {
+ setIsLoading(false);
+ }
+ };
+
+ fetchFiles();
+ }, []);
+
+ if (isLoading) {
+ return Loading...
;
+ }
+ return (
+
+
+
+
+ {files.map((file) => (
+ -
+
+ File Name: {file.file_name}
+
+
+ Date of Upload:{" "}
+ {new Date(file.date_of_upload).toLocaleString()}
+
+
+ Size: {file.size} bytes
+
+
+ Page Count: {file.page_count}
+
+
+ File Type: {file.file_type}
+
+
+ Uploaded By: {file.uploaded_by_email}
+
+
+ ))}
+
+
+
+
+ );
+}
+
+export default ListOfFiles;
diff --git a/frontend/src/pages/Layout/Layout.tsx b/frontend/src/pages/Layout/Layout.tsx
index a7713cc8..1ebf72a0 100644
--- a/frontend/src/pages/Layout/Layout.tsx
+++ b/frontend/src/pages/Layout/Layout.tsx
@@ -24,22 +24,17 @@ export const Layout = ({
const location = useLocation();
useEffect(() => {
- // console.log(isAuthenticated);
if (!isAuthenticated) {
- setShowLoginMenu(true);
- }
- if (location.pathname === "/login" && !isAuthenticated) {
- setShowLoginMenu(false);
- } else if (location.pathname === "/resetpassword" && !isAuthenticated) {
- setShowLoginMenu(false);
- } else if (
- (location.pathname.includes("password") ||
- location.pathname.includes("reset")) &&
- !isAuthenticated
- ) {
- setShowLoginMenu(false);
- } else if (!isAuthenticated) {
- setShowLoginMenu(true);
+ if (
+ location.pathname === "/login" ||
+ location.pathname === "/resetpassword" ||
+ location.pathname.includes("password") ||
+ location.pathname.includes("reset")
+ ) {
+ setShowLoginMenu(false);
+ } else {
+ setShowLoginMenu(true);
+ }
}
}, [isAuthenticated, location.pathname]);
@@ -54,7 +49,7 @@ export const Layout = ({
- {!isAuthenticated && (
+ {!isAuthenticated && showLoginMenu && (
,
errorElement: ,
},
+ {
+ path: "listoffiles",
+ element: ,
+ errorElement: ,
+ },
+ {
+ path: "uploadfile",
+ element: ,
+ },
{
path: "drugSummary",
element: ,
@@ -72,10 +82,6 @@ const routes = [
path: "medications",
element: ,
},
- {
- path: "uploadfile",
- element: ,
- },
];
export default routes;
diff --git a/server/api/admin.py b/server/api/admin.py
index be1266ae..37614d2c 100644
--- a/server/api/admin.py
+++ b/server/api/admin.py
@@ -6,9 +6,15 @@
from .views.ai_promptStorage.models import AI_PromptStorage
from .views.ai_settings.models import AI_Settings
from .views.ai_promptStorage.models import AI_PromptStorage
+from .models.model_embeddings import Embeddings
from .views.jira.models import Feedback
+@admin.register(Embeddings)
+class MedicationAdmin(admin.ModelAdmin):
+ list_display = ['guid']
+
+
@admin.register(Medication)
class MedicationAdmin(admin.ModelAdmin):
list_display = ['name', 'benefits', 'risks']
diff --git a/server/api/migrations/0004_uploadfile_uploaded_by_email_and_more.py b/server/api/migrations/0004_uploadfile_uploaded_by_email_and_more.py
new file mode 100644
index 00000000..8a8a74f4
--- /dev/null
+++ b/server/api/migrations/0004_uploadfile_uploaded_by_email_and_more.py
@@ -0,0 +1,25 @@
+# Generated by Django 4.2.3 on 2024-07-30 10:19
+
+from django.conf import settings
+from django.db import migrations, models
+import django.db.models.deletion
+
+
+class Migration(migrations.Migration):
+
+ dependencies = [
+ ('api', '0003_alter_uploadfile_date_of_upload_and_more'),
+ ]
+
+ operations = [
+ migrations.AddField(
+ model_name='uploadfile',
+ name='uploaded_by_email',
+ field=models.CharField(blank=True, max_length=255),
+ ),
+ migrations.AlterField(
+ model_name='uploadfile',
+ name='uploaded_by',
+ field=models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL),
+ ),
+ ]
diff --git a/server/api/migrations/0005_embeddings.py b/server/api/migrations/0005_embeddings.py
new file mode 100644
index 00000000..4141495e
--- /dev/null
+++ b/server/api/migrations/0005_embeddings.py
@@ -0,0 +1,30 @@
+# Generated by Django 4.2.3 on 2024-07-30 10:21
+
+from django.db import migrations, models
+import django.db.models.deletion
+import pgvector.django.vector
+import uuid
+
+
+class Migration(migrations.Migration):
+
+ dependencies = [
+ ('api', '0004_uploadfile_uploaded_by_email_and_more'),
+ ]
+
+ operations = [
+ migrations.CreateModel(
+ name='Embeddings',
+ fields=[
+ ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
+ ('guid', models.UUIDField(default=uuid.uuid4, editable=False, unique=True)),
+ ('name', models.CharField(max_length=255)),
+ ('text', models.TextField()),
+ ('page_num', models.IntegerField(default=1)),
+ ('chunk_number', models.IntegerField()),
+ ('embedding_sentence_transformers', pgvector.django.vector.VectorField(dimensions=384, null=True)),
+ ('date_of_upload', models.DateTimeField(auto_now_add=True)),
+ ('upload_file', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='embeddings', to='api.uploadfile')),
+ ],
+ ),
+ ]
diff --git a/server/api/models/model_embeddings.py b/server/api/models/model_embeddings.py
new file mode 100644
index 00000000..97337088
--- /dev/null
+++ b/server/api/models/model_embeddings.py
@@ -0,0 +1,22 @@
+from django.db import models
+from django.conf import settings
+from pgvector.django import VectorField
+import uuid
+from ..views.uploadFile.models import UploadFile
+
+
+class Embeddings(models.Model):
+ upload_file = models.ForeignKey(
+ UploadFile, related_name='embeddings', on_delete=models.CASCADE)
+ # This is a new unique GUID for each Embedding
+ guid = models.UUIDField(unique=True, default=uuid.uuid4, editable=False)
+ name = models.CharField(max_length=255)
+ text = models.TextField()
+ page_num = models.IntegerField(default=1)
+ chunk_number = models.IntegerField()
+ embedding_sentence_transformers = VectorField(
+ dimensions=384, null=True)
+ date_of_upload = models.DateTimeField(auto_now_add=True, blank=True)
+
+ def __str__(self):
+ return self.name
diff --git a/server/api/services/conversions_services.py b/server/api/services/conversions_services.py
new file mode 100644
index 00000000..d134ff49
--- /dev/null
+++ b/server/api/services/conversions_services.py
@@ -0,0 +1,12 @@
+import uuid
+
+
+def convert_uuids(data):
+ if isinstance(data, dict):
+ return {key: convert_uuids(value) for key, value in data.items()}
+ elif isinstance(data, list):
+ return [convert_uuids(item) for item in data]
+ elif isinstance(data, uuid.UUID):
+ return str(data)
+ else:
+ return data
diff --git a/server/api/services/embedding_services.py b/server/api/services/embedding_services.py
new file mode 100644
index 00000000..5aacab38
--- /dev/null
+++ b/server/api/services/embedding_services.py
@@ -0,0 +1,44 @@
+# services/embedding_services.py
+from .sentencetTransformer_model import TransformerModel
+# Adjust import path as needed
+from ..models.model_embeddings import Embeddings
+from pgvector.django import L2Distance
+
+
+def get_closest_embeddings(user, message_data, document_name=None, guid=None, num_results=10):
+ #
+ transformerModel = TransformerModel.get_instance().model
+ embedding_message = transformerModel.encode(message_data)
+ # Start building the query based on the message's embedding
+ closest_embeddings_query = Embeddings.objects.filter(
+ upload_file__uploaded_by=user
+ ).annotate(
+ distance=L2Distance(
+ 'embedding_sentence_transformers', embedding_message)
+ ).order_by('distance')
+
+ # Filter by GUID if provided, otherwise filter by document name if provided
+ if guid:
+ closest_embeddings_query = closest_embeddings_query.filter(
+ upload_file__guid=guid)
+ elif document_name:
+ closest_embeddings_query = closest_embeddings_query.filter(
+ name=document_name)
+
+ # Slice the results to limit to num_results
+ closest_embeddings_query = closest_embeddings_query[:num_results]
+
+ # Format the results to be returned
+ results = [
+ {
+ "name": obj.name,
+ "text": obj.text,
+ "page_number": obj.page_num,
+ "chunk_number": obj.chunk_number,
+ "distance": obj.distance,
+ "file_id": obj.upload_file.guid if obj.upload_file else None,
+ }
+ for obj in closest_embeddings_query
+ ]
+
+ return results
diff --git a/server/api/services/sentencetTransformer_model.py b/server/api/services/sentencetTransformer_model.py
new file mode 100644
index 00000000..1ba9b9e5
--- /dev/null
+++ b/server/api/services/sentencetTransformer_model.py
@@ -0,0 +1,22 @@
+from sentence_transformers import SentenceTransformer
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class TransformerModel:
+ _instance = None
+
+ def __new__(cls):
+ if cls._instance is None:
+ logger.info("Loading SentenceTransformer model")
+ cls._instance = super(TransformerModel, cls).__new__(cls)
+ cls._instance.model = SentenceTransformer(
+ 'paraphrase-MiniLM-L6-v2')
+ return cls._instance
+
+ @classmethod
+ def get_instance(cls):
+ if cls._instance is None:
+ cls._instance = cls()
+ return cls._instance
diff --git a/server/api/views/embeddings/embeddingsView.py b/server/api/views/embeddings/embeddingsView.py
new file mode 100644
index 00000000..54b1ef41
--- /dev/null
+++ b/server/api/views/embeddings/embeddingsView.py
@@ -0,0 +1,89 @@
+from rest_framework.views import APIView
+from rest_framework.permissions import IsAuthenticated
+from rest_framework.response import Response
+from rest_framework import status
+import openai
+# from ..embeddings_manage.models import Embeddings
+import os
+from ...services.embedding_services import get_closest_embeddings
+import logging
+import json
+import uuid
+from django.utils.decorators import method_decorator
+from django.views.decorators.csrf import csrf_exempt
+from ...services.conversions_services import convert_uuids
+
+
+@method_decorator(csrf_exempt, name='dispatch')
+class AskEmbeddingsAPIView(APIView):
+
+ permission_classes = [IsAuthenticated]
+
+ def post(self, request, *args, **kwargs):
+ try:
+ user = request.user
+
+ print("AskEmbeddingsAPIView")
+ request_data = request.data.get('message', None)
+ if not request_data:
+ return Response({"error": "Message data is required."}, status=status.HTTP_400_BAD_REQUEST)
+ message = [request_data][0]
+
+ embeddings_results = get_closest_embeddings(
+ user=user,
+ message_data=message
+ )
+
+ embeddings_results = convert_uuids(embeddings_results)
+
+ print("AskEmbeddingsAPIView1")
+ prompt_texts = [
+ f"[Start of INFO {i+1} === GUID: {obj['file_id']}, Page Number: {obj['page_number']}, Chunk Number: {obj['chunk_number']}, Text: {obj['text']} === End of INFO {i+1} ]" for i, obj in enumerate(embeddings_results)]
+
+ listOfEmbeddings = " ".join(prompt_texts)
+
+ prompt_text = (
+ f"""You are an AI assistant tasked with providing detailed, well-structured responses based on the information provided in [PROVIDED-INFO]. Follow these guidelines strictly:
+ 1. Content: Use information contained within [PROVIDED-INFO] to answer the question.
+ 2. Organization: Structure your response with clear sections and paragraphs.
+ 3. Citations: After EACH sentence that uses information from [PROVIDED-INFO], include a citation in this exact format:***[{{file_id}}], Page {{page_number}}, Chunk {{chunk_number}}*** . Only use citations that correspond to the information you're presenting.
+ 4. Clarity: Ensure your answer is well-structured and easy to follow.
+ 5. Direct Response: Answer the user's question directly without unnecessary introductions or filler phrases.
+ Here's an example of the required response format:
+ ________________________________________
+ See's Candy in the context of sales during a specific event. The candy counters rang up 2,690 individual sales on a Friday, and an additional 3,931 transactions on a Saturday ***[16s848as-vcc1-85sd-r196-7f820a4s9de1, Page 5, Chunk 26]***.
+ People like the consumption of fudge and peanut brittle the most ***[130714d7-b9c1-4sdf-b146-fdsf854cad4f, Page 9, Chunk 19]***.
+ Here is the history of See's Candy: the company was purchased in 1972, and its products have not been materially altered in 101 years ***[895sdsae-b7v5-416f-c84v-7f9784dc01e1, Page 2, Chunk 13]***.
+ Bipolar disorder treatment often involves mood stabilizers. Lithium is a commonly prescribed mood stabilizer effective in reducing manic episodes ***[b99988ac-e3b0-4d22-b978-215e814807f4, Page 29, Chunk 122]***. For acute hypomania or mild to moderate mania, initial treatment with risperidone or olanzapine monotherapy is suggested ***[b99988ac-e3b0-4d22-b978-215e814807f4, Page 24, Chunk 101]***.
+ ________________________________________
+ Please provide your response to the user's question following these guidelines precisely.
+ [PROVIDED-INFO] = {listOfEmbeddings}"""
+ )
+
+ # message = f"{message}\n"
+ model_used = "gpt-4o-mini"
+ # model_used = "gpt-3.5-turbo-0125"
+
+ openai.api_key = os.getenv("OPENAI_API_KEY")
+ response = openai.ChatCompletion.create(
+ model=model_used,
+ temperature=0.2,
+ messages=[
+ {"role": "system",
+ "content": prompt_text},
+ {"role": "user", "content": message}
+ ]
+ )
+
+ answer = response["choices"][0]["message"]["content"]
+
+ return Response({
+ "question": message,
+ "llm_response": answer,
+ "embeddings_info": embeddings_results,
+ "sent to LLM": prompt_text,
+ }, status=status.HTTP_200_OK)
+
+ except Exception as e:
+ print(f"An error occurred: {e}")
+ return Response({"error": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
diff --git a/server/api/views/embeddings/urls.py b/server/api/views/embeddings/urls.py
new file mode 100644
index 00000000..32bf4ebe
--- /dev/null
+++ b/server/api/views/embeddings/urls.py
@@ -0,0 +1,10 @@
+from django.urls import path
+from .embeddingsView import AskEmbeddingsAPIView
+
+
+urlpatterns = [
+
+ path('v1/api/embeddings/ask_embeddings', AskEmbeddingsAPIView.as_view(),
+ name='ask_embeddings'),
+
+]
diff --git a/server/api/views/uploadFile/models.py b/server/api/views/uploadFile/models.py
index ecd64fd6..9554aee1 100644
--- a/server/api/views/uploadFile/models.py
+++ b/server/api/views/uploadFile/models.py
@@ -1,4 +1,5 @@
from django.db import models
+from django.conf import settings
import uuid
@@ -12,7 +13,10 @@ class UploadFile(models.Model):
size = models.BigIntegerField()
page_count = models.IntegerField()
file_type = models.CharField(max_length=50)
- uploaded_by = models.CharField(max_length=255, blank=True)
+ uploaded_by = models.ForeignKey(
+ settings.AUTH_USER_MODEL, on_delete=models.CASCADE)
+ uploaded_by_email = models.CharField(
+ max_length=255, blank=True)
source_url = models.CharField(max_length=255, blank=True, null=True)
analyzed = models.DateTimeField(blank=True, null=True)
approved = models.DateTimeField(blank=True, null=True)
diff --git a/server/api/views/uploadFile/serializers.py b/server/api/views/uploadFile/serializers.py
new file mode 100644
index 00000000..94ab8bec
--- /dev/null
+++ b/server/api/views/uploadFile/serializers.py
@@ -0,0 +1,8 @@
+from rest_framework import serializers
+from .models import UploadFile
+
+
+class UploadFileSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = UploadFile
+ exclude = ['file']
diff --git a/server/api/views/uploadFile/views.py b/server/api/views/uploadFile/views.py
index 8d9b41f7..4ad27e21 100644
--- a/server/api/views/uploadFile/views.py
+++ b/server/api/views/uploadFile/views.py
@@ -8,20 +8,39 @@
from .models import UploadFile # Import your UploadFile model
from django.core.files.base import ContentFile
import os
+from .serializers import UploadFileSerializer
+from django.http import HttpResponse
+from ...services.sentencetTransformer_model import TransformerModel
+from ...models.model_embeddings import Embeddings
+import fitz
+from django.db import transaction
@method_decorator(csrf_exempt, name='dispatch')
class UploadFileView(APIView):
- # permission_classes = [IsAuthenticated]
+ permission_classes = [IsAuthenticated]
+
+ def get(self, request, format=None):
+ print("UploadFileView, get list")
+
+ # Get the authenticated user
+ user = request.user
+
+ # Filter the files uploaded by the authenticated user
+ files = UploadFile.objects.filter(uploaded_by=user.id).defer(
+ 'file').order_by('-date_of_upload')
+
+ serializer = UploadFileSerializer(files, many=True)
+ return Response(serializer.data)
def post(self, request, format=None):
print(request.auth)
print(f"UploadFileView post called. Path: {request.path}")
- if not request.user.is_superuser:
- return Response(
- {"message": "Error, user is not a superuser."},
- status=status.HTTP_401_UNAUTHORIZED,
- )
+ # if not request.user.is_superuser:
+ # return Response(
+ # {"message": "Error, user is not a superuser."},
+ # status=status.HTTP_401_UNAUTHORIZED,
+ # )
uploaded_file = request.FILES.get('file')
if uploaded_file is None:
@@ -46,7 +65,7 @@ def post(self, request, format=None):
# Read the entire PDF to store in the BinaryField
uploaded_file.seek(0)
pdf_binary = uploaded_file.read()
-
+ with transaction.atomic():
# Create a new UploadFile instance and populate it
new_file = UploadFile(
file_name=uploaded_file.name,
@@ -54,18 +73,97 @@ def post(self, request, format=None):
size=size,
page_count=page_count,
file_type=file_type,
- # Assuming you want to capture who uploaded the file
- uploaded_by=request.user.email
+ uploaded_by=request.user, # Set to the user instance
+ uploaded_by_email=request.user.email # Also store the email separately
)
new_file.save()
- return Response(
- {"message": "File uploaded successfully.",
- "file_id": new_file.id},
- status=status.HTTP_201_CREATED,
- )
- except Exception as e:
+ if new_file.id is None:
+ return Response({"message": "Failed to save the upload file."}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
+
+ with fitz.open(stream=pdf_binary, filetype="pdf") as doc:
+ text = ""
+ page_number = 1 # Initialize page_number
+ page_texts = [] # List to hold text for each page with page number
+
+ for page in doc:
+ page_text = page.get_text()
+ text += page_text
+ page_texts.append((page_number, page_text))
+ page_number += 1
+
+ chunks_with_page = []
+
+ # Create chunks along with their corresponding page number
+ for page_num, page_text in page_texts:
+ words = page_text.split()
+ chunks = [' '.join(words[i:i+100])
+ for i in range(0, len(words), 100)]
+ for chunk in chunks:
+ chunks_with_page.append((page_num, chunk))
+
+ model = TransformerModel.get_instance().model
+ # Encode each chunk and save embeddings
+ embeddings = model.encode(
+ [chunk for _, chunk in chunks_with_page])
+
+ for i, ((page_num, chunk), embedding) in enumerate(zip(chunks_with_page, embeddings)):
+ Embeddings.objects.create(
+ upload_file=new_file,
+ name=new_file.file_name, # You may adjust the naming convention
+ text=chunk,
+ chunk_number=i,
+ page_num=page_num, # Store the page number here
+ embedding_sentence_transformers=embedding.tolist()
+ )
return Response(
- {"message": f"Error processing PDF: {str(e)}"},
- status=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ {"message": "File uploaded successfully.",
+ "file_id": new_file.id},
+ status=status.HTTP_201_CREATED,
)
+ except Exception as e:
+ # Handle potential errors
+ return Response({"message": f"Error processing file and embeddings: {str(e)}"},
+ status=status.HTTP_400_BAD_REQUEST)
+
+ def delete(self, request, format=None):
+ guid = request.data.get('guid')
+ if not guid:
+ return Response({"message": "No file ID provided."}, status=status.HTTP_400_BAD_REQUEST)
+
+ try:
+ with transaction.atomic():
+ # Fetch the file to delete
+ upload_file = UploadFile.objects.get(guid=guid)
+
+ # Check if the user has permission to delete this file
+ if upload_file.uploaded_by != request.user:
+ return Response({"message": "You do not have permission to delete this file."}, status=status.HTTP_403_FORBIDDEN)
+
+ # Delete related embeddings
+ Embeddings.objects.filter(upload_file=upload_file).delete()
+
+ # Delete the file
+ upload_file.delete()
+
+ return Response({"message": "File and related embeddings deleted successfully."}, status=status.HTTP_200_OK)
+ except UploadFile.DoesNotExist:
+ return Response({"message": "File not found."}, status=status.HTTP_404_NOT_FOUND)
+ except Exception as e:
+ return Response({"message": f"Error deleting file and embeddings: {str(e)}"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
+
+
+@method_decorator(csrf_exempt, name='dispatch')
+class RetrieveUploadFileView(APIView):
+ permission_classes = [IsAuthenticated]
+
+ def get(self, request, guid, format=None):
+ try:
+ file = UploadFile.objects.get(
+ guid=guid, uploaded_by=request.user.id)
+ response = HttpResponse(file.file, content_type='application/pdf')
+ # print(file.file[:100])
+ response['Content-Disposition'] = f'attachment; filename="{file.file_name}"'
+ return response
+ except UploadFile.DoesNotExist:
+ return Response({"message": "No file found or access denied."}, status=status.HTTP_404_NOT_FOUND)
diff --git a/server/balancer_backend/urls.py b/server/balancer_backend/urls.py
index 203125bc..298a1fee 100644
--- a/server/balancer_backend/urls.py
+++ b/server/balancer_backend/urls.py
@@ -19,7 +19,7 @@
# List of application names for which URL patterns will be dynamically added
urls = ['chatgpt', 'jira', 'listDrugs', 'listMeds', 'risk',
- 'uploadFile', 'ai_promptStorage', 'ai_settings']
+ 'uploadFile', 'ai_promptStorage', 'ai_settings', 'embeddings']
# Loop through each application name and dynamically import and add its URL patterns
for url in urls:
diff --git a/server/requirements.txt b/server/requirements.txt
index eec27d3c..ad44e8c6 100644
--- a/server/requirements.txt
+++ b/server/requirements.txt
@@ -12,4 +12,9 @@ django-registration-redux==2.13
django-cors-headers>=3.10.0
djangorestframework-simplejwt
djoser
-pdfplumber
\ No newline at end of file
+pdfplumber
+pgvector
+sentence_transformers
+PyMuPDF
+Pillow
+pytesseract
\ No newline at end of file