Skip to content

Commit

Permalink
Refactor model endpoint tests to remove bdd (#4084)
Browse files Browse the repository at this point in the history
  • Loading branch information
conbrad authored Nov 13, 2024
1 parent 6ef7b1c commit b55828a
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 72 deletions.

This file was deleted.

102 changes: 45 additions & 57 deletions api/app/tests/weather_models/endpoints/test_models_endpoints.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,17 @@
""" Functional testing for /models/* endpoints.
"""
import os
import json
import importlib
import logging
import pytest
from aiohttp import ClientSession
from pytest_bdd import scenario, given, then, when, parsers
from fastapi.testclient import TestClient
import app.main
from app.tests import load_json_file, load_sqlalchemy_response_from_json
from app.tests.common import default_mock_client_get
from app.tests.utils.mock_jwt_decode_role import MockJWTDecodeWithRole


logger = logging.getLogger(__name__)


@pytest.mark.usefixtures("mock_jwt_decode")
@scenario("test_models_endpoints.feature", "Generic model endpoint testing")
def test_model_predictions_summaries_scenario():
""" BDD Scenario for prediction summaries """


def _patch_function(monkeypatch, module_name: str, function_name: str, json_filename: str):
""" Patch module_name.function_name to return de-serialized json_filename """
"""Patch module_name.function_name to return de-serialized json_filename"""

def mock_get_data(*_):
dirname = os.path.dirname(os.path.realpath(__file__))
filename = os.path.join(dirname, json_filename)
Expand All @@ -33,50 +20,51 @@ def mock_get_data(*_):
monkeypatch.setattr(importlib.import_module(module_name), function_name, mock_get_data)


@given(parsers.parse("some explanatory {notes}"), converters={'notes': str})
def given_some_notes(notes: str):
""" Send notes to the logger. """
logger.info(notes)


@given(parsers.parse("A weather model crud mapping {crud_mapping}"), target_fixture='database', converters={'crud_mapping': load_json_file(__file__)})
def given_a_database(monkeypatch, crud_mapping: dict):
""" Mock the sql response """

for item in crud_mapping:
_patch_function(monkeypatch, item['module'], item['function'], item['json'])

return {}


@when(parsers.parse("I call {endpoint} with {codes}"), converters={'endpoint': str, 'codes': json.loads})
def when_prediction(database: dict, codes: str, endpoint: str, monkeypatch: pytest.MonkeyPatch):
""" Make call to endpoint """

def mock_admin_role_function(*_, **__):
return MockJWTDecodeWithRole("morecast2_write_forecast")
@pytest.mark.parametrize(
"codes, endpoint, crud_mapping, expected_status_code, expected_response_file",
[
([322], "/api/weather_models/GDPS/predictions/summaries/", "test_models_predictions_summaries_crud_mappings.json", 200, "test_models_predictions_summaries_response.json"),
(
[322, 838],
"/api/weather_models/GDPS/predictions/summaries/",
"test_models_predictions_summaries_multiple_crud_mappings.json",
200,
"test_models_predictions_summaries_response_multiple.json",
),
(
[838],
"/api/weather_models/GDPS/predictions/most_recent/",
"test_models_predictions_most_recent_GDPS_[838]_crud_mappings.json",
200,
"test_models_predictions_most_recent_GDPS_[838]_response.json",
),
(
[838, 209],
"/api/weather_models/GDPS/predictions/most_recent/",
"test_models_predictions_most_recent_RDPS_[838, 209]_crud_mappings.json",
200,
"test_models_predictions_most_recent_RDPS_[838, 209]_response.json",
),
(
[956],
"/api/weather_models/GDPS/predictions/most_recent/",
"test_models_predictions_most_recent_GDPS_[956]_crud_mappings.json",
200,
"test_models_predictions_most_recent_GDPS_[956]_response.json",
),
],
)
@pytest.mark.usefixtures("mock_jwt_decode")
def test_successful_model_endpoint_calls(codes, endpoint, crud_mapping, expected_status_code, expected_response_file, monkeypatch):
with open(os.path.join(os.path.dirname(__file__), crud_mapping), "r", encoding="utf-8") as tmp:
for item in json.load(tmp):
_patch_function(monkeypatch, item["module"], item["function"], item["json"])

decode_fn = "jwt.decode"
monkeypatch.setattr(decode_fn, mock_admin_role_function)
monkeypatch.setattr(ClientSession, "get", default_mock_client_get)

client = TestClient(app.main.app)
response = client.post(
endpoint, headers={'Authorization': 'Bearer token'}, json={'stations': codes})
if response.status_code == 200:
database['response_json'] = response.json()
database['status_code'] = response.status_code


@then(parsers.parse('The status code = {expected_status_code}'), converters={'expected_status_code': int})
def assert_status_code(database: dict, expected_status_code: int):
""" Assert that the status code is as expected
"""
assert database['status_code'] == expected_status_code

response = client.post(endpoint, headers={"Authorization": "Bearer token"}, json={"stations": codes})

@then(parsers.parse('The response = {expected_response}'), converters={'expected_response': load_json_file(__file__)})
def assert_response(database: dict, expected_response: dict):
""" Assert that the response is as expected
"""
assert database['response_json'] == expected_response
assert response.status_code == expected_status_code
expected_response = load_json_file(__file__)(expected_response_file)
assert response.json() == expected_response

0 comments on commit b55828a

Please sign in to comment.