From f940054a9afe97a6dcec47776c1cc6e82f3602b2 Mon Sep 17 00:00:00 2001 From: Monique Rio Date: Mon, 18 Nov 2024 21:03:31 +0000 Subject: [PATCH] reworked in_zephir to be a more generic filter --- aim/digifeeds/database/crud.py | 22 ++++++++++++++-------- aim/digifeeds/database/main.py | 13 +++++++------ aim/digifeeds/database/schemas.py | 9 +++++++++ tests/digifeeds/database/test_crud.py | 12 ++++++------ tests/digifeeds/database/test_main.py | 4 ++-- 5 files changed, 38 insertions(+), 22 deletions(-) diff --git a/aim/digifeeds/database/crud.py b/aim/digifeeds/database/crud.py index f2979fb..24e414b 100644 --- a/aim/digifeeds/database/crud.py +++ b/aim/digifeeds/database/crud.py @@ -24,33 +24,39 @@ def get_item(db: Session, barcode: str): return db.query(models.Item).filter(models.Item.barcode == barcode).first() -def get_items_total(db: Session, in_zephir: bool | None): - query = get_items_query(db=db, in_zephir=in_zephir) +def get_items_total(db: Session, filter: schemas.ItemFilters = None): + query = get_items_query(db=db, filter=filter) return query.count() -def get_items(db: Session, in_zephir: bool | None, limit: int, offset: int): +def get_items( + db: Session, + limit: int, + offset: int, + filter: schemas.ItemFilters = None, +): """ Get Digifeed items from the database Args: db (sqlalchemy.orm.Session): Digifeeds database session - in_zephir (bool | None): Whether or not the items are in zephir + filter (schemas.ItemFilters | None): filter to apply to the list of items. Returns: aim.digifeeds.database.models.Item: Item object """ - query = get_items_query(db=db, in_zephir=in_zephir) + query = get_items_query(db=db, filter=filter) return query.offset(offset).limit(limit).all() -def get_items_query(db: Session, in_zephir: bool | None): +def get_items_query(db: Session, filter: schemas.ItemFilters = None): query = db.query(models.Item) - if in_zephir is True: + + if filter == "in_zephir": query = query.filter( models.Item.statuses.any(models.ItemStatus.status_name == "in_zephir") ) - elif in_zephir is False: + elif filter == "not_in_zephir": query = query.filter( ~models.Item.statuses.any(models.ItemStatus.status_name == "in_zephir") ) diff --git a/aim/digifeeds/database/main.py b/aim/digifeeds/database/main.py index c9f3d02..22acb21 100644 --- a/aim/digifeeds/database/main.py +++ b/aim/digifeeds/database/main.py @@ -43,23 +43,24 @@ def get_db(): # pragma: no cover def get_items( offset: int = Query(0, ge=0, description="Requested offset from the list of pages"), limit: int = Query(50, ge=1, description="Requested number of items per page"), - in_zephir: bool | None = Query( - None, description="Filter for items that do or do not have metadata in Zephir" + filter: schemas.ItemFilters = Query( + None, description="Filters on the items in the database" ), db: Session = Depends(get_db), ) -> schemas.PageOfItems: # list[schemas.Item]: """ Get the digifeeds items. - These items can be filtered by whether or not their metadata is in Zephir or - all of them can be fetched. + These items can be filtered by whether or not their metadata is in Zephir, + whether or not they are pending deletion, if they are not in alma, or all of + them can be fetched. """ - db_items = crud.get_items(in_zephir=in_zephir, db=db, offset=offset, limit=limit) + db_items = crud.get_items(filter=filter, db=db, offset=offset, limit=limit) return { "limit": limit, "offset": offset, - "total": crud.get_items_total(in_zephir=in_zephir, db=db), + "total": crud.get_items_total(filter=filter, db=db), "items": db_items, } diff --git a/aim/digifeeds/database/schemas.py b/aim/digifeeds/database/schemas.py index 657b83f..9678879 100644 --- a/aim/digifeeds/database/schemas.py +++ b/aim/digifeeds/database/schemas.py @@ -2,6 +2,7 @@ from pydantic import BaseModel, Field, ConfigDict from datetime import datetime +from enum import Enum class ItemStatus(BaseModel): @@ -47,6 +48,14 @@ class PageOfItems(BaseModel): total: int = 1 +class ItemFilters(str, Enum): + in_zephir = "in_zephir" + not_in_zephir = "not_in_zephir" + pending_deletion = "pending_deletion" + not_pending_deletion = "not_pending_deletion" + not_in_alma = "not_in_alma" + + class ItemCreate(ItemBase): pass diff --git a/tests/digifeeds/database/test_crud.py b/tests/digifeeds/database/test_crud.py index c9b06bd..a2399b3 100644 --- a/tests/digifeeds/database/test_crud.py +++ b/tests/digifeeds/database/test_crud.py @@ -26,8 +26,8 @@ def test_get_items_and_total_any(self, db_session): item2 = add_item(db=db_session, item=ItemCreate(barcode="valid_barcode2")) status = get_status(db=db_session, name="in_zephir") add_item_status(db=db_session, item=item1, status=status) - items = get_items(db=db_session, in_zephir=None, limit=2, offset=0) - count = get_items_total(db=db_session, in_zephir=None) + items = get_items(db=db_session, filter=None, limit=2, offset=0) + count = get_items_total(db=db_session, filter=None) db_session.refresh(item1) db_session.refresh(item2) assert (items[0]) == item1 @@ -39,8 +39,8 @@ def test_get_items_and_total_in_zephir(self, db_session): item2 = add_item(db=db_session, item=ItemCreate(barcode="valid_barcode2")) status = get_status(db=db_session, name="in_zephir") add_item_status(db=db_session, item=item1, status=status) - items = get_items(db=db_session, in_zephir=True, limit=2, offset=0) - count = get_items_total(db=db_session, in_zephir=True) + items = get_items(db=db_session, filter="in_zephir", limit=2, offset=0) + count = get_items_total(db=db_session, filter="in_zephir") db_session.refresh(item1) db_session.refresh(item2) assert (len(items)) == 1 @@ -52,8 +52,8 @@ def test_get_items_and_total_not_in_zephir(self, db_session): item2 = add_item(db=db_session, item=ItemCreate(barcode="valid_barcode2")) status = get_status(db=db_session, name="in_zephir") add_item_status(db=db_session, item=item1, status=status) - items = get_items(db=db_session, in_zephir=False, limit=2, offset=0) - count = get_items_total(db=db_session, in_zephir=False) + items = get_items(db=db_session, filter="not_in_zephir", limit=2, offset=0) + count = get_items_total(db=db_session, filter="not_in_zephir") db_session.refresh(item1) db_session.refresh(item2) assert (len(items)) == 1 diff --git a/tests/digifeeds/database/test_main.py b/tests/digifeeds/database/test_main.py index 4e36a05..5a7cdf6 100644 --- a/tests/digifeeds/database/test_main.py +++ b/tests/digifeeds/database/test_main.py @@ -33,7 +33,7 @@ def test_get_items(client, valid_item, valid_in_zephir_item, db_session): def test_get_items_with_in_zephir_true(client, valid_item, valid_in_zephir_item): valid_item valid_in_zephir_item - response = client.get("/items", params={"in_zephir": True}) + response = client.get("/items", params={"filter": "in_zephir"}) assert response.status_code == 200, response.text assert len(response.json()["items"]) == 1 assert response.json()["items"][0]["barcode"] == valid_in_zephir_item.barcode @@ -42,7 +42,7 @@ def test_get_items_with_in_zephir_true(client, valid_item, valid_in_zephir_item) def test_get_items_with_in_zephir_false(client, valid_item, valid_in_zephir_item): valid_item valid_in_zephir_item - response = client.get("/items", params={"in_zephir": False}) + response = client.get("/items", params={"filter": "not_in_zephir"}) assert response.status_code == 200, response.text assert len(response.json()["items"]) == 1 assert response.json()["items"][0]["barcode"] == valid_item.barcode