Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace bdd tests with parametrized tests for stations endpoints #4062

Merged
merged 2 commits into from
Nov 6, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 15 additions & 132 deletions api/app/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,114 +1,51 @@
""" Util & common files for tests
"""
from typing import IO, Any, Callable, Optional, Tuple
"""Util & common files for tests"""

from typing import Callable, Optional
from dateutil import parser
import os
import datetime
import json
import importlib
import jsonpickle
from app.db.models.common import TZTimeStamp


def get_complete_filename(module_path: str, filename: str):
""" Get the full path of a filename, given it's module path """
"""Get the full path of a filename, given it's module path"""
dirname = os.path.dirname(os.path.realpath(module_path))
return os.path.join(dirname, filename)


def _load_json_file(module_path: str, filename: str) -> Optional[dict]:
""" Load json file given a module path and a filename """
if filename == 'None': # Not the best solution...
"""Load json file given a module path and a filename"""
if filename == "None": # Not the best solution...
return None
if filename:
with open(get_complete_filename(module_path, filename), encoding="utf-8") as file_pointer:
return json.load(file_pointer)
return None


def _load_json_file_with_name(module_path: str, filename: str) -> Tuple[Optional[dict], str]:
""" Load json file given a module path and a filename """
if filename == 'None': # Not the best solution...
return None, filename
if filename:
with open(get_complete_filename(module_path, filename), encoding="utf-8") as file_pointer:
return json.load(file_pointer), filename
return None, filename


def load_json_file(module_path: str) -> Callable[[str], dict]:
""" Return a function that can load json from a filename and return a dict """
"""Return a function that can load json from a filename and return a dict"""

def _json_loader(filename: str):
return _load_json_file(module_path, filename)
return _json_loader


def load_json_file_with_name(module_path: str) -> Callable[[str], dict]:
""" Return a function that can load a json from a filename and return a dict, but also the filename """
def _json_loader(filename: str):
return _load_json_file_with_name(module_path, filename)
return _json_loader


def json_converter(item: object):
""" Add datetime serialization """
if isinstance(item, datetime.datetime):
return item.isoformat()
return None


def dump_sqlalchemy_row_data_to_json(response, target: IO[Any]):
""" Useful for dumping sqlalchemy responses to json in for unit tests. """
result = []
for response_row in response:
result.append(jsonpickle.encode(response_row))
target.write(jsonpickle.encode(result))


def dump_sqlalchemy_mapped_object_response_to_json(response, target: IO[Any]):
""" Useful for dumping sqlalchemy responses to json in for unit tests.

e.g. if we want to store the response for GDPS predictions for two stations, we could write the
following code:
```python
query = get_station_model_predictions_order_by_prediction_timestamp(
session, [322, 838], ModelEnum.GDPS, back_5_days, now)
with open('tmp.json', 'w') as tmp:
dump_sqlalchemy_response_to_json(query, tmp)
```
"""
result = []
for row in response:
result_row = []
for record in row:
# Copy the dict so we can safely change it.
data = dict(record.__dict__)
# Pop internal value
data.pop('_sa_instance_state')
result_row.append(
{
'module': type(record).__module__,
'class': type(record).__name__,
'data': data
}
)
result.append(result_row)
json.dump(result, fp=target, default=json_converter, indent=3)


def load_sqlalchemy_response_from_json(filename):
""" Load a sqlalchemy response from a json file """
with open(filename, 'r', encoding="utf-8") as tmp:
"""Load a sqlalchemy response from a json file"""
with open(filename, "r", encoding="utf-8") as tmp:
data = json.load(tmp)
return load_sqlalchemy_response_from_object(data)


def de_serialize_record(record):
""" De-serailize a single sqlalchemy record """
module = importlib.import_module(record['module'])
class_ = getattr(module, record['class'])
"""De-serailize a single sqlalchemy record"""
module = importlib.import_module(record["module"])
class_ = getattr(module, record["class"])
record_data = {}
for key, value in record['data'].items():
for key, value in record["data"].items():
# Handle the special case, where the type is timestamp, converting the string to the
# correct data type.
if isinstance(getattr(class_, key).type, TZTimeStamp):
Expand All @@ -119,7 +56,7 @@ def de_serialize_record(record):


def load_sqlalchemy_response_from_object(data: object):
""" Load a sqlalchemy response from an object """
"""Load a sqlalchemy response from an object"""
# Usualy the data is a list of objects - or a list of list of objects.
# e.g.: [ { record }]
# e.g.: or [ [{record}, {record}]]
Expand All @@ -134,57 +71,3 @@ def load_sqlalchemy_response_from_object(data: object):
return result
# Sometimes though, we're only expecting a single record, not a list.
return de_serialize_record(data)


def apply_crud_mapping(monkeypatch, crud_mapping: dict, module_path: str):
""" Mock the sql response
The crud response was generated by temporarily introducing
"dump_sqlalchemy_row_data_to_json" and "dump_sqlalchemy_mapped_object_response_to_json"
in code - and saving the database responses.
"""

if crud_mapping:
for item in crud_mapping:
if item['serializer'] == "jsonpickle":
_jsonpickle_patch_function(monkeypatch,
item['module'], item['function'], item['json'], module_path)
else:
_json_patch_function(monkeypatch,
item['module'], item['function'], item['json'], module_path)

return {}


def _jsonpickle_patch_function(
monkeypatch,
module_name: str,
function_name: str,
json_filename: str,
module_path: str):
""" Patch module_name.function_name to return de-serialized json_filename """
def mock_get_data(*_):
filename = get_complete_filename(module_path, json_filename)
with open(filename, encoding="utf-8") as file_pointer:
rows = jsonpickle.decode(file_pointer.read())
for row in rows:
# Workaround to remain compatible with old tests. Ideally we would just always pickle the row.
if isinstance(row, str):
yield jsonpickle.decode(row)
continue
yield row

monkeypatch.setattr(importlib.import_module(module_name), function_name, mock_get_data)


def _json_patch_function(monkeypatch,
module_name: str,
function_name: str,
json_filename: str,
module_path: str):
""" Patch module_name.function_name to return de-serialized json_filename """
def mock_get_data(*_):
filename = get_complete_filename(module_path, json_filename)
with open(filename, encoding="utf-8") as file_pointer:
return json.load(file_pointer)

monkeypatch.setattr(importlib.import_module(module_name), function_name, mock_get_data)
26 changes: 0 additions & 26 deletions api/app/tests/test_stations.feature

This file was deleted.

84 changes: 0 additions & 84 deletions api/app/tests/test_stations.py

This file was deleted.

61 changes: 61 additions & 0 deletions api/app/tests/test_stations_new.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from aiohttp import ClientSession
import pytest
import app.main
from datetime import datetime, timezone
from app.tests import load_json_file
from app.tests.common import default_mock_client_get
from fastapi.testclient import TestClient
from httpx import Response


@pytest.fixture()
def client():
from app.main import app as test_app

with TestClient(test_app) as test_client:
yield test_client


@pytest.mark.parametrize(
"url, status, code, name, lat, long",
[
("/api/stations/", 200, 331, "ASHNOLA", 49.13905, -120.1844),
("/api/stations/", 200, 322, "AFTON", 50.6733333, -120.4816667),
("/api/stations/", 200, 317, "ALLISON PASS", 49.0623139, -120.7674194),
],
)
@pytest.mark.usefixtures("mock_jwt_decode")
def test_get_stations(
client: TestClient,
monkeypatch,
url,
status,
code,
name,
lat,
long,
):
monkeypatch.setattr(ClientSession, "get", default_mock_client_get)
response: Response = client.get(url)
assert response.status_code == status
station = next(x for x in response.json()["features"] if x["properties"]["code"] == code)
assert station["properties"]["code"] == code, "Code"
assert station["properties"]["name"] == name, "Name"
assert station["geometry"]["coordinates"][1] == lat, "Latitude"
assert station["geometry"]["coordinates"][0] == long, "Longitude"
assert len(response.json()["features"]) >= 200


@pytest.mark.usefixtures("mock_jwt_decode")
def test_get_station_details(client: TestClient, monkeypatch):
monkeypatch.setattr(ClientSession, "get", default_mock_client_get)

def mock_get_utc_now():
return datetime.fromtimestamp(1618870929583 / 1000, tz=timezone.utc)

monkeypatch.setattr(app.routers.stations, "get_utc_now", mock_get_utc_now)
expected_response = load_json_file(__file__)("test_stations_details_expected_response.json")

response: Response = client.get("/api/stations/details/")
assert response.status_code == 200
assert response.json() == expected_response
Loading