From d25028b1124168384df2246446dacc89582e40b5 Mon Sep 17 00:00:00 2001 From: Shreyas Mocherla Date: Mon, 5 Aug 2024 13:03:45 +0530 Subject: [PATCH] Implemented Auth --- multi_tenant_rag.py | 45 ++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 42 insertions(+), 3 deletions(-) diff --git a/multi_tenant_rag.py b/multi_tenant_rag.py index ee491de..d4492d5 100644 --- a/multi_tenant_rag.py +++ b/multi_tenant_rag.py @@ -1,4 +1,6 @@ import streamlit as st +import streamlit_authenticator as stauth +from streamlit_authenticator.utilities import RegisterError, LoginError import os from langchain_community.vectorstores.chroma import Chroma from app import setup_chroma_client, setup_chroma_embedding_function @@ -8,6 +10,40 @@ import tempfile from langchain_community.document_loaders import PyPDFLoader +import yaml +from yaml.loader import SafeLoader + +def configure_authenticator(): + with open('.streamlit/config.yaml') as file: + config = yaml.load(file, Loader=SafeLoader) + + authenticator = stauth.Authenticate( + config['credentials'], + config['cookie']['name'], + config['cookie']['key'], + config['cookie']['expiry_days'], + config['pre-authorized'] + ) + return authenticator + +def authenticate(op): + authenticator = configure_authenticator() + + if op == "login": + name, authentication_status, username = authenticator.login() + st.session_state['authentication_status'] = authentication_status + st.session_state['username'] = username + elif op == "register": + try: + (email_of_registered_user, + username_of_registered_user, + name_of_registered_user) = authenticator.register_user(pre_authorization=False) + if email_of_registered_user: + st.success('User registered successfully') + except RegisterError as e: + st.error(e) + return authenticator + class MultiTenantRAG(RAG): def __init__(self, user_id, llm, embeddings, collection_name, db_client): self.user_id = user_id @@ -22,13 +58,13 @@ def load_documents(self, doc): return documents def main(): - llm = setup_huggingface_endpoint() + llm = setup_huggingface_endpoint(model_id="qwen/Qwen2-7B-Instruct") embeddings = setup_huggingface_embeddings() chroma_embeddings = setup_chroma_embedding_function() - user_id = st.text_input("Enter user ID") + user_id = st.session_state['username'] client = setup_chroma_client() @@ -65,4 +101,7 @@ def main(): st.chat_message("assistant").markdown(answer) if __name__ == "__main__": - main() \ No newline at end of file + authenticator = authenticate("login") + if st.session_state['authentication_status']: + authenticator.logout() + main() \ No newline at end of file