Skip to content

Commit

Permalink
Add authentication (#19)
Browse files Browse the repository at this point in the history
* Add authentication

* Minor tweaks
  • Loading branch information
homanp authored Jan 18, 2024
1 parent 485936f commit 6b9abc0
Show file tree
Hide file tree
Showing 7 changed files with 59 additions and 10 deletions.
3 changes: 2 additions & 1 deletion .env.example
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
API_BASE_URL=https://rag.superagent.sh
COHERE_API_KEY=
HUGGINGFACE_API_KEY=
HUGGINGFACE_API_KEY=
JWT_SECRET=
5 changes: 3 additions & 2 deletions api/delete.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from fastapi import APIRouter
from fastapi import APIRouter, Depends
from models.delete import RequestPayload, ResponsePayload
from service.vector_database import get_vector_service, VectorService
from auth.user import get_current_api_user

router = APIRouter()


@router.post("/delete", response_model=ResponsePayload)
async def delete(payload: RequestPayload):
async def delete(payload: RequestPayload, _api_user=Depends(get_current_api_user)):
vector_service: VectorService = get_vector_service(
index_name=payload.index_name, credentials=payload.vector_database
)
Expand Down
7 changes: 5 additions & 2 deletions api/ingest.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
from typing import Dict
from fastapi import APIRouter
from fastapi import APIRouter, Depends
from models.ingest import RequestPayload
from service.embedding import EmbeddingService
from auth.user import get_current_api_user

router = APIRouter()


@router.post("/ingest")
async def ingest(payload: RequestPayload) -> Dict:
async def ingest(
payload: RequestPayload, _api_user=Depends(get_current_api_user)
) -> Dict:
embedding_service = EmbeddingService(
files=payload.files,
index_name=payload.index_name,
Expand Down
5 changes: 3 additions & 2 deletions api/query.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from fastapi import APIRouter
from fastapi import APIRouter, Depends
from models.query import RequestPayload, ResponsePayload
from service.vector_database import get_vector_service, VectorService
from auth.user import get_current_api_user

router = APIRouter()


@router.post("/query", response_model=ResponsePayload)
async def query(payload: RequestPayload):
async def query(payload: RequestPayload, _api_user=Depends(get_current_api_user)):
vector_service: VectorService = get_vector_service(
index_name=payload.index_name, credentials=payload.vector_database
)
Expand Down
Empty file added auth/__init__.py
Empty file.
33 changes: 33 additions & 0 deletions auth/user.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import logging
import jwt

from decouple import config
from fastapi import HTTPException, Security
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from superagent.client import AsyncSuperagent

logger = logging.getLogger(__name__)
security = HTTPBearer()


def generate_jwt(data: dict):
token = jwt.encode({**data}, config("JWT_SECRET"), algorithm="HS256")
return token


def decode_jwt(token: str):
return jwt.decode(token, config("JWT_SECRET"), algorithms=["HS256"])


async def get_current_api_user(
authorization: HTTPAuthorizationCredentials = Security(security),
):
token = authorization.credentials
decoded_token = decode_jwt(token)
superagent = AsyncSuperagent(
base_url="https://api.beta.superagent.sh", token=decoded_token
)
api_user = superagent.api_user.get()
if not api_user:
raise HTTPException(status_code=401, detail="Invalid token or expired token")
return api_user
16 changes: 13 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,10 @@ joblib==1.3.2
litellm==1.17.5
llama-index==0.9.30
loguru==0.7.2
lxml==5.1.0
MarkupSafe==2.1.3
marshmallow==3.20.2
mpmath==1.3.0
multidict==6.0.4
mypy-extensions==1.0.0
nest-asyncio==1.5.8
Expand All @@ -57,34 +59,41 @@ openai==1.7.2
packaging==23.2
pandas==2.1.4
pathspec==0.12.1
pillow==10.2.0
pinecone-client==3.0.0
platformdirs==4.1.0
portalocker==2.8.2
protobuf==4.25.2
pycparser==2.21
pydantic==2.5.3
pydantic_core==2.14.6
pydantic==2.4.2
pydantic_core==2.10.1
PyJWT==2.8.0
pypdf==3.17.4
python-dateutil==2.8.2
python-decouple==3.8
python-dotenv==1.0.0
python-pptx==0.6.23
pytz==2023.3.post1
PyYAML==6.0.1
qdrant-client==1.7.0
regex==2023.12.25
requests==2.31.0
ruff==0.1.13
setuptools==69.0.3
safetensors==0.4.1
six==1.16.0
sniffio==1.3.0
soupsieve==2.5
SQLAlchemy==2.0.25
starlette==0.35.1
superagent-py==0.1.55
sympy==1.12
tenacity==8.2.3
tiktoken==0.5.2
tokenizers==0.15.0
toml==0.10.2
torch==2.1.2
tqdm==4.66.1
transformers==4.36.2
typing-inspect==0.9.0
typing_extensions==4.9.0
tzdata==2023.4
Expand All @@ -97,5 +106,6 @@ watchfiles==0.21.0
weaviate-client==3.26.0
websockets==12.0
wrapt==1.16.0
XlsxWriter==3.1.9
yarl==1.9.4
zipp==3.17.0

0 comments on commit 6b9abc0

Please sign in to comment.