Skip to content

Commit

Permalink
chore: restore processing job
Browse files Browse the repository at this point in the history
fix: don't connect to weaviate in module top level
  • Loading branch information
rmoesbergen committed Apr 23, 2024
1 parent cf3b709 commit fac05dd
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 53 deletions.
73 changes: 40 additions & 33 deletions api/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,22 @@ def get_jwks_client():
return jwt.PyJWKClient(uri=get_openid_configuration()["jwks_uri"])


@repeat_every(seconds=24 * 60 * 60)
@repeat_every(seconds=10, wait_first=True)
async def update_smoelen():
db = next(get_db())
obtain_images(db)
handle_unprocessed(db)


@asynccontextmanager
async def startup(_app: FastAPI) -> None:
"""Startup context manager"""
await update_smoelen()
yield


models.Base.metadata.create_all(bind=engine)
app = FastAPI()
app = FastAPI(lifespan=startup)

app.add_middleware(
CORSMiddleware,
Expand Down Expand Up @@ -68,15 +75,15 @@ def get_user(token: Annotated[HTTPAuthorizationCredentials, Depends(security)]):
)

is_admin = (
"begeleider" in decoded_jwt["account_type"]
or decoded_jwt["sub"] in settings.allowed_users
"begeleider" in decoded_jwt["account_type"]
or decoded_jwt["sub"] in settings.allowed_users
)
return User(id=decoded_jwt["sub"], admin=is_admin)


def verify_signature(path: str, signature: str):
if not crud.get_verifier().verify(
SHA256.new(path.encode("utf-8")), bytes.fromhex(signature)
SHA256.new(path.encode("utf-8")), bytes.fromhex(signature)
):
return False

Expand All @@ -85,9 +92,9 @@ def verify_signature(path: str, signature: str):

@app.post("/albums", response_model=schemas.AlbumList, operation_id="create_album")
async def create_album(
album: schemas.AlbumCreate,
db: Session = Depends(get_db),
user: User = Depends(get_user),
album: schemas.AlbumCreate,
db: Session = Depends(get_db),
user: User = Depends(get_user),
):
if not user.admin:
raise HTTPException(
Expand All @@ -106,9 +113,9 @@ async def get_albums(db: Session = Depends(get_db), _user=Depends(get_user)):
"/albums", response_model=list[schemas.AlbumList], operation_id="order_albums"
)
async def order_albums(
albums: list[schemas.AlbumOrder],
db: Session = Depends(get_db),
user: User = Depends(get_user),
albums: list[schemas.AlbumOrder],
db: Session = Depends(get_db),
user: User = Depends(get_user),
):
if not user.admin:
raise HTTPException(
Expand All @@ -120,7 +127,7 @@ async def order_albums(

@app.get("/albums/{album_id}", response_model=schemas.Album, operation_id="get_album")
async def get_album(
album_id: UUID, db: Session = Depends(get_db), _user=Depends(get_user)
album_id: UUID, db: Session = Depends(get_db), _user=Depends(get_user)
):
return crud.get_album(db, album_id)

Expand All @@ -129,10 +136,10 @@ async def get_album(
"/albums/{album_id}", response_model=schemas.Album, operation_id="update_album"
)
async def update_album(
album_id: UUID,
album: schemas.AlbumCreate,
db: Session = Depends(get_db),
user: User = Depends(get_user),
album_id: UUID,
album: schemas.AlbumCreate,
db: Session = Depends(get_db),
user: User = Depends(get_user),
):
if not user.admin:
raise HTTPException(
Expand All @@ -146,11 +153,11 @@ async def update_album(
"/items/{album_id}", response_model=list[schemas.Item], operation_id="upload_items"
)
async def upload_items(
album_id: UUID,
items: list[UploadFile],
background_tasks: BackgroundTasks,
db: Session = Depends(get_db),
user: User = Depends(get_user),
album_id: UUID,
items: list[UploadFile],
background_tasks: BackgroundTasks,
db: Session = Depends(get_db),
user: User = Depends(get_user),
):
items = await crud.create_items(db, user, items, album_id)
background_tasks.add_task(process_smoelen, db, items)
Expand All @@ -159,10 +166,10 @@ async def upload_items(

@app.get("/items/{item_id}/{expiry}/full", include_in_schema=False)
async def get_item(
item_id: UUID, signature: str, expiry: float, db: Session = Depends(get_db)
item_id: UUID, signature: str, expiry: float, db: Session = Depends(get_db)
):
if not verify_signature(
f"{settings.base_url}/items/{item_id}/{expiry}/full", signature
f"{settings.base_url}/items/{item_id}/{expiry}/full", signature
):
return None
if datetime.now().timestamp() > expiry:
Expand All @@ -173,10 +180,10 @@ async def get_item(

@app.get("/items/{item_id}/{expiry}/cover", include_in_schema=False)
async def get_cover(
item_id: UUID, signature: str, expiry: float, db: Session = Depends(get_db)
item_id: UUID, signature: str, expiry: float, db: Session = Depends(get_db)
):
if not verify_signature(
f"{settings.base_url}/items/{item_id}/{expiry}/cover", signature
f"{settings.base_url}/items/{item_id}/{expiry}/cover", signature
):
return None
if datetime.now().timestamp() > expiry:
Expand All @@ -191,10 +198,10 @@ async def get_cover(
operation_id="delete_items",
)
async def delete_items(
album_id: UUID,
items: list[UUID],
db: Session = Depends(get_db),
user: User = Depends(get_user),
album_id: UUID,
items: list[UUID],
db: Session = Depends(get_db),
user: User = Depends(get_user),
):
return crud.delete_items(db, user, album_id, items)

Expand All @@ -205,10 +212,10 @@ async def delete_items(
operation_id="set_preview",
)
async def set_preview(
album_id: UUID,
item_id: UUID,
db: Session = Depends(get_db),
user: User = Depends(get_user),
album_id: UUID,
item_id: UUID,
db: Session = Depends(get_db),
user: User = Depends(get_user),
):
if not user.admin:
raise HTTPException(
Expand Down
41 changes: 21 additions & 20 deletions api/app/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,13 @@
import face_recognition
import requests
from PIL import Image
from numpy import asarray, ndarray
from sqlalchemy.orm import Session
from starlette.concurrency import run_in_threadpool
from weaviate import connect_to_local

from app.conf import settings
from app.db import models
from app.db.crud import get_smoel, set_smoel, delete_items, create_item

host, port = settings.weaviate_url.lstrip("https://").split(":")
client = connect_to_local(host, port)
known_faces = client.collections.get('known_faces')

from numpy import asarray, ndarray
from sqlalchemy.orm import Session
from starlette.concurrency import run_in_threadpool
from weaviate import connect_to_local, WeaviateClient

NoArgsNoReturnFuncT = Callable[[], None]
NoArgsNoReturnAsyncFuncT = Callable[[], Coroutine[Any, Any, None]]
Expand All @@ -34,6 +28,11 @@
]


def get_weaviate_client() -> WeaviateClient:
host, port = settings.weaviate_url.lstrip("https://").split(":")
return connect_to_local(host, port)


def repeat_every(
*,
seconds: float,
Expand Down Expand Up @@ -222,16 +221,18 @@ def obtain_images(db: Session):
def find_people(item: models.Item):
encodings = create_encodings(item.path)
smoelen = []
for encoding in encodings:
smoel = known_faces.query.near_vector(
near_vector=ndarray.tolist(encoding),
distance=0.18,
limit=1,
)
if len(smoel.objects) == 0:
continue

smoelen.append(smoel.objects[0].properties)
with get_weaviate_client() as client:
known_faces = client.collections.get('known_faces')
for encoding in encodings:
smoel = known_faces.query.near_vector(
near_vector=ndarray.tolist(encoding),
distance=0.18,
limit=1,
)
if len(smoel.objects) == 0:
continue

smoelen.append(smoel.objects[0].properties)

return smoelen

Expand Down

0 comments on commit fac05dd

Please sign in to comment.