From af2ec9a21875d64586958dcc7e28b4a678501570 Mon Sep 17 00:00:00 2001 From: Monique Rio Date: Thu, 31 Oct 2024 18:09:57 +0000 Subject: [PATCH 1/5] adds pagination to items endpoint --- aim/digifeeds/database/crud.py | 36 +++++++++++++++++++---- aim/digifeeds/database/main.py | 13 +++++++-- aim/digifeeds/database/schemas.py | 8 ++++++ tests/digifeeds/database/test_crud.py | 41 ++++++++++++++------------- tests/digifeeds/database/test_main.py | 28 ++++++++++++------ 5 files changed, 89 insertions(+), 37 deletions(-) diff --git a/aim/digifeeds/database/crud.py b/aim/digifeeds/database/crud.py index 887b799..a5a8b3d 100644 --- a/aim/digifeeds/database/crud.py +++ b/aim/digifeeds/database/crud.py @@ -3,6 +3,7 @@ Operations that act on the digifeeds database """ + from sqlalchemy.orm import Session from aim.digifeeds.database import schemas from aim.digifeeds.database import models @@ -23,7 +24,28 @@ def get_item(db: Session, barcode: str): return db.query(models.Item).filter(models.Item.barcode == barcode).first() -def get_items(db: Session, in_zephir: bool | None): +def get_item_total(db: Session, in_zephir: bool | None): + if in_zephir is True: + return ( + db.query(models.Item) + .filter( + models.Item.statuses.any(models.ItemStatus.status_name == "in_zephir") + ) + .count() + ) + elif in_zephir is False: + return ( + db.query(models.Item) + .filter( + ~models.Item.statuses.any(models.ItemStatus.status_name == "in_zephir") + ) + .count() + ) + + return db.query(models.Item).count() + + +def get_items(db: Session, in_zephir: bool | None, limit: int, offset: int): """ Get Digifeed items from the database @@ -38,22 +60,24 @@ def get_items(db: Session, in_zephir: bool | None): return ( db.query(models.Item) .filter( - models.Item.statuses.any( - models.ItemStatus.status_name == "in_zephir") + models.Item.statuses.any(models.ItemStatus.status_name == "in_zephir") ) + .offset(offset) + .limit(limit) .all() ) elif in_zephir is False: return ( db.query(models.Item) .filter( - ~models.Item.statuses.any( - models.ItemStatus.status_name == "in_zephir") + ~models.Item.statuses.any(models.ItemStatus.status_name == "in_zephir") ) + .offset(offset) + .limit(limit) .all() ) - return db.query(models.Item).all() + return db.query(models.Item).offset(offset).limit(limit).all() def add_item(db: Session, item: schemas.ItemCreate): diff --git a/aim/digifeeds/database/main.py b/aim/digifeeds/database/main.py index 1122c4b..86280e6 100644 --- a/aim/digifeeds/database/main.py +++ b/aim/digifeeds/database/main.py @@ -41,11 +41,13 @@ def get_db(): # pragma: no cover @app.get("/items/", response_model_by_alias=False, tags=["Digifeeds Database"]) def get_items( + offset: int = Query(0, ge=0, description="Requested offset from the list of pages"), + limit: int = Query(50, ge=1, description="Requested number of items per page"), in_zephir: bool | None = Query( None, description="Filter for items that do or do not have metadata in Zephir" ), db: Session = Depends(get_db), -) -> list[schemas.Item]: +) -> schemas.PageOfItems: # list[schemas.Item]: """ Get the digifeeds items. @@ -53,8 +55,13 @@ def get_items( all of them can be fetched. """ - db_items = crud.get_items(in_zephir=in_zephir, db=db) - return db_items + db_items = crud.get_items(in_zephir=in_zephir, db=db, offset=offset, limit=limit) + return { + "limit": limit, + "offset": offset, + "total": crud.get_item_total(in_zephir=in_zephir, db=db), + "items": db_items, + } @app.get( diff --git a/aim/digifeeds/database/schemas.py b/aim/digifeeds/database/schemas.py index c2a43fb..af4e76b 100644 --- a/aim/digifeeds/database/schemas.py +++ b/aim/digifeeds/database/schemas.py @@ -1,4 +1,5 @@ """Digifeeds Pydantic Models""" + from pydantic import BaseModel, Field, ConfigDict from datetime import datetime @@ -39,6 +40,12 @@ class Item(ItemBase): ) +class PageOfItems(BaseModel): + items: list[Item] + limit: int = 10 + offset: int = 0 + total: int = 15 + class ItemCreate(ItemBase): pass @@ -67,6 +74,7 @@ class Response400(Response): } ) + class Response404(Response): model_config = ConfigDict( json_schema_extra={ diff --git a/tests/digifeeds/database/test_crud.py b/tests/digifeeds/database/test_crud.py index ae9de03..d6f6484 100644 --- a/tests/digifeeds/database/test_crud.py +++ b/tests/digifeeds/database/test_crud.py @@ -4,63 +4,64 @@ add_item, get_status, get_statuses, - add_item_status + add_item_status, ) from aim.digifeeds.database.schemas import ItemCreate + class TestCrud: def test_get_item(self, db_session): item = add_item(db=db_session, item=ItemCreate(barcode="valid_barcode")) barcode = item.barcode item_in_db = get_item(barcode=barcode, db=db_session) - assert(item_in_db.barcode) == "valid_barcode" + assert (item_in_db.barcode) == "valid_barcode" def test_get_item_that_does_not_exist(self, db_session): item_in_db = get_item(barcode="does not exist", db=db_session) - assert(item_in_db) is None + assert (item_in_db) is None def test_get_items_all(self, db_session): item1 = add_item(db=db_session, item=ItemCreate(barcode="valid_barcode")) item2 = add_item(db=db_session, item=ItemCreate(barcode="valid_barcode2")) status = get_status(db=db_session, name="in_zephir") - add_item_status(db=db_session,item=item1, status=status) - items = get_items(db=db_session, in_zephir=None) + add_item_status(db=db_session, item=item1, status=status) + items = get_items(db=db_session, in_zephir=None, limit=2, offset=0) db_session.refresh(item1) db_session.refresh(item2) - assert(items[0]) == item1 - assert(items[1]) == item2 - + assert (items[0]) == item1 + assert (items[1]) == item2 + def test_get_items_in_zephir(self, db_session): item1 = add_item(db=db_session, item=ItemCreate(barcode="valid_barcode")) item2 = add_item(db=db_session, item=ItemCreate(barcode="valid_barcode2")) status = get_status(db=db_session, name="in_zephir") - add_item_status(db=db_session,item=item1, status=status) - items = get_items(db=db_session, in_zephir=True) + add_item_status(db=db_session, item=item1, status=status) + items = get_items(db=db_session, in_zephir=True, limit=2, offset=0) db_session.refresh(item1) db_session.refresh(item2) - assert(len(items)) == 1 - assert(items[0]) == item1 + assert (len(items)) == 1 + assert (items[0]) == item1 def test_get_items_not_in_zephir(self, db_session): item1 = add_item(db=db_session, item=ItemCreate(barcode="valid_barcode")) item2 = add_item(db=db_session, item=ItemCreate(barcode="valid_barcode2")) status = get_status(db=db_session, name="in_zephir") - add_item_status(db=db_session,item=item1, status=status) - items = get_items(db=db_session, in_zephir=False) + add_item_status(db=db_session, item=item1, status=status) + items = get_items(db=db_session, in_zephir=False, limit=2, offset=0) db_session.refresh(item1) db_session.refresh(item2) - assert(len(items)) == 1 - assert(items[0]) == item2 + assert (len(items)) == 1 + assert (items[0]) == item2 def test_get_status_that_exists(self, db_session): status = get_status(db=db_session, name="in_zephir") - assert(status.name) == "in_zephir" + assert (status.name) == "in_zephir" def test_get_status_that_does_not_exist(self, db_session): status = get_status(db=db_session, name="does_not_exist") - assert(status) is None + assert (status) is None def test_get_statuses(self, db_session): statuses = get_statuses(db=db_session) - assert(len(statuses)) > 1 - assert(statuses[0].name) == "in_zephir" + assert (len(statuses)) > 1 + assert (statuses[0].name) == "in_zephir" diff --git a/tests/digifeeds/database/test_main.py b/tests/digifeeds/database/test_main.py index 3996f6d..4e36a05 100644 --- a/tests/digifeeds/database/test_main.py +++ b/tests/digifeeds/database/test_main.py @@ -2,10 +2,12 @@ from aim.digifeeds.database.schemas import ItemCreate import pytest + @pytest.fixture() def valid_item(db_session): return crud.add_item(db=db_session, item=ItemCreate(barcode="valid_barcode")) + @pytest.fixture() def valid_in_zephir_item(db_session): item = crud.add_item(db=db_session, item=ItemCreate(barcode="in_zephir_item")) @@ -14,32 +16,36 @@ def valid_in_zephir_item(db_session): db_session.refresh(item) return item + def test_get_statuses(client): response = client.get("/statuses") assert response.status_code == 200, response.text + def test_get_items(client, valid_item, valid_in_zephir_item, db_session): valid_item valid_in_zephir_item response = client.get("/items") assert response.status_code == 200, response.text - assert len(response.json()) == 2 + assert len(response.json()["items"]) == 2 + def test_get_items_with_in_zephir_true(client, valid_item, valid_in_zephir_item): valid_item valid_in_zephir_item - response = client.get("/items", params={"in_zephir":True}) + response = client.get("/items", params={"in_zephir": True}) assert response.status_code == 200, response.text - assert len(response.json()) == 1 - assert response.json()[0]["barcode"] == valid_in_zephir_item.barcode + assert len(response.json()["items"]) == 1 + assert response.json()["items"][0]["barcode"] == valid_in_zephir_item.barcode + def test_get_items_with_in_zephir_false(client, valid_item, valid_in_zephir_item): valid_item valid_in_zephir_item - response = client.get("/items", params={"in_zephir":False}) + response = client.get("/items", params={"in_zephir": False}) assert response.status_code == 200, response.text - assert len(response.json()) == 1 - assert response.json()[0]["barcode"] == valid_item.barcode + assert len(response.json()["items"]) == 1 + assert response.json()["items"][0]["barcode"] == valid_item.barcode def test_get_item(client, valid_item, valid_in_zephir_item): @@ -48,30 +54,36 @@ def test_get_item(client, valid_item, valid_in_zephir_item): response = client.get(f"/items/{valid_item.barcode}") assert response.status_code == 200, response.text + def test_get_item_not_found(client): response = client.get("/items/some_barcode_that_does_not_exist") assert response.status_code == 404 assert response.json() == {"detail": "Item not found"} + def test_create_item(client): response = client.post("items/new_barcode") assert response.status_code == 200, response.text + def test_create_existing_item(client, valid_item): response = client.post(f"items/{valid_item.barcode}") assert response.status_code == 400 assert response.json() == {"detail": "Item already exists"} + def test_update_item_success(client, valid_item): response = client.put(f"items/{valid_item.barcode}/status/in_zephir") assert response.status_code == 200, response.text + def test_update_nonexisting_item(client): response = client.put("/items/some_barcode_that_does_not_exist/status/in_zephir") assert response.status_code == 404 assert response.json() == {"detail": "Item not found"} + def test_update_existing_item_with_nonexistent_status(client, valid_item): response = client.put(f"/items/{valid_item.barcode}/status/non_existent_status") assert response.status_code == 404 - assert response.json() == {"detail": "Status not found"} \ No newline at end of file + assert response.json() == {"detail": "Status not found"} From 1cfba87096eb3e937e5aefd443b500746ad7f9e9 Mon Sep 17 00:00:00 2001 From: Monique Rio Date: Thu, 31 Oct 2024 18:36:28 +0000 Subject: [PATCH 2/5] adds test for count; refactors filter in_zephir --- aim/digifeeds/database/crud.py | 51 ++++++++------------------- aim/digifeeds/database/main.py | 2 +- tests/digifeeds/database/test_crud.py | 13 +++++-- 3 files changed, 25 insertions(+), 41 deletions(-) diff --git a/aim/digifeeds/database/crud.py b/aim/digifeeds/database/crud.py index a5a8b3d..f2979fb 100644 --- a/aim/digifeeds/database/crud.py +++ b/aim/digifeeds/database/crud.py @@ -24,25 +24,9 @@ def get_item(db: Session, barcode: str): return db.query(models.Item).filter(models.Item.barcode == barcode).first() -def get_item_total(db: Session, in_zephir: bool | None): - if in_zephir is True: - return ( - db.query(models.Item) - .filter( - models.Item.statuses.any(models.ItemStatus.status_name == "in_zephir") - ) - .count() - ) - elif in_zephir is False: - return ( - db.query(models.Item) - .filter( - ~models.Item.statuses.any(models.ItemStatus.status_name == "in_zephir") - ) - .count() - ) - - return db.query(models.Item).count() +def get_items_total(db: Session, in_zephir: bool | None): + query = get_items_query(db=db, in_zephir=in_zephir) + return query.count() def get_items(db: Session, in_zephir: bool | None, limit: int, offset: int): @@ -56,28 +40,21 @@ def get_items(db: Session, in_zephir: bool | None, limit: int, offset: int): Returns: aim.digifeeds.database.models.Item: Item object """ + query = get_items_query(db=db, in_zephir=in_zephir) + return query.offset(offset).limit(limit).all() + + +def get_items_query(db: Session, in_zephir: bool | None): + query = db.query(models.Item) if in_zephir is True: - return ( - db.query(models.Item) - .filter( - models.Item.statuses.any(models.ItemStatus.status_name == "in_zephir") - ) - .offset(offset) - .limit(limit) - .all() + query = query.filter( + models.Item.statuses.any(models.ItemStatus.status_name == "in_zephir") ) elif in_zephir is False: - return ( - db.query(models.Item) - .filter( - ~models.Item.statuses.any(models.ItemStatus.status_name == "in_zephir") - ) - .offset(offset) - .limit(limit) - .all() + query = query.filter( + ~models.Item.statuses.any(models.ItemStatus.status_name == "in_zephir") ) - - return db.query(models.Item).offset(offset).limit(limit).all() + return query def add_item(db: Session, item: schemas.ItemCreate): diff --git a/aim/digifeeds/database/main.py b/aim/digifeeds/database/main.py index 86280e6..c9f3d02 100644 --- a/aim/digifeeds/database/main.py +++ b/aim/digifeeds/database/main.py @@ -59,7 +59,7 @@ def get_items( return { "limit": limit, "offset": offset, - "total": crud.get_item_total(in_zephir=in_zephir, db=db), + "total": crud.get_items_total(in_zephir=in_zephir, db=db), "items": db_items, } diff --git a/tests/digifeeds/database/test_crud.py b/tests/digifeeds/database/test_crud.py index d6f6484..c9b06bd 100644 --- a/tests/digifeeds/database/test_crud.py +++ b/tests/digifeeds/database/test_crud.py @@ -5,6 +5,7 @@ get_status, get_statuses, add_item_status, + get_items_total, ) from aim.digifeeds.database.schemas import ItemCreate @@ -20,38 +21,44 @@ def test_get_item_that_does_not_exist(self, db_session): item_in_db = get_item(barcode="does not exist", db=db_session) assert (item_in_db) is None - def test_get_items_all(self, db_session): + def test_get_items_and_total_any(self, db_session): item1 = add_item(db=db_session, item=ItemCreate(barcode="valid_barcode")) item2 = add_item(db=db_session, item=ItemCreate(barcode="valid_barcode2")) status = get_status(db=db_session, name="in_zephir") add_item_status(db=db_session, item=item1, status=status) items = get_items(db=db_session, in_zephir=None, limit=2, offset=0) + count = get_items_total(db=db_session, in_zephir=None) db_session.refresh(item1) db_session.refresh(item2) assert (items[0]) == item1 assert (items[1]) == item2 + assert (count) == 2 - def test_get_items_in_zephir(self, db_session): + def test_get_items_and_total_in_zephir(self, db_session): item1 = add_item(db=db_session, item=ItemCreate(barcode="valid_barcode")) item2 = add_item(db=db_session, item=ItemCreate(barcode="valid_barcode2")) status = get_status(db=db_session, name="in_zephir") add_item_status(db=db_session, item=item1, status=status) items = get_items(db=db_session, in_zephir=True, limit=2, offset=0) + count = get_items_total(db=db_session, in_zephir=True) db_session.refresh(item1) db_session.refresh(item2) assert (len(items)) == 1 assert (items[0]) == item1 + assert count == 1 - def test_get_items_not_in_zephir(self, db_session): + def test_get_items_and_total_not_in_zephir(self, db_session): item1 = add_item(db=db_session, item=ItemCreate(barcode="valid_barcode")) item2 = add_item(db=db_session, item=ItemCreate(barcode="valid_barcode2")) status = get_status(db=db_session, name="in_zephir") add_item_status(db=db_session, item=item1, status=status) items = get_items(db=db_session, in_zephir=False, limit=2, offset=0) + count = get_items_total(db=db_session, in_zephir=False) db_session.refresh(item1) db_session.refresh(item2) assert (len(items)) == 1 assert (items[0]) == item2 + assert count == 1 def test_get_status_that_exists(self, db_session): status = get_status(db=db_session, name="in_zephir") From 12ab4e4d3def16125c96da1b80d73551fff6fc24 Mon Sep 17 00:00:00 2001 From: Monique Rio Date: Thu, 31 Oct 2024 20:11:27 +0000 Subject: [PATCH 3/5] DBClient has a get_items method that pages through results from api --- aim/digifeeds/db_client.py | 29 +++++++++ tests/digifeeds/test_db_client.py | 83 +++++++++++++++++++++++++ tests/fixtures/digifeeds/item_list.json | 18 ++++++ 3 files changed, 130 insertions(+) create mode 100644 tests/fixtures/digifeeds/item_list.json diff --git a/aim/digifeeds/db_client.py b/aim/digifeeds/db_client.py index 6c6a3ed..d6f4538 100644 --- a/aim/digifeeds/db_client.py +++ b/aim/digifeeds/db_client.py @@ -68,5 +68,34 @@ def add_item_status(self, barcode: str, status: str): response.raise_for_status() return response.json() + def get_items(self, limit: int = 50, in_zephir: bool | None = None): + items = [] + url = self._url(f"items") + params = { + "limit": limit, + "offset": 0, + } + if in_zephir != None: + params["in_zephir"] = in_zephir + + response = requests.get(url, params=params) + if response.status_code != 200: + response.raise_for_status() + + first_page = response.json() + total = first_page["total"] + for item in first_page["items"]: + items.append(item) + + for offset in list(range(limit, total, limit)): + params["offset"] = offset + response = requests.get(url, params=params) + if response.status_code != 200: + response.raise_for_status() + for item in response.json()["items"]: + items.append(item) + + return items + def _url(self, path) -> str: return f"{self.base_url}/{path}" diff --git a/tests/digifeeds/test_db_client.py b/tests/digifeeds/test_db_client.py index df9b180..083b415 100644 --- a/tests/digifeeds/test_db_client.py +++ b/tests/digifeeds/test_db_client.py @@ -1,8 +1,18 @@ import responses +from responses import matchers import pytest from aim.services import S from aim.digifeeds.db_client import DBClient from requests.exceptions import HTTPError +import json +import copy + + +@pytest.fixture +def item_list(): + with open("tests/fixtures/digifeeds/item_list.json") as f: + output = json.load(f) + return output @responses.activate @@ -79,3 +89,76 @@ def test_add_item_status_failure(): with pytest.raises(Exception) as exc_info: DBClient().add_item_status(barcode="my_barcode", status="in_zephir") assert exc_info.type is HTTPError + + +@responses.activate +def test_get_items_multiple_pages(item_list): + page_2 = copy.copy(item_list) + page_2["offset"] = 1 + page_2["items"][0]["barcode"] = "some_other_barcode" + url = f"{S.digifeeds_api_url}/items" + responses.get( + url=url, + match=[matchers.query_param_matcher({"limit": 1, "offset": 0})], + json=item_list, + ) + responses.get( + url=url, + match=[matchers.query_param_matcher({"limit": 1, "offset": 1})], + json=page_2, + ) + + items = DBClient().get_items(limit=1) + assert (len(items)) == 2 + + +@responses.activate +def test_get_items_in_zephir_value(item_list): + item_list["total"] = 1 + url = f"{S.digifeeds_api_url}/items" + responses.get( + url=url, + match=[ + matchers.query_param_matcher({"limit": 1, "offset": 0, "in_zephir": False}) + ], + json=item_list, + ) + items = DBClient().get_items(limit=1, in_zephir=False) + assert (len(items)) == 1 + + +@responses.activate +def test_get_items_fail_first_page(): + url = f"{S.digifeeds_api_url}/items" + responses.get( + url=url, + status=500, + match=[matchers.query_param_matcher({"limit": 1, "offset": 0})], + json={}, + ) + + with pytest.raises(Exception) as exc_info: + DBClient().get_items(limit=1) + + assert exc_info.type is HTTPError + + +@responses.activate +def test_get_items_fail_later_page(item_list): + url = f"{S.digifeeds_api_url}/items" + responses.get( + url=url, + match=[matchers.query_param_matcher({"limit": 1, "offset": 0})], + json=item_list, + ) + responses.get( + url=url, + status=500, + match=[matchers.query_param_matcher({"limit": 1, "offset": 1})], + json={}, + ) + + with pytest.raises(Exception) as exc_info: + DBClient().get_items(limit=1) + + assert exc_info.type is HTTPError diff --git a/tests/fixtures/digifeeds/item_list.json b/tests/fixtures/digifeeds/item_list.json new file mode 100644 index 0000000..b9c3502 --- /dev/null +++ b/tests/fixtures/digifeeds/item_list.json @@ -0,0 +1,18 @@ +{ + "limit": 1, + "offset": 0, + "total": 2, + "items": [ + { + "barcode": "some_barcode", + "created_at": "2024-09-25T17:12:39", + "statuses": [ + { + "name": "added_to_digifeeds_set", + "description": "Item has been added to the digifeeds set", + "created_at": "2024-09-25T17:13:28" + } + ] + } + ] +} \ No newline at end of file From 2194efb20e098d7d3bcbfca19d377837f638da22 Mon Sep 17 00:00:00 2001 From: Monique Rio Date: Mon, 18 Nov 2024 19:13:39 +0000 Subject: [PATCH 4/5] linting --- aim/cli/main.py | 1 + aim/digifeeds/db_client.py | 4 ++-- docs/conf.py | 13 ++++++++++--- tests/conftest.py | 14 +++++++++----- tests/digifeeds/database/test_models.py | 14 +++++--------- 5 files changed, 27 insertions(+), 19 deletions(-) diff --git a/aim/cli/main.py b/aim/cli/main.py index 4124886..8c6df0a 100644 --- a/aim/cli/main.py +++ b/aim/cli/main.py @@ -3,6 +3,7 @@ This hooks up the AIM CLI application. Nothing exciting happening here. """ + import typer import aim.cli.digifeeds as digifeeds diff --git a/aim/digifeeds/db_client.py b/aim/digifeeds/db_client.py index d6f4538..f08b0cd 100644 --- a/aim/digifeeds/db_client.py +++ b/aim/digifeeds/db_client.py @@ -70,12 +70,12 @@ def add_item_status(self, barcode: str, status: str): def get_items(self, limit: int = 50, in_zephir: bool | None = None): items = [] - url = self._url(f"items") + url = self._url("items") params = { "limit": limit, "offset": 0, } - if in_zephir != None: + if in_zephir is not None: params["in_zephir"] = in_zephir response = requests.get(url, params=params) diff --git a/docs/conf.py b/docs/conf.py index fe52415..739c0a9 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -14,8 +14,15 @@ # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration -extensions = ["sphinx.ext.napoleon", "sphinx.ext.viewcode", "sphinx.ext.autosummary", - "sphinx.ext.autodoc", 'myst_parser', 'sphinxcontrib.mermaid', "sphinx_toolbox.more_autodoc.autonamedtuple"] +extensions = [ + "sphinx.ext.napoleon", + "sphinx.ext.viewcode", + "sphinx.ext.autosummary", + "sphinx.ext.autodoc", + "myst_parser", + "sphinxcontrib.mermaid", + "sphinx_toolbox.more_autodoc.autonamedtuple", +] autosummary_generate = True mermaid_d3_zoom = True @@ -33,5 +40,5 @@ html_theme_options = { "navigation_depth": 5, "collapse_navigation": False, - "titles_only": True + "titles_only": True, } diff --git a/tests/conftest.py b/tests/conftest.py index 8e71b08..99204f1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,8 +10,10 @@ engine = create_engine( S.test_database, - connect_args={ "check_same_thread": False,}, - poolclass=StaticPool + connect_args={ + "check_same_thread": False, + }, + poolclass=StaticPool, ) TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) @@ -25,10 +27,11 @@ session.close() connection.close() + # From: https://stackoverflow.com/questions/67255653/how-to-set-up-and-tear-down-a-database-between-tests-in-fastapi # These two event listeners are only needed for sqlite for proper # SAVEPOINT / nested transaction support. Other databases like postgres -# don't need them. +# don't need them. # From: https://docs.sqlalchemy.org/en/14/dialects/sqlite.html#serializable-isolation-savepoints-transactional-ddl @sa.event.listens_for(engine, "connect") def do_connect(dbapi_connection, connection_record): @@ -42,6 +45,7 @@ def do_begin(conn): # emit our own BEGIN conn.exec_driver_sql("BEGIN") + # Handles rolling back the db after every test @pytest.fixture() def db_session(scope="module"): @@ -49,7 +53,6 @@ def db_session(scope="module"): transaction = connection.begin() session = TestingSessionLocal(bind=connection) - # Begin a nested transaction (using SAVEPOINT). nested = connection.begin_nested() @@ -68,6 +71,7 @@ def end_savepoint(session, transaction): transaction.rollback() connection.close() + # A fixture for the fastapi test client which depends on the # previous session fixture. Instead of creating a new session in the # dependency override as before, it uses the one provided by the @@ -79,4 +83,4 @@ def override_get_db(): app.dependency_overrides[get_db] = override_get_db yield TestClient(app) - del app.dependency_overrides[get_db] \ No newline at end of file + del app.dependency_overrides[get_db] diff --git a/tests/digifeeds/database/test_models.py b/tests/digifeeds/database/test_models.py index e6d183d..23769cc 100644 --- a/tests/digifeeds/database/test_models.py +++ b/tests/digifeeds/database/test_models.py @@ -1,5 +1,6 @@ from aim.digifeeds.database.models import Item, Status, ItemStatus + class TestItem: def test_item_valid(self, db_session): valid_item = Item(barcode="valid_barcode") @@ -15,19 +16,14 @@ def test_item_statuses(self, db_session): db_session.commit() status = db_session.query(Status).filter_by(name="in_zephir").first() db_session.refresh(item) - assert(len(item.statuses)) == 0 + assert (len(item.statuses)) == 0 - item_status = ItemStatus(item=item,status=status) + item_status = ItemStatus(item=item, status=status) db_session.add(item_status) db_session.commit() db_session.refresh(item) assert item.barcode == "valid_barcode" - assert(len(item.statuses)) == 1 - assert(item.statuses[0].created_at) + assert (len(item.statuses)) == 1 + assert item.statuses[0].created_at assert item.statuses[0].status_name == "in_zephir" - - - - - \ No newline at end of file From 7407a683f7256cb3ef56df856bf501c033bb45fb Mon Sep 17 00:00:00 2001 From: Monique Rio Date: Mon, 18 Nov 2024 19:33:16 +0000 Subject: [PATCH 5/5] docs: make digifeeds pagination example make sense --- aim/digifeeds/database/schemas.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aim/digifeeds/database/schemas.py b/aim/digifeeds/database/schemas.py index af4e76b..657b83f 100644 --- a/aim/digifeeds/database/schemas.py +++ b/aim/digifeeds/database/schemas.py @@ -44,7 +44,7 @@ class PageOfItems(BaseModel): items: list[Item] limit: int = 10 offset: int = 0 - total: int = 15 + total: int = 1 class ItemCreate(ItemBase):