diff --git a/app/Dockerfile b/app/Dockerfile
index 76c48f2d..e28e593d 100644
--- a/app/Dockerfile
+++ b/app/Dockerfile
@@ -4,7 +4,7 @@
# The build stage that will be used to deploy to the various environments
# needs to be called `release` in order to integrate with the repo's
# top-level Makefile
-FROM python:3-slim AS base
+FROM python:3.11-slim AS base
# Install poetry, the package manager.
# https://python-poetry.org
@@ -57,7 +57,7 @@ COPY . /app
ENV HOST=0.0.0.0
# Run the application.
-CMD ["poetry", "run", "python", "-m", "src"]
+CMD ["poetry", "run", "python", "src/app.py"]
#---------
# Release
@@ -109,4 +109,4 @@ ENV HOST=0.0.0.0
USER ${RUN_USER}
# Run the application.
-CMD ["poetry", "run", "gunicorn", "src.app:create_app()"]
+CMD ["poetry", "run", "python", "src/app.py"]
diff --git a/app/local.env b/app/local.env
index 21c5a1b6..ac6d2e40 100644
--- a/app/local.env
+++ b/app/local.env
@@ -15,8 +15,6 @@ PYTHONPATH=/app/
# commands that can run in or out
# of the Docker container - defaults to outside
-FLASK_APP=src.app:create_app
-
############################
# Logging
############################
diff --git a/app/poetry.lock b/app/poetry.lock
index acb931c9..78072d73 100644
--- a/app/poetry.lock
+++ b/app/poetry.lock
@@ -97,16 +97,6 @@ tests = ["PyYAML (>=3.10)", "marshmallow (>=3.13.0)", "openapi-spec-validator (<
validation = ["openapi-spec-validator (<0.5)", "prance[osv] (>=0.11)"]
yaml = ["PyYAML (>=3.10)"]
-[[package]]
-name = "atomicwrites"
-version = "1.4.1"
-description = "Atomic file writes."
-optional = false
-python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
-files = [
- {file = "atomicwrites-1.4.1.tar.gz", hash = "sha256:81b2c9071a49367a7f770170e5eec8cb66567cfbbc8c73d20ce5ca4a8d71cf11"},
-]
-
[[package]]
name = "attrs"
version = "23.1.0"
@@ -1289,44 +1279,48 @@ xray = ["aws-xray-sdk (>=0.93,!=0.96)", "setuptools"]
[[package]]
name = "mypy"
-version = "0.971"
+version = "1.5.1"
description = "Optional static typing for Python"
optional = false
-python-versions = ">=3.6"
+python-versions = ">=3.8"
files = [
- {file = "mypy-0.971-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:f2899a3cbd394da157194f913a931edfd4be5f274a88041c9dc2d9cdcb1c315c"},
- {file = "mypy-0.971-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:98e02d56ebe93981c41211c05adb630d1d26c14195d04d95e49cd97dbc046dc5"},
- {file = "mypy-0.971-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:19830b7dba7d5356d3e26e2427a2ec91c994cd92d983142cbd025ebe81d69cf3"},
- {file = "mypy-0.971-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:02ef476f6dcb86e6f502ae39a16b93285fef97e7f1ff22932b657d1ef1f28655"},
- {file = "mypy-0.971-cp310-cp310-win_amd64.whl", hash = "sha256:25c5750ba5609a0c7550b73a33deb314ecfb559c350bb050b655505e8aed4103"},
- {file = "mypy-0.971-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:d3348e7eb2eea2472db611486846742d5d52d1290576de99d59edeb7cd4a42ca"},
- {file = "mypy-0.971-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:3fa7a477b9900be9b7dd4bab30a12759e5abe9586574ceb944bc29cddf8f0417"},
- {file = "mypy-0.971-cp36-cp36m-win_amd64.whl", hash = "sha256:2ad53cf9c3adc43cf3bea0a7d01a2f2e86db9fe7596dfecb4496a5dda63cbb09"},
- {file = "mypy-0.971-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:855048b6feb6dfe09d3353466004490b1872887150c5bb5caad7838b57328cc8"},
- {file = "mypy-0.971-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:23488a14a83bca6e54402c2e6435467a4138785df93ec85aeff64c6170077fb0"},
- {file = "mypy-0.971-cp37-cp37m-win_amd64.whl", hash = "sha256:4b21e5b1a70dfb972490035128f305c39bc4bc253f34e96a4adf9127cf943eb2"},
- {file = "mypy-0.971-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:9796a2ba7b4b538649caa5cecd398d873f4022ed2333ffde58eaf604c4d2cb27"},
- {file = "mypy-0.971-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:5a361d92635ad4ada1b1b2d3630fc2f53f2127d51cf2def9db83cba32e47c856"},
- {file = "mypy-0.971-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:b793b899f7cf563b1e7044a5c97361196b938e92f0a4343a5d27966a53d2ec71"},
- {file = "mypy-0.971-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:d1ea5d12c8e2d266b5fb8c7a5d2e9c0219fedfeb493b7ed60cd350322384ac27"},
- {file = "mypy-0.971-cp38-cp38-win_amd64.whl", hash = "sha256:23c7ff43fff4b0df93a186581885c8512bc50fc4d4910e0f838e35d6bb6b5e58"},
- {file = "mypy-0.971-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:1f7656b69974a6933e987ee8ffb951d836272d6c0f81d727f1d0e2696074d9e6"},
- {file = "mypy-0.971-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:d2022bfadb7a5c2ef410d6a7c9763188afdb7f3533f22a0a32be10d571ee4bbe"},
- {file = "mypy-0.971-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ef943c72a786b0f8d90fd76e9b39ce81fb7171172daf84bf43eaf937e9f220a9"},
- {file = "mypy-0.971-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:d744f72eb39f69312bc6c2abf8ff6656973120e2eb3f3ec4f758ed47e414a4bf"},
- {file = "mypy-0.971-cp39-cp39-win_amd64.whl", hash = "sha256:77a514ea15d3007d33a9e2157b0ba9c267496acf12a7f2b9b9f8446337aac5b0"},
- {file = "mypy-0.971-py3-none-any.whl", hash = "sha256:0d054ef16b071149917085f51f89555a576e2618d5d9dd70bd6eea6410af3ac9"},
- {file = "mypy-0.971.tar.gz", hash = "sha256:40b0f21484238269ae6a57200c807d80debc6459d444c0489a102d7c6a75fa56"},
+ {file = "mypy-1.5.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:f33592ddf9655a4894aef22d134de7393e95fcbdc2d15c1ab65828eee5c66c70"},
+ {file = "mypy-1.5.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:258b22210a4a258ccd077426c7a181d789d1121aca6db73a83f79372f5569ae0"},
+ {file = "mypy-1.5.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a9ec1f695f0c25986e6f7f8778e5ce61659063268836a38c951200c57479cc12"},
+ {file = "mypy-1.5.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:abed92d9c8f08643c7d831300b739562b0a6c9fcb028d211134fc9ab20ccad5d"},
+ {file = "mypy-1.5.1-cp310-cp310-win_amd64.whl", hash = "sha256:a156e6390944c265eb56afa67c74c0636f10283429171018446b732f1a05af25"},
+ {file = "mypy-1.5.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6ac9c21bfe7bc9f7f1b6fae441746e6a106e48fc9de530dea29e8cd37a2c0cc4"},
+ {file = "mypy-1.5.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:51cb1323064b1099e177098cb939eab2da42fea5d818d40113957ec954fc85f4"},
+ {file = "mypy-1.5.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:596fae69f2bfcb7305808c75c00f81fe2829b6236eadda536f00610ac5ec2243"},
+ {file = "mypy-1.5.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:32cb59609b0534f0bd67faebb6e022fe534bdb0e2ecab4290d683d248be1b275"},
+ {file = "mypy-1.5.1-cp311-cp311-win_amd64.whl", hash = "sha256:159aa9acb16086b79bbb0016145034a1a05360626046a929f84579ce1666b315"},
+ {file = "mypy-1.5.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:f6b0e77db9ff4fda74de7df13f30016a0a663928d669c9f2c057048ba44f09bb"},
+ {file = "mypy-1.5.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:26f71b535dfc158a71264e6dc805a9f8d2e60b67215ca0bfa26e2e1aa4d4d373"},
+ {file = "mypy-1.5.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2fc3a600f749b1008cc75e02b6fb3d4db8dbcca2d733030fe7a3b3502902f161"},
+ {file = "mypy-1.5.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:26fb32e4d4afa205b24bf645eddfbb36a1e17e995c5c99d6d00edb24b693406a"},
+ {file = "mypy-1.5.1-cp312-cp312-win_amd64.whl", hash = "sha256:82cb6193de9bbb3844bab4c7cf80e6227d5225cc7625b068a06d005d861ad5f1"},
+ {file = "mypy-1.5.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:4a465ea2ca12804d5b34bb056be3a29dc47aea5973b892d0417c6a10a40b2d65"},
+ {file = "mypy-1.5.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:9fece120dbb041771a63eb95e4896791386fe287fefb2837258925b8326d6160"},
+ {file = "mypy-1.5.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d28ddc3e3dfeab553e743e532fb95b4e6afad51d4706dd22f28e1e5e664828d2"},
+ {file = "mypy-1.5.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:57b10c56016adce71fba6bc6e9fd45d8083f74361f629390c556738565af8eeb"},
+ {file = "mypy-1.5.1-cp38-cp38-win_amd64.whl", hash = "sha256:ff0cedc84184115202475bbb46dd99f8dcb87fe24d5d0ddfc0fe6b8575c88d2f"},
+ {file = "mypy-1.5.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:8f772942d372c8cbac575be99f9cc9d9fb3bd95c8bc2de6c01411e2c84ebca8a"},
+ {file = "mypy-1.5.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:5d627124700b92b6bbaa99f27cbe615c8ea7b3402960f6372ea7d65faf376c14"},
+ {file = "mypy-1.5.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:361da43c4f5a96173220eb53340ace68cda81845cd88218f8862dfb0adc8cddb"},
+ {file = "mypy-1.5.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:330857f9507c24de5c5724235e66858f8364a0693894342485e543f5b07c8693"},
+ {file = "mypy-1.5.1-cp39-cp39-win_amd64.whl", hash = "sha256:c543214ffdd422623e9fedd0869166c2f16affe4ba37463975043ef7d2ea8770"},
+ {file = "mypy-1.5.1-py3-none-any.whl", hash = "sha256:f757063a83970d67c444f6e01d9550a7402322af3557ce7630d3c957386fa8f5"},
+ {file = "mypy-1.5.1.tar.gz", hash = "sha256:b031b9601f1060bf1281feab89697324726ba0c0bae9d7cd7ab4b690940f0b92"},
]
[package.dependencies]
-mypy-extensions = ">=0.4.3"
+mypy-extensions = ">=1.0.0"
tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""}
-typing-extensions = ">=3.10"
+typing-extensions = ">=4.1.0"
[package.extras]
dmypy = ["psutil (>=4.0)"]
-python2 = ["typed-ast (>=1.4.0,<2)"]
+install-types = ["pip"]
reports = ["lxml"]
[[package]]
@@ -1559,17 +1553,6 @@ files = [
{file = "psycopg_binary-3.1.10-cp39-cp39-win_amd64.whl", hash = "sha256:b30887e631fd67affaed98f6cd2135b44f2d1a6d9bca353a69c3889c78bd7aa8"},
]
-[[package]]
-name = "py"
-version = "1.11.0"
-description = "library with cross-python path, ini-parsing, io, code, log facilities"
-optional = false
-python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
-files = [
- {file = "py-1.11.0-py2.py3-none-any.whl", hash = "sha256:607c53218732647dff4acdfcd50cb62615cedf612e72d1724fb1a0cc6405b378"},
- {file = "py-1.11.0.tar.gz", hash = "sha256:51c75c4126074b472f746a24399ad32f6053d1b34b68d2fa41e558e6f4a98719"},
-]
-
[[package]]
name = "py-partiql-parser"
version = "0.3.7"
@@ -1797,27 +1780,25 @@ plugins = ["importlib-metadata"]
[[package]]
name = "pytest"
-version = "6.2.5"
+version = "7.4.2"
description = "pytest: simple powerful testing with Python"
optional = false
-python-versions = ">=3.6"
+python-versions = ">=3.7"
files = [
- {file = "pytest-6.2.5-py3-none-any.whl", hash = "sha256:7310f8d27bc79ced999e760ca304d69f6ba6c6649c0b60fb0e04a4a77cacc134"},
- {file = "pytest-6.2.5.tar.gz", hash = "sha256:131b36680866a76e6781d13f101efb86cf674ebb9762eb70d3082b6f29889e89"},
+ {file = "pytest-7.4.2-py3-none-any.whl", hash = "sha256:1d881c6124e08ff0a1bb75ba3ec0bfd8b5354a01c194ddd5a0a870a48d99b002"},
+ {file = "pytest-7.4.2.tar.gz", hash = "sha256:a766259cfab564a2ad52cb1aae1b881a75c3eb7e34ca3779697c23ed47c47069"},
]
[package.dependencies]
-atomicwrites = {version = ">=1.0", markers = "sys_platform == \"win32\""}
-attrs = ">=19.2.0"
colorama = {version = "*", markers = "sys_platform == \"win32\""}
+exceptiongroup = {version = ">=1.0.0rc8", markers = "python_version < \"3.11\""}
iniconfig = "*"
packaging = "*"
pluggy = ">=0.12,<2.0"
-py = ">=1.8.2"
-toml = "*"
+tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""}
[package.extras]
-testing = ["argcomplete", "hypothesis (>=3.56)", "mock", "nose", "requests", "xmlschema"]
+testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"]
[[package]]
name = "pytest-lazy-fixture"
@@ -2176,6 +2157,20 @@ anyio = ">=3.4.0,<5"
[package.extras]
full = ["httpx (>=0.22.0)", "itsdangerous", "jinja2", "python-multipart", "pyyaml"]
+[[package]]
+name = "starlette-context"
+version = "0.3.6"
+description = "Middleware for Starlette that allows you to store and access the context data of a request. Can be used with logging so logs automatically use request headers such as x-request-id or x-correlation-id."
+optional = false
+python-versions = ">=3.8,<4.0"
+files = [
+ {file = "starlette_context-0.3.6-py3-none-any.whl", hash = "sha256:b14ce373fbb6895a2182a7104b9f63ba20c8db83444005fb9a844dd77ad9895c"},
+ {file = "starlette_context-0.3.6.tar.gz", hash = "sha256:d361a36ba2d4acca3ab680f917b25e281533d725374752d47607a859041958cb"},
+]
+
+[package.dependencies]
+starlette = "*"
+
[[package]]
name = "stevedore"
version = "5.1.0"
@@ -2190,17 +2185,6 @@ files = [
[package.dependencies]
pbr = ">=2.0.0,<2.1.0 || >2.1.0"
-[[package]]
-name = "toml"
-version = "0.10.2"
-description = "Python Library for Tom's Obvious, Minimal Language"
-optional = false
-python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*"
-files = [
- {file = "toml-0.10.2-py2.py3-none-any.whl", hash = "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b"},
- {file = "toml-0.10.2.tar.gz", hash = "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f"},
-]
-
[[package]]
name = "tomli"
version = "2.0.1"
@@ -2650,4 +2634,4 @@ files = [
[metadata]
lock-version = "2.0"
python-versions = "^3.10"
-content-hash = "8a0dc94cf02cf1ff2ca09f4e5163d26ead0534f6648196b4188b8303fe54da95"
+content-hash = "b7560789c538007917c286a7d07a430fbcd01560b8dc27da4037a58d80e7a1cf"
diff --git a/app/pyproject.toml b/app/pyproject.toml
index 25233f03..284f754d 100644
--- a/app/pyproject.toml
+++ b/app/pyproject.toml
@@ -21,6 +21,7 @@ marshmallow = "^3.18.0"
gunicorn = "^21.2.0"
fastapi = {extras = ["all"], version = "^0.103.1"}
psycopg = {extras = ["binary"], version = "^3.1.10"}
+starlette-context = "^0.3.6"
[tool.poetry.group.dev.dependencies]
@@ -29,14 +30,14 @@ flake8 = "^5.0.4"
flake8-bugbear = "^22.8.23"
flake8-alfred = "^1.1.1"
isort = "^5.10.1"
-mypy = "^0.971"
+mypy = "1.5.1"
moto = {extras = ["s3"], version = "^4.0.2"}
types-pytz = "^2022.2.1"
coverage = "^6.4.4"
Faker = "^14.2.0"
factory-boy = "^3.2.1"
bandit = "^1.7.4"
-pytest = "^6.0.0"
+pytest = "7.4.2"
pytest-watch = "^4.2.0"
pytest-lazy-fixture = "^0.6.3"
types-pyyaml = "^6.0.12.11"
@@ -83,6 +84,9 @@ warn_redundant_casts = true
warn_unreachable = true
warn_unused_ignores = true
+plugins = ["pydantic.mypy"]
+
+
[tool.bandit]
# Ignore audit logging test file since test audit logging requires a lot of operations that trigger bandit warnings
exclude_dirs = ["./tests/src/logging/test_audit.py"]
diff --git a/app/setup.cfg b/app/setup.cfg
index 54346acf..0505d286 100644
--- a/app/setup.cfg
+++ b/app/setup.cfg
@@ -18,3 +18,9 @@ ignore =
E266,
# don't use bare except, B001 is more descriptive
E722
+
+# Tell Flake to not error when we do "def my_func(some_field = default_value)" when
+# the default value is one of the following types as these are used by FastAPI.
+# These are calls to classes and functions and are immutably safe in their usage.
+# https://stackoverflow.com/questions/73306462/flake8-throws-b008-fastapi-data-type-definitions
+extend-immutable-calls = Depends, Security, fastapi_db.DbSessionDependency
\ No newline at end of file
diff --git a/app/src/adapters/db/__init__.py b/app/src/adapters/db/__init__.py
index 591de851..913841dd 100644
--- a/app/src/adapters/db/__init__.py
+++ b/app/src/adapters/db/__init__.py
@@ -2,9 +2,9 @@
Database module.
This module contains the DBClient class, which is used to manage database connections.
-This module can be used on it's own or with an application framework such as Flask.
+This module can be used on it's own or with an application framework such as Fast API.
-To use this module with Flask, use the flask_db module.
+To use this module with FastAPI, use the fastapi_db module.
Usage:
import src.adapters.db as db
@@ -26,7 +26,7 @@
from src.adapters.db.client import Connection, DBClient, Session
from src.adapters.db.clients.postgres_client import PostgresDBClient
-# Do not import flask_db here, because this module is not dependent on any specific framework.
-# Code can choose to use this module on its own or with the flask_db module depending on needs.
+# Do not import fastapi_db here, because this module is not dependent on any specific framework.
+# Code can choose to use this module on its own or with the fastapi_db module depending on needs.
__all__ = ["Connection", "DBClient", "Session", "PostgresDBClient"]
diff --git a/app/src/adapters/db/client.py b/app/src/adapters/db/client.py
index 78b83261..251aa310 100644
--- a/app/src/adapters/db/client.py
+++ b/app/src/adapters/db/client.py
@@ -23,7 +23,7 @@
class DBClient(abc.ABC, metaclass=abc.ABCMeta):
"""Database connection manager.
- This class is used to manage database connections for the Flask app.
+ This class is used to manage database connections for the app.
It has methods for getting a new connection or session object.
A derived class must initialize _engine in the __init__ function
diff --git a/app/src/adapters/db/clients/postgres_config.py b/app/src/adapters/db/clients/postgres_config.py
index 961c20d6..41f4bebe 100644
--- a/app/src/adapters/db/clients/postgres_config.py
+++ b/app/src/adapters/db/clients/postgres_config.py
@@ -9,16 +9,16 @@
class PostgresDBConfig(PydanticBaseEnvConfig):
- check_connection_on_init: bool = Field(True, env="DB_CHECK_CONNECTION_ON_INIT")
- aws_region: Optional[str] = Field(None, env="AWS_REGION")
- host: str = Field(env="DB_HOST")
- name: str = Field(env="DB_NAME")
- username: str = Field(env="DB_USER")
- password: Optional[str] = Field(None, env="DB_PASSWORD")
- db_schema: str = Field("public", env="DB_SCHEMA")
- port: int = Field(5432, env="DB_PORT")
- hide_sql_parameter_logs: bool = Field(True, env="HIDE_SQL_PARAMETER_LOGS")
- ssl_mode: str = Field("require", env="DB_SSL_MODE")
+ check_connection_on_init: bool = Field(True, alias="DB_CHECK_CONNECTION_ON_INIT")
+ aws_region: Optional[str] = Field(None, alias="AWS_REGION")
+ host: str = Field(alias="DB_HOST")
+ name: str = Field(alias="DB_NAME")
+ username: str = Field(alias="DB_USER")
+ password: Optional[str] = Field(None, alias="DB_PASSWORD")
+ db_schema: str = Field("public", alias="DB_SCHEMA")
+ port: int = Field(5432, alias="DB_PORT")
+ hide_sql_parameter_logs: bool = Field(True, alias="HIDE_SQL_PARAMETER_LOGS")
+ ssl_mode: str = Field("require", alias="DB_SSL_MODE")
def get_db_config() -> PostgresDBConfig:
diff --git a/app/src/adapters/db/fastapi_db.py b/app/src/adapters/db/fastapi_db.py
new file mode 100644
index 00000000..922e57eb
--- /dev/null
+++ b/app/src/adapters/db/fastapi_db.py
@@ -0,0 +1,87 @@
+"""
+This module has functionality to extend FastAPI with a database client.
+
+To initialize this FastAPI extension, call register_db_client() with an instance
+of a FastAPI app and an instance of a DBClient.
+
+Example:
+ import src.adapters.db as db
+ import src.adapters.db.fastapi_db as fastapi_db
+
+ db_client = db.PostgresDBClient()
+ app = fastapi.FastAPI()
+ fastapi_db.register_db_client(db_client, app)
+
+Then, in a request handler, specify the DB session as a dependency to get a
+new database session that lasts for the duration of the request.
+
+Example:
+ import src.adapters.db as db
+ import src.adapters.db.fastapi_db as fastapi_db
+
+ @app.get("/health")
+ def health(db_session: db.Session = Depends(fastapi_db.DbSessionDependency())):
+ with db_session.begin():
+ ...
+
+The session can also be defined as an annotation in the typing like:
+`db_session: typing.Annotated[db.Session, fastapi.Depends(fastapi_db.DbSessionDependency())]`
+if you want to avoid having the function defined with a default value
+
+Alternatively, if you want to get the database client directly, use the get_db_client
+function which requires you have the application itself via the Request
+
+Example:
+ from fastapi import Request
+ import src.adapters.db.fastapi_db as fastapi_db
+
+ @app.get("/health")
+ def health(request: Request):
+ db_client = fastapi_db.get_db_client(request.app)
+ # db_client.get_connection() or db_client.get_session()
+"""
+
+from typing import Generator
+
+import fastapi
+
+import src.adapters.db as db
+
+_FASTAPI_KEY_PREFIX = "db"
+_DEFAULT_CLIENT_NAME = "default"
+
+
+def register_db_client(
+ db_client: db.DBClient, app: fastapi.FastAPI, client_name: str = _DEFAULT_CLIENT_NAME
+) -> None:
+ fastapi_state_key = f"{_FASTAPI_KEY_PREFIX}{client_name}"
+ setattr(app.state, fastapi_state_key, db_client)
+
+
+def get_db_client(app: fastapi.FastAPI, client_name: str = _DEFAULT_CLIENT_NAME) -> db.DBClient:
+ fastapi_state_key = f"{_FASTAPI_KEY_PREFIX}{client_name}"
+ return getattr(app.state, fastapi_state_key)
+
+
+class DbSessionDependency:
+ """
+ FastAPI dependency class that can be used to fetch a DB session::
+
+ import src.adapters.db as db
+ import src.adapters.db.fastapi_db as fastapi_db
+
+ @app.get("/health")
+ def health(db_session: db.Session = Depends(fastapi_db.DbSessionDependency())):
+ with db_session.begin():
+ ...
+
+ This approach to setting up a dependency allows us to take in a parameter (the client name)
+ See: https://fastapi.tiangolo.com/advanced/advanced-dependencies/#a-callable-instance
+ """
+
+ def __init__(self, client_name: str = _DEFAULT_CLIENT_NAME):
+ self.client_name = client_name
+
+ def __call__(self, request: fastapi.Request) -> Generator[db.Session, None, None]:
+ with get_db_client(request.app, self.client_name).get_session() as session:
+ yield session
diff --git a/app/src/adapters/db/flask_db.py b/app/src/adapters/db/flask_db.py
deleted file mode 100644
index 779e5ad4..00000000
--- a/app/src/adapters/db/flask_db.py
+++ /dev/null
@@ -1,124 +0,0 @@
-"""
-This module has functionality to extend Flask with a database client.
-
-To initialize this flask extension, call register_db_client() with an instance
-of a Flask app and an instance of a DBClient.
-
-Example:
- import src.adapters.db as db
- import src.adapters.db.flask_db as flask_db
-
- db_client = db.PostgresDBClient()
- app = APIFlask(__name__)
- flask_db.register_db_client(db_client, app)
-
-Then, in a request handler, use the with_db_session decorator to get a
-new database session that lasts for the duration of the request.
-
-Example:
- import src.adapters.db as db
- import src.adapters.db.flask_db as flask_db
-
- @app.route("/health")
- @flask_db.with_db_session
- def health(db_session: db.Session):
- with db_session.begin():
- ...
-
-
-Alternatively, if you want to get the database client directly, use the get_db
-function.
-
-Example:
- from flask import current_app
- import src.adapters.db.flask_db as flask_db
-
- @app.route("/health")
- def health():
- db_client = flask_db.get_db(current_app)
- # db_client.get_connection() or db_client.get_session()
-"""
-from functools import wraps
-from typing import Any, Callable, Concatenate, ParamSpec, TypeVar
-
-from flask import Flask, current_app
-
-import src.adapters.db as db
-from src.adapters.db.client import DBClient
-
-_FLASK_EXTENSION_KEY_PREFIX = "db"
-_DEFAULT_CLIENT_NAME = "default"
-
-
-def register_db_client(
- db_client: DBClient, app: Flask, client_name: str = _DEFAULT_CLIENT_NAME
-) -> None:
- """Initialize the Flask app.
-
- Add the database to the Flask app's extensions so that it can be
- accessed by request handlers using the current app context.
-
- If you use multiple DB clients, you can differentiate them by
- specifying a client_name.
-
- see get_db
- """
- flask_extension_key = f"{_FLASK_EXTENSION_KEY_PREFIX}{client_name}"
- app.extensions[flask_extension_key] = db_client
-
-
-def get_db(app: Flask, client_name: str = _DEFAULT_CLIENT_NAME) -> DBClient:
- """Get the database connection for the given Flask app.
-
- Use this in request handlers to access the database from the active Flask app.
-
- Specify the same client_name as used in register_db_client to get the correct client
-
- Example:
- from flask import current_app
- import src.adapters.db.flask_db as flask_db
-
- @app.route("/health")
- def health():
- db_client = flask_db.get_db(current_app)
- """
- flask_extension_key = f"{_FLASK_EXTENSION_KEY_PREFIX}{client_name}"
- return app.extensions[flask_extension_key]
-
-
-P = ParamSpec("P")
-T = TypeVar("T")
-
-
-def with_db_session(
- *, client_name: str = _DEFAULT_CLIENT_NAME
-) -> Callable[[Callable[Concatenate[db.Session, P], T]], Callable[P, T]]:
- """Decorator for functions that need a database session.
-
- This decorator will create a new session object and pass it to the function
- as the first positional argument. A transaction is not started automatically.
- To start a transaction use db_session.begin()
-
- Usage:
- @with_db_session()
- def foo(db_session: db.Session):
- ...
-
- @with_db_session()
- def bar(db_session: db.Session, x, y):
- ...
-
- @with_db_session(client_name="legacy_db")
- def fiz(db_session: db.Session, x, y, z):
- ...
- """
-
- def decorator(f: Callable[Concatenate[db.Session, P], T]) -> Callable[P, T]:
- @wraps(f)
- def wrapper(*args: Any, **kwargs: Any) -> T:
- with get_db(current_app, client_name=client_name).get_session() as session:
- return f(session, *args, **kwargs)
-
- return wrapper
-
- return decorator
diff --git a/app/src/api/healthcheck.py b/app/src/api/healthcheck.py
index 4dc5782b..60769670 100644
--- a/app/src/api/healthcheck.py
+++ b/app/src/api/healthcheck.py
@@ -1,33 +1,28 @@
import logging
-from typing import Tuple
-from apiflask import APIBlueprint
-from flask import current_app
+from fastapi import APIRouter, HTTPException, Request, status
+from pydantic import BaseModel, Field
from sqlalchemy import text
-from werkzeug.exceptions import ServiceUnavailable
-import src.adapters.db.flask_db as flask_db
-from src.api import response
-from src.api.schemas import request_schema
+import src.adapters.db.fastapi_db as fastapi_db
logger = logging.getLogger(__name__)
+healthcheck_router = APIRouter(tags=["healthcheck"])
-class HealthcheckSchema(request_schema.OrderedSchema):
- message: str
+class HealthcheckModel(BaseModel):
+ message: str = Field(examples=["Service healthy"])
-healthcheck_blueprint = APIBlueprint("healthcheck", __name__, tag="Health")
-
-@healthcheck_blueprint.get("/health")
-@healthcheck_blueprint.output(HealthcheckSchema)
-@healthcheck_blueprint.doc(responses=[200, ServiceUnavailable.code])
-def health() -> Tuple[dict, int]:
+@healthcheck_router.get("/health")
+def health(request: Request) -> HealthcheckModel:
try:
- with flask_db.get_db(current_app).get_connection() as conn:
+ with fastapi_db.get_db_client(request.app).get_connection() as conn:
assert conn.scalar(text("SELECT 1 AS healthy")) == 1
- return response.ApiResponse(message="Service healthy").asdict(), 200
- except Exception:
+ return HealthcheckModel(message="Service healthy")
+ except Exception as err:
logger.exception("Connection to DB failure")
- return response.ApiResponse(message="Service unavailable").asdict(), ServiceUnavailable.code
+ raise HTTPException(
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Service unavailable"
+ ) from err
diff --git a/app/src/api/index.py b/app/src/api/index.py
new file mode 100644
index 00000000..58e77429
--- /dev/null
+++ b/app/src/api/index.py
@@ -0,0 +1,20 @@
+from fastapi import APIRouter
+from fastapi.responses import HTMLResponse
+
+# This route won't appear on the OpenAPI docs
+index_router = APIRouter(include_in_schema=False)
+
+
+@index_router.get("/")
+def get_index() -> HTMLResponse:
+ content = """
+
+
+
Home
+
+ Home
+ Visit /docs to view the api documentation for this project.
+
+
+ """
+ return HTMLResponse(content=content)
diff --git a/app/src/api/response.py b/app/src/api/response.py
deleted file mode 100644
index 70e2ae8f..00000000
--- a/app/src/api/response.py
+++ /dev/null
@@ -1,48 +0,0 @@
-import dataclasses
-from typing import Optional
-
-from src.api.schemas import response_schema
-from src.db.models.base import Base
-
-
-@dataclasses.dataclass
-class ValidationErrorDetail:
- type: str
- message: str = ""
- rule: Optional[str] = None
- field: Optional[str] = None
- value: Optional[str] = None # Do not store PII data here, as it gets logged in some cases
-
-
-class ValidationException(Exception):
- __slots__ = ["errors", "message", "data"]
-
- def __init__(
- self,
- errors: list[ValidationErrorDetail],
- message: str = "Invalid request",
- data: Optional[dict | list[dict]] = None,
- ):
- self.errors = errors
- self.message = message
- self.data = data or {}
-
-
-@dataclasses.dataclass
-class ApiResponse:
- """Base response model for all API responses."""
-
- message: str
- data: Optional[Base] = None
- warnings: list[ValidationErrorDetail] = dataclasses.field(default_factory=list)
- errors: list[ValidationErrorDetail] = dataclasses.field(default_factory=list)
-
- # This method is used to convert ApiResponse objects to a dictionary
- # This is necessary because APIFlask has a bug that causes an exception to be
- # thrown when returning objects from routes when BASE_RESPONSE_SCHEMA is set
- # (See https://github.com/apiflask/apiflask/issues/384)
- # Once that issue is fixed, this method can be removed and routes can simply
- # return ApiResponse objects directly and allow APIFlask to serealize the objects
- # to JSON automatically.
- def asdict(self) -> dict:
- return response_schema.ResponseSchema().dump(self)
diff --git a/app/src/api/schemas/request_schema.py b/app/src/api/schemas/request_schema.py
deleted file mode 100644
index 08eca1d9..00000000
--- a/app/src/api/schemas/request_schema.py
+++ /dev/null
@@ -1,11 +0,0 @@
-import apiflask
-
-
-# Use ordered schemas to ensure that the order of fields in the generated OpenAPI
-# schema is deterministic.
-# This should no longer be needed once apiflask is updated to use ordered schemas
-# TODO update apiflask with fix to default to ordered schemas
-# See https://github.com/apiflask/apiflask/issues/385#issuecomment-1364463104
-class OrderedSchema(apiflask.Schema):
- class Meta:
- ordered = True
diff --git a/app/src/api/schemas/response_schema.py b/app/src/api/schemas/response_schema.py
deleted file mode 100644
index 5d89e85e..00000000
--- a/app/src/api/schemas/response_schema.py
+++ /dev/null
@@ -1,19 +0,0 @@
-from apiflask import fields
-
-from src.api.schemas import request_schema
-
-
-class ValidationErrorSchema(request_schema.OrderedSchema):
- type = fields.String(metadata={"description": "The type of error"})
- message = fields.String(metadata={"description": "The message to return"})
- rule = fields.String(metadata={"description": "The rule that failed"})
- field = fields.String(metadata={"description": "The field that failed"})
- value = fields.String(metadata={"description": "The value that failed"})
-
-
-class ResponseSchema(request_schema.OrderedSchema):
- message = fields.String(metadata={"description": "The message to return"})
- data = fields.Field(metadata={"description": "The REST resource object"}, dump_default={})
- status_code = fields.Integer(metadata={"description": "The HTTP status code"}, dump_default=200)
- warnings = fields.List(fields.Nested(ValidationErrorSchema), dump_default=[])
- errors = fields.List(fields.Nested(ValidationErrorSchema), dump_default=[])
diff --git a/app/src/api/users/__init__.py b/app/src/api/users/__init__.py
deleted file mode 100644
index 8ddfd549..00000000
--- a/app/src/api/users/__init__.py
+++ /dev/null
@@ -1,10 +0,0 @@
-from src.api.users.user_blueprint import user_blueprint
-
-# import user_commands module to register the CLI commands on the user_blueprint
-import src.api.users.user_commands # noqa: F401 E402 isort:skip
-
-# import user_commands module to register the API routes on the user_blueprint
-import src.api.users.user_routes # noqa: F401 E402 isort:skip
-
-
-__all__ = ["user_blueprint"]
diff --git a/app/src/api/users/user_blueprint.py b/app/src/api/users/user_blueprint.py
deleted file mode 100644
index 3155fe4e..00000000
--- a/app/src/api/users/user_blueprint.py
+++ /dev/null
@@ -1,3 +0,0 @@
-from apiflask import APIBlueprint
-
-user_blueprint = APIBlueprint("user", __name__, tag="User", cli_group="user")
diff --git a/app/src/api/users/user_commands.py b/app/src/api/users/user_commands.py
deleted file mode 100644
index 6e69481c..00000000
--- a/app/src/api/users/user_commands.py
+++ /dev/null
@@ -1,34 +0,0 @@
-import logging
-import os.path as path
-from typing import Optional
-
-import click
-
-import src.adapters.db as db
-import src.adapters.db.flask_db as flask_db
-import src.services.users as user_service
-from src.api.users.user_blueprint import user_blueprint
-from src.util.datetime_util import utcnow
-
-logger = logging.getLogger(__name__)
-
-user_blueprint.cli.help = "User commands"
-
-
-@user_blueprint.cli.command("create-csv", help="Create a CSV of all users and their roles")
-@flask_db.with_db_session()
-@click.option(
- "--dir",
- default=".",
- help="Directory to save output file in. Can be an S3 path (e.g. 's3://bucketname/folder/') Defaults to current directory ('.').",
-)
-@click.option(
- "--filename",
- default=None,
- help="Filename to save output file as. Defaults to '[timestamp]-user-roles.csv.",
-)
-def create_csv(db_session: db.Session, dir: str, filename: Optional[str]) -> None:
- if filename is None:
- filename = utcnow().strftime("%Y-%m-%d-%H-%M-%S") + "-user-roles.csv"
- filepath = path.join(dir, filename)
- user_service.create_user_csv(db_session, filepath)
diff --git a/app/src/api/users/user_routes.py b/app/src/api/users/user_routes.py
index 338ae313..7b61fe39 100644
--- a/app/src/api/users/user_routes.py
+++ b/app/src/api/users/user_routes.py
@@ -1,57 +1,57 @@
import logging
+import typing
+import uuid
from typing import Any
+import fastapi.exception_handlers
+from fastapi import APIRouter
+
import src.adapters.db as db
-import src.adapters.db.flask_db as flask_db
-import src.api.response as response
+import src.adapters.db.fastapi_db as fastapi_db
import src.api.users.user_schemas as user_schemas
import src.services.users as user_service
-import src.services.users as users
-from src.api.users.user_blueprint import user_blueprint
-from src.auth.api_key_auth import api_key_auth
+from src.auth.api_key_auth import verify_api_key
from src.db.models.user_models import User
logger = logging.getLogger(__name__)
-@user_blueprint.post("/v1/users")
-@user_blueprint.input(user_schemas.UserSchema)
-@user_blueprint.output(user_schemas.UserSchema, status_code=201)
-@user_blueprint.auth_required(api_key_auth)
-@flask_db.with_db_session()
-def user_post(db_session: db.Session, user_params: users.CreateUserParams) -> dict:
+user_router = APIRouter(tags=["user"], dependencies=[fastapi.Depends(verify_api_key)])
+
+
+@user_router.post("/v1/users", status_code=201, response_model=user_schemas.UserModelOut)
+def user_post(
+ db_session: typing.Annotated[db.Session, fastapi.Depends(fastapi_db.DbSessionDependency())],
+ user_model: user_schemas.UserModel,
+) -> User:
"""
POST /v1/users
"""
- user = user_service.create_user(db_session, user_params)
+ logger.info(user_model)
+ user = user_service.create_user(db_session, user_model)
logger.info("Successfully inserted user", extra=get_user_log_params(user))
- return response.ApiResponse(message="Success", data=user).asdict()
+ return user
-@user_blueprint.patch("/v1/users/")
-# Allow partial updates. partial=true means requests that are missing
-# required fields will not be rejected.
-# https://marshmallow.readthedocs.io/en/stable/quickstart.html#partial-loading
-@user_blueprint.input(user_schemas.UserSchema(partial=True))
-@user_blueprint.output(user_schemas.UserSchema)
-@user_blueprint.auth_required(api_key_auth)
-@flask_db.with_db_session()
+@user_router.patch("/v1/users/{user_id}", response_model=user_schemas.UserModelOut)
def user_patch(
- db_session: db.Session, user_id: str, patch_user_params: users.PatchUserParams
-) -> dict:
+ db_session: typing.Annotated[db.Session, fastapi.Depends(fastapi_db.DbSessionDependency())],
+ user_id: uuid.UUID,
+ patch_user_params: user_schemas.UserModelPatch,
+) -> User:
user = user_service.patch_user(db_session, user_id, patch_user_params)
logger.info("Successfully patched user", extra=get_user_log_params(user))
- return response.ApiResponse(message="Success", data=user).asdict()
+ return user
-@user_blueprint.get("/v1/users/")
-@user_blueprint.output(user_schemas.UserSchema)
-@user_blueprint.auth_required(api_key_auth)
-@flask_db.with_db_session()
-def user_get(db_session: db.Session, user_id: str) -> dict:
+@user_router.get("/v1/users/{user_id}", response_model=user_schemas.UserModelOut)
+def user_get(
+ db_session: typing.Annotated[db.Session, fastapi.Depends(fastapi_db.DbSessionDependency())],
+ user_id: uuid.UUID,
+) -> User:
user = user_service.get_user(db_session, user_id)
logger.info("Successfully fetched user", extra=get_user_log_params(user))
- return response.ApiResponse(message="Success", data=user).asdict()
+ return user
def get_user_log_params(user: User) -> dict[str, Any]:
diff --git a/app/src/api/users/user_schemas.py b/app/src/api/users/user_schemas.py
index b0fd2c14..b3872f73 100644
--- a/app/src/api/users/user_schemas.py
+++ b/app/src/api/users/user_schemas.py
@@ -1,44 +1,57 @@
-from apiflask import fields
-from marshmallow import fields as marshmallow_fields
+from datetime import date, datetime
+from typing import Any
+from uuid import UUID
+
+from pydantic import BaseModel, Field
-from src.api.schemas import request_schema
from src.db.models import user_models
-class RoleSchema(request_schema.OrderedSchema):
- type = marshmallow_fields.Enum(
- user_models.RoleType,
- by_value=True,
- metadata={"description": "The name of the role"},
- )
+class RoleModel(BaseModel):
+ type: user_models.RoleType
- # Note that user_id is not included in the API schema since the role
- # will always be a nested fields of the API user
-
-
-class UserSchema(request_schema.OrderedSchema):
- id = fields.UUID(dump_only=True)
- first_name = fields.String(metadata={"description": "The user's first name"}, required=True)
- middle_name = fields.String(metadata={"description": "The user's middle name"})
- last_name = fields.String(metadata={"description": "The user's last name"}, required=True)
- phone_number = fields.String(
- required=True,
- metadata={
- "description": "The user's phone number",
- "example": "123-456-7890",
- "pattern": r"^([0-9]|\*){3}\-([0-9]|\*){3}\-[0-9]{4}$",
- },
- )
- date_of_birth = fields.Date(
- metadata={"description": "The users date of birth"},
- required=True,
+
+class UserModel(BaseModel):
+
+ first_name: str
+ middle_name: str | None = None
+ last_name: str
+ phone_number: str = Field(
+ pattern=r"^([0-9]|\*){3}\-([0-9]|\*){3}\-[0-9]{4}$", examples=["123-456-7890"]
)
- is_active = fields.Boolean(
- metadata={"description": "Whether the user is active"},
- required=True,
+ date_of_birth: date
+ is_active: bool
+ roles: list[RoleModel]
+
+
+UNSET: Any = None
+
+
+class UserModelPatch(UserModel):
+ # TODO - this is unfortunately what we need to do for PATCH to work
+ # and is making me think PATCH should just not be used in favor of PUT endpoints
+ # which may also be simpler for front-ends as well.
+ #
+ # Ideas:
+ # - Just go with a PUT endpoint
+ # - Make this how the main model works, merge the two
+ # - Be fine with duplication?
+ #
+ first_name: str = UNSET
+ # middle_initial does not need to be redefined
+ last_name: str = UNSET
+ phone_number: str = Field(
+ pattern=r"^([0-9]|\*){3}\-([0-9]|\*){3}\-[0-9]{4}$",
+ examples=["123-456-7890"],
+ default=UNSET,
)
- roles = fields.List(fields.Nested(RoleSchema), required=True)
+ date_of_birth: date = UNSET
+ is_active: bool = UNSET
+ roles: list[RoleModel] = UNSET
+
+
+class UserModelOut(UserModel):
+ id: UUID
- # Output only fields in addition to id field
- created_at = fields.DateTime(dump_only=True)
- updated_at = fields.DateTime(dump_only=True)
+ created_at: datetime
+ updated_at: datetime
diff --git a/app/src/app.py b/app/src/app.py
index 2ca137d3..bc66af1a 100644
--- a/app/src/app.py
+++ b/app/src/app.py
@@ -1,97 +1,108 @@
import logging
-import os
-from typing import Optional
-from apiflask import APIFlask
-from flask import g
-from werkzeug.exceptions import Unauthorized
+import fastapi
+import fastapi.params
+import uvicorn
+from fastapi.encoders import jsonable_encoder
+from fastapi.exceptions import RequestValidationError, ResponseValidationError
+from fastapi.responses import JSONResponse
+from starlette_context.middleware import RawContextMiddleware
import src.adapters.db as db
-import src.adapters.db.flask_db as flask_db
+import src.adapters.db.fastapi_db as fastapi_db
import src.logger
-import src.logger.flask_logger as flask_logger
-from src.api.healthcheck import healthcheck_blueprint
-from src.api.schemas import response_schema
-from src.api.users import user_blueprint
-from src.auth.api_key_auth import User, get_app_security_scheme
+import src.logger.fastapi_logger as fastapi_logger
+from src.api.healthcheck import healthcheck_router
+from src.api.index import index_router
+from src.api.users.user_routes import user_router
+from src.app_config import AppConfig
logger = logging.getLogger(__name__)
-def create_app() -> APIFlask:
- app = APIFlask(__name__)
+def create_app() -> fastapi.FastAPI:
+ src.logger.init(__package__)
-<<<<<<< Updated upstream
- src.logging.init(__package__)
-=======
- root_logger = src.logging.init(__package__)
->>>>>>> Stashed changes
- flask_logger.init_app(logging.root, app)
+ app = fastapi.FastAPI(
+ # Fields which appear on the OpenAPI docs
+ # title appears at the top of the OpenAPI page as well attached to all logs as "app.name"
+ title="Template Application FastAPI",
+ description="Template for a FastAPI Application",
+ contact={
+ "name": "Nava PBC Engineering",
+ "url": "https://www.navapbc.com",
+ "email": "engineering@navapbc.com",
+ },
+ # Global dependencies of every API endpoint
+ dependencies=get_global_dependencies(),
+ )
+
+ fastapi_logger.init_app(logging.root, app)
+
+ configure_exception_handling(app)
db_client = db.PostgresDBClient()
- flask_db.register_db_client(db_client, app)
+ fastapi_db.register_db_client(db_client, app)
+
+ register_routers(app)
- configure_app(app)
- register_blueprints(app)
- register_index(app)
+ # Add this middleware last so that every other middleware has access to the context
+ # object we use to pass logging information around with. Middlewares work like a stack so
+ # the last one added is the first one executed.
+ app.add_middleware(RawContextMiddleware)
+ logger.info("Creating app")
return app
-def current_user(is_user_expected: bool = True) -> Optional[User]:
- current = g.get("current_user")
- if is_user_expected and current is None:
- logger.error("No current user found for request")
- raise Unauthorized
- return current
+def get_global_dependencies() -> list[fastapi.params.Depends]:
+ return [fastapi.Depends(fastapi_logger.add_url_rule_to_request_context)]
-def configure_app(app: APIFlask) -> None:
- # Modify the response schema to instead use the format of our ApiResponse class
- # which adds additional details to the object.
- # https://apiflask.com/schema/#base-response-schema-customization
- app.config["BASE_RESPONSE_SCHEMA"] = response_schema.ResponseSchema
+def register_routers(app: fastapi.FastAPI) -> None:
+ app.include_router(index_router)
+ app.include_router(healthcheck_router)
+ app.include_router(user_router)
- # Set a few values for the Swagger endpoint
- app.config["OPENAPI_VERSION"] = "3.0.3"
- # Set various general OpenAPI config values
- app.info = {
- "title": "Template Application Flask",
- "description": "Template API for a Flask Application",
- "contact": {
- "name": "Nava PBC Engineering",
- "url": "https://www.navapbc.com",
- "email": "engineering@navapbc.com",
- },
- }
+def configure_exception_handling(app: fastapi.FastAPI) -> None:
- # Set the security schema and define the header param
- # where we expect the API token to reside.
- # See: https://apiflask.com/authentication/#use-external-authentication-library
- app.security_schemes = get_app_security_scheme()
+ # Pydantic v2 by default returns a URL for each validation error pointing to their documentation
+ # which makes the errors overly verbose. We trim that out here before returning the error.
+ # See: https://fastapi.tiangolo.com/tutorial/handling-errors/?h=handlin#override-request-validation-exceptions
+ #
+ # If you would like to modify the format of the errors from Pydantic, you can restructure the response object here.
+ @app.exception_handler(RequestValidationError)
+ async def request_validation_exception_handler(
+ request: fastapi.Request, exc: RequestValidationError
+ ) -> JSONResponse:
+ # This matches the default request validation handler, but excludes the URL from the validation error.
+ return JSONResponse(
+ status_code=fastapi.status.HTTP_422_UNPROCESSABLE_ENTITY,
+ content={"detail": jsonable_encoder(exc.errors(), exclude={"url"})},
+ )
+ # TODO - do this a bit better, this is to prevent PII from being logged
+ # - should open a ticket against FastAPI as this looks like a relatively recent change:
+ # - https://github.com/tiangolo/fastapi/pull/10078
+ def override_method(self: ResponseValidationError) -> str:
+ return f"{len(self._errors)} validation errors:\n"
-def register_blueprints(app: APIFlask) -> None:
- app.register_blueprint(healthcheck_blueprint)
- app.register_blueprint(user_blueprint)
+ ResponseValidationError.__str__ = override_method # type: ignore
-def get_project_root_dir() -> str:
- return os.path.join(os.path.dirname(__file__), "..")
+if __name__ == "__main__":
+ app_config = AppConfig()
+ uvicorn_params: dict = {
+ "host": app_config.host,
+ "port": app_config.port,
+ "factory": True,
+ "app_dir": "src/",
+ }
+ if app_config.use_reloader:
+ uvicorn_params["reload"] = True
+ else:
+ uvicorn_params["workers"] = 4
-def register_index(app: APIFlask) -> None:
- @app.route("/")
- @app.doc(hide=True)
- def index() -> str:
- return """
-
-
- Home
-
- Home
- Visit /docs to view the api documentation for this project.
-
-
- """
+ uvicorn.run("app:create_app", **uvicorn_params)
diff --git a/app/src/app_config.py b/app/src/app_config.py
index aa780987..1c3a8339 100644
--- a/app/src/app_config.py
+++ b/app/src/app_config.py
@@ -6,6 +6,8 @@ class AppConfig(PydanticBaseEnvConfig):
# from accessing the application. This is especially important if you are
# running the application locally on a public network. This needs to be
# overriden to 0.0.0.0 when running in a container
- # See https://flask.palletsprojects.com/en/2.2.x/api/#flask.Flask.run
host: str = "127.0.0.1"
port: int = 8080
+
+ # TODO - figure out how we want this configured
+ use_reloader: bool = True
diff --git a/app/src/auth/api_key_auth.py b/app/src/auth/api_key_auth.py
index a1511f01..5806fac7 100644
--- a/app/src/auth/api_key_auth.py
+++ b/app/src/auth/api_key_auth.py
@@ -1,71 +1,15 @@
import logging
import os
-import uuid
-from dataclasses import dataclass
-from typing import Any
-import flask
-from apiflask import HTTPTokenAuth, abort
+from fastapi import HTTPException, Security, status
+from fastapi.security import APIKeyHeader
logger = logging.getLogger(__name__)
-# Initialize the authorization context
-# this needs to be attached to your
-# routes as `your_blueprint..auth_required(api_key_auth)`
-# in order to enable authorization
-api_key_auth = HTTPTokenAuth("ApiKey", header="X-Auth")
+API_KEY_HEADER = APIKeyHeader(name="X-Auth")
-def get_app_security_scheme() -> dict[str, Any]:
- return {"ApiKeyAuth": {"type": "apiKey", "in": "header", "name": "X-Auth"}}
-
-
-@dataclass
-class User:
- # WARNING: This is a very rudimentary
- # user for demo purposes and is not
- # a production ready approach. It exists
- # purely to define a rough structure / example
- id: uuid.UUID
- sub_id: str
- username: str
-
- def as_dict(self) -> dict[str, Any]:
- # Connexion expects a dictionary it can
- # use .get() on, so convert this to that format
- return {"uid": self.id, "sub": self.sub_id}
-
- def get_user_log_attributes(self) -> dict:
- # Note this gets called during authentication
- # to attach the information to the flask global object
- # which will in turn be attached to the log record
- return {"current_user.id": str(self.id)}
-
-
-API_AUTH_USER = User(uuid.uuid4(), "sub_id_1234", "API auth user")
-
-
-@api_key_auth.verify_token
-def verify_token(token: str) -> dict:
- logger.info("Authenticating provided token")
-
- user = process_token(token)
-
- # Note that the current user can also be found
- # by doing api_key_auth.current_user once in
- # the request context. This is here in case
- # multiple authentication approaches exist
- # in your API, you don't need to check each
- # one in order to figure out which was actually used
- flask.g.current_user = user
- flask.g.current_user_log_attributes = user.get_user_log_attributes()
-
- logger.info("Authentication successful")
-
- return user.as_dict()
-
-
-def process_token(token: str) -> User:
+def verify_api_key(token: str = Security(API_KEY_HEADER)) -> str:
# WARNING: this isn't really a production ready
# auth approach. In reality the user object returned
# here should be pulled from a DB or auth service, but
@@ -74,15 +18,15 @@ def process_token(token: str) -> User:
expected_auth_token = os.getenv("API_AUTH_TOKEN", None)
if not expected_auth_token:
- logger.info(
- "Authentication is not setup, please add an API_AUTH_TOKEN environment variable."
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail="Authentication is not setup, please add an API_AUTH_TOKEN environment variable.",
)
- abort(401, "Authentication is not setup properly and the user cannot be authenticated")
if token != expected_auth_token:
- logger.info("Authentication failed for provided auth token.")
- abort(
- 401, "The server could not verify that you are authorized to access the URL requested"
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail="The server could not verify that you are authorized to access the URL requested",
)
- return API_AUTH_USER
+ return token
diff --git a/app/src/db/migrations/env.py b/app/src/db/migrations/env.py
index 29a80a46..d7134304 100644
--- a/app/src/db/migrations/env.py
+++ b/app/src/db/migrations/env.py
@@ -15,17 +15,7 @@
logger = logging.getLogger("migrations")
# Initialize logging
-with src.logging.init("migrations"):
-<<<<<<< Updated upstream
-=======
-
- if not config.get_main_option("sqlalchemy.url"):
- uri = make_connection_uri(get_db_config())
-
- # Escape percentage signs in the URI.
- # https://alembic.sqlalchemy.org/en/latest/api/config.html#alembic.config.Config.set_main_option
- config.set_main_option("sqlalchemy.url", uri.replace("%", "%%"))
->>>>>>> Stashed changes
+with src.logger.init("migrations"):
# add your model's MetaData object here
# for 'autogenerate' support
diff --git a/app/src/fastapi_app.py b/app/src/fastapi_app.py
deleted file mode 100644
index d6fe27b2..00000000
--- a/app/src/fastapi_app.py
+++ /dev/null
@@ -1,38 +0,0 @@
-import fastapi
-import logging
-import logger as our_logging
-from pydantic import BaseModel, Field
-
-logger = logging.getLogger(__name__)
-
-def create_app() -> fastapi.FastAPI:
- our_logging.init(__package__)
-
- logger.info("Creating app")
- return fastapi.FastAPI()
-
-app = create_app()
-
-@app.get("/")
-def root():
- logger.info("hello - in the root")
- return {"message": "hello"}
-
-class NestedA(BaseModel):
- x: str = Field(examples=["hello there again"])
-
-
-class NestedB(BaseModel):
- y: list[NestedA]
-
-class Item(BaseModel):
- name: str
- description: str | None = None
- price: float
- tax: float | None = None
- z: list[NestedB]
-
-@app.post("/items/")
-def create_item(item: Item) -> Item:
- logger.info(item)
- return item
\ No newline at end of file
diff --git a/app/src/logger/__init__.py b/app/src/logger/__init__.py
index 30f0078d..cda6e6aa 100644
--- a/app/src/logger/__init__.py
+++ b/app/src/logger/__init__.py
@@ -3,15 +3,15 @@
There are two formatters for the log messages: human-readable and JSON.
The formatter that is used is determined by the environment variable
LOG_FORMAT. If the environment variable is not set, the JSON formatter
-is used by default. See src.logging.formatters for more information.
+is used by default. See src.logger.formatters for more information.
The logger also adds a PII mask filter to the root logger. See
-src.logging.pii for more information.
+src.logger.pii for more information.
Usage:
- import src.logging
+ import src.logger
- with src.logging.init("program name"):
+ with src.logger.init("program name"):
...
Once the module has been initialized, the standard logging module can be
@@ -28,6 +28,6 @@
import src.logger.config as config
-def init(program_name: str) -> config.LoggingContext:
+def init(program_name: str) -> config.LoggingContext:
return config.LoggingContext(program_name)
diff --git a/app/src/logger/audit.py b/app/src/logger/audit.py
index 49b29a21..0ae27728 100644
--- a/app/src/logger/audit.py
+++ b/app/src/logger/audit.py
@@ -72,7 +72,9 @@ def handle_audit_event(event_name: str, args: tuple[Any, ...]) -> None:
def log_audit_event(event_name: str, args: Sequence[Any], arg_names: Sequence[str]) -> None:
"""Log a message but only log recently repeated messages at intervals."""
extra = {
- f"audit.args.{arg_name}": arg for arg_name, arg in zip(arg_names, args) if arg_name != "_"
+ f"audit.args.{arg_name}": arg
+ for arg_name, arg in zip(arg_names, args, strict=True)
+ if arg_name != "_"
}
key = (event_name, repr(args))
diff --git a/app/src/logger/config.py b/app/src/logger/config.py
index 587afd8d..7cc70687 100644
--- a/app/src/logger/config.py
+++ b/app/src/logger/config.py
@@ -5,6 +5,8 @@
import sys
from typing import Any, ContextManager, cast
+from pydantic_settings import SettingsConfigDict
+
import src.logger.audit
import src.logger.formatters as formatters
import src.logger.pii as pii
@@ -23,11 +25,9 @@ class LoggingConfig(PydanticBaseEnvConfig):
format: str = "json"
level: str = "INFO"
enable_audit: bool = False
- human_readable_formatter: PydanticBaseEnvConfig = HumanReadableFormatterConfig()
+ human_readable_formatter: HumanReadableFormatterConfig = HumanReadableFormatterConfig()
- class Config:
- env_prefix = "log_"
- env_nested_delimiter = "__"
+ model_config = SettingsConfigDict(env_prefix="log_", env_nested_delimiter="__")
class LoggingContext(ContextManager[None]):
@@ -97,7 +97,7 @@ def _configure_logging(self) -> None:
logging.root.setLevel(config.level)
if config.enable_audit:
- src.logging.audit.init()
+ src.logger.audit.init()
# Configure loggers for third party packages
logging.getLogger("alembic").setLevel(logging.INFO)
@@ -105,6 +105,12 @@ def _configure_logging(self) -> None:
logging.getLogger("sqlalchemy.pool").setLevel(logging.INFO)
logging.getLogger("sqlalchemy.dialects.postgresql").setLevel(logging.INFO)
+ # Uvicorn sets up its own handlers with different formatting
+ # To keep things consistent, we override their handler to be ours
+ # Alternatively we could configure all of this directly, but that is significantly more configuration.
+ logging.getLogger("uvicorn").handlers = [self.console_handler]
+ logging.getLogger("uvicorn.access").handlers = [self.console_handler]
+
def get_formatter(config: LoggingConfig) -> logging.Formatter:
"""Return the formatter used by the root logger.
diff --git a/app/src/logger/fastapi_logger.py b/app/src/logger/fastapi_logger.py
new file mode 100644
index 00000000..f41195d1
--- /dev/null
+++ b/app/src/logger/fastapi_logger.py
@@ -0,0 +1,161 @@
+"""Module for adding standard logging functionality to a FastAPI app.
+
+This module configures an application's logger to add extra data
+to all log messages. FastAPI application context data such as the
+app name and request context data such as the request method, request url
+rule, and query parameters are added to the log record.
+
+This module also configures the FastAPI application to log every requests start and end.
+
+Usage:
+ import src.logger.fastapi_logger as fastapi_logger
+
+ logger = logging.getLogger(__name__)
+ app = create_app()
+
+ fastapi_logger.init_app(logger, app)
+"""
+import logging
+import time
+import typing
+
+import fastapi
+import starlette_context
+
+logger = logging.getLogger(__name__)
+EXTRA_LOG_DATA_ATTR = "extra_log_data"
+
+
+def init_app(app_logger: logging.Logger, app: fastapi.FastAPI) -> None:
+ """Initialize the FastAPI app logger.
+
+ Adds FastAPI app context data and FastAPI request context data
+ to every log record using log filters.
+ See https://docs.python.org/3/howto/logging-cookbook.html#using-filters-to-impart-contextual-information
+
+ Also configures the app to log every non-404 request using the given logger.
+
+ Usage:
+ import src.logger.fastapi_logger as fastapi_logger
+
+ logger = logging.getLogger(__name__)
+ app = create_app()
+
+ fastapi_logger.init_app(logger, app)
+ """
+
+ # Need to add filters to each of the handlers rather than to the logger itself, since
+ # messages are passed directly to the ancestor loggers’ handlers bypassing any filters
+ # set on the ancestors.
+ # See https://docs.python.org/3/library/logging.html#logging.Logger.propagate
+ for handler in app_logger.handlers:
+ handler.addFilter(_add_request_context_info_to_log_record)
+
+ _add_logging_middleware(app)
+
+ logger.info("Initialized Fast API logger")
+
+
+def _add_logging_middleware(app: fastapi.FastAPI) -> None:
+ """
+ Add middleware that runs before we start processing a request to add
+ additional context to log messages and automatically log the start/end
+
+ IMPORTANT: These are defined in the reverse-order that they execute
+ as middleware work like a stack.
+ """
+
+ # Log the start/end of a request and include the timing.
+ @app.middleware("http")
+ async def log_start_and_end_of_request(
+ request: fastapi.Request,
+ call_next: typing.Callable[[fastapi.Request], typing.Awaitable[fastapi.Response]],
+ ) -> fastapi.Response:
+ request_start_time = time.perf_counter()
+
+ logger.info("start request")
+ response = await call_next(request)
+
+ logger.info(
+ "end request",
+ extra={
+ "response.status_code": response.status_code,
+ "response.content_length": response.headers.get("content-length", None),
+ "response.content_type": response.headers.get("content-type", None),
+ "response.charset": response.charset,
+ "response.time_ms": (time.perf_counter() - request_start_time) * 1000,
+ },
+ )
+ return response
+
+ # Add general information regarding the request (route, request ID, method) to all
+ # log messages for the lifecycle of the request.
+ @app.middleware("http")
+ async def attach_request_context_info(
+ request: fastapi.Request,
+ call_next: typing.Callable[[fastapi.Request], typing.Awaitable[fastapi.Response]],
+ ) -> fastapi.Response:
+ # This will be added to all log messages for the rest of the request lifecycle
+ data = {
+ "app.name": request.app.title,
+ "request.id": request.headers.get("x-amzn-requestid", ""),
+ "request.method": request.method,
+ "request.path": request.scope.get("path", ""),
+ # Starlette does not resolve the URL rule (ie. the specific route) until after
+ # middleware runs, so the url_rule cannot be added here, see: add_url_rule_to_request_context for where that happens
+ }
+
+ # Add query parameter data in the format request.query. =
+ # For example, the query string ?foo=bar&baz=qux would be added as
+ # request.query.foo = bar and request.query.baz = qux
+ # PII should be kept out of the URL, as URLs are logged in access logs.
+ # With that assumption, it is safe to log query parameters.
+ for key, value in request.query_params.items():
+ data[f"request.query.{key}"] = value
+
+ add_extra_data_to_current_request_logs(data)
+
+ return await call_next(request)
+
+
+def add_url_rule_to_request_context(request: fastapi.Request) -> None:
+ """
+ Starlette, the underlying routing library that FastAPI does not determine the
+ route that will handle a request until after all middleware run. This method instead
+ relies on being used as a dependency (eg. FastAPI(dependencies=[Depends(add_url_rule_to_request_context)])
+ which will make it always run for every route, but after all of the middlewares.
+
+ See: https://github.com/encode/starlette/issues/685 which describes the issue in Starlette
+ """
+ url_rule = ""
+
+ api_route = request.scope.get("route", None)
+ if api_route:
+ url_rule = api_route.path
+
+ add_extra_data_to_current_request_logs({"request.url_rule": url_rule})
+
+
+def _add_request_context_info_to_log_record(record: logging.LogRecord) -> bool:
+ """Add request context data to the log record.
+
+ If there is no request context, then do not add any data.
+ """
+ if not starlette_context.context.exists():
+ return True
+
+ extra_log_data: dict[str, str] = starlette_context.context.get(EXTRA_LOG_DATA_ATTR, {})
+ record.__dict__.update(extra_log_data)
+
+ return True
+
+
+def add_extra_data_to_current_request_logs(
+ data: dict[str, str | int | float | bool | None]
+) -> None:
+ """Add data to every log record for the current request."""
+ assert starlette_context.context.exists(), "Must be in a request context"
+
+ extra_log_data = starlette_context.context.get(EXTRA_LOG_DATA_ATTR, {})
+ extra_log_data.update(data)
+ starlette_context.context[EXTRA_LOG_DATA_ATTR] = extra_log_data
diff --git a/app/src/logger/flask_logger.py b/app/src/logger/flask_logger.py
deleted file mode 100644
index 6aefa833..00000000
--- a/app/src/logger/flask_logger.py
+++ /dev/null
@@ -1,166 +0,0 @@
-"""Module for adding standard logging functionality to a Flask app.
-
-This module configures an application's logger to add extra data
-to all log messages. Flask application context data such as the
-app name and request context data such as the request method, request url
-rule, and query parameters are added to the log record.
-
-This module also configures the Flask application to log every
-non-404 request.
-
-Usage:
- import src.logging.flask_logger as flask_logger
-
- logger = logging.getLogger(__name__)
- app = create_app()
-
- flask_logger.init_app(logger, app)
-"""
-import logging
-import time
-
-import flask
-
-logger = logging.getLogger(__name__)
-EXTRA_LOG_DATA_ATTR = "extra_log_data"
-
-
-def init_app(app_logger: logging.Logger, app: flask.Flask) -> None:
- """Initialize the Flask app logger.
-
- Adds Flask app context data and Flask request context data
- to every log record using log filters.
- See https://docs.python.org/3/howto/logging-cookbook.html#using-filters-to-impart-contextual-information
-
- Also configures the app to log every non-404 request using the given logger.
-
- Usage:
- import src.logging.flask_logger as flask_logger
-
- logger = logging.getLogger(__name__)
- app = create_app()
-
- flask_logger.init_app(logger, app)
- """
-
- # Need to add filters to each of the handlers rather than to the logger itself, since
- # messages are passed directly to the ancestor loggers’ handlers bypassing any filters
- # set on the ancestors.
- # See https://docs.python.org/3/library/logging.html#logging.Logger.propagate
- for handler in app_logger.handlers:
- handler.addFilter(_add_app_context_info_to_log_record)
- handler.addFilter(_add_request_context_info_to_log_record)
-
- # Add request context data to every log record for the current request
- # such as request id, request method, request path, and the matching Flask request url rule
- app.before_request(
- lambda: add_extra_data_to_current_request_logs(_get_request_context_info(flask.request))
- )
-
- app.before_request(_track_request_start_time)
- app.before_request(_log_start_request)
- app.after_request(_log_end_request)
-
- app_logger.info("initialized flask logger")
-
-
-def add_extra_data_to_current_request_logs(
- data: dict[str, str | int | float | bool | None]
-) -> None:
- """Add data to every log record for the current request."""
- assert flask.has_request_context(), "Must be in a request context"
-
- extra_log_data = getattr(flask.g, EXTRA_LOG_DATA_ATTR, {})
- extra_log_data.update(data)
- setattr(flask.g, EXTRA_LOG_DATA_ATTR, extra_log_data)
-
-
-def _track_request_start_time() -> None:
- """Store the request start time in flask.g"""
- flask.g.request_start_time = time.perf_counter()
-
-
-def _log_start_request() -> None:
- """Log the start of a request.
-
- This function handles the Flask's before_request event.
- See https://tedboy.github.io/flask/interface_src.application_object.html#flask.Flask.before_request
-
- Additional info about the request will be in the `extra` field
- added by `_add_request_context_info_to_log_record`
- """
- logger.info("start request")
-
-
-def _log_end_request(response: flask.Response) -> flask.Response:
- """Log the end of a request.
-
- This function handles the Flask's after_request event.
- See https://tedboy.github.io/flask/interface_src.application_object.html#flask.Flask.after_request
-
- Additional info about the request will be in the `extra` field
- added by `_add_request_context_info_to_log_record`
- """
-
- logger.info(
- "end request",
- extra={
- "response.status_code": response.status_code,
- "response.content_length": response.content_length,
- "response.content_type": response.content_type,
- "response.mimetype": response.mimetype,
- "response.time_ms": (time.perf_counter() - flask.g.request_start_time) * 1000,
- },
- )
- return response
-
-
-def _add_app_context_info_to_log_record(record: logging.LogRecord) -> bool:
- """Add app context data to the log record.
-
- If there is no app context, then do not add any data.
- """
- if not flask.has_app_context():
- return True
-
- assert flask.current_app is not None
- record.__dict__ |= _get_app_context_info(flask.current_app)
-
- return True
-
-
-def _add_request_context_info_to_log_record(record: logging.LogRecord) -> bool:
- """Add request context data to the log record.
-
- If there is no request context, then do not add any data.
- """
- if not flask.has_request_context():
- return True
-
- assert flask.request is not None
- extra_log_data: dict[str, str] = getattr(flask.g, EXTRA_LOG_DATA_ATTR, {})
- record.__dict__.update(extra_log_data)
-
- return True
-
-
-def _get_app_context_info(app: flask.Flask) -> dict:
- return {"app.name": app.name}
-
-
-def _get_request_context_info(request: flask.Request) -> dict:
- data = {
- "request.id": request.headers.get("x-amzn-requestid", ""),
- "request.method": request.method,
- "request.path": request.path,
- "request.url_rule": str(request.url_rule),
- }
-
- # Add query parameter data in the format request.query. =
- # For example, the query string ?foo=bar&baz=qux would be added as
- # request.query.foo = bar and request.query.baz = qux
- # PII should be kept out of the URL, as URLs are logged in access logs.
- # With that assumption, it is safe to log query parameters.
- for key, value in request.args.items():
- data[f"request.query.{key}"] = value
- return data
diff --git a/app/src/logger/pii.py b/app/src/logger/pii.py
index 7eb7fb14..e8c7ffe5 100644
--- a/app/src/logger/pii.py
+++ b/app/src/logger/pii.py
@@ -8,7 +8,7 @@
Example:
import logging
- import src.logging.pii as pii
+ import src.logger.pii as pii
handler = logging.StreamHandler()
handler.addFilter(pii.mask_pii)
@@ -22,7 +22,7 @@
Example:
import logging
- import src.logging.pii as pii
+ import src.logger.pii as pii
logger = logging.getLogger(__name__)
logger.addFilter(pii.mask_pii)
diff --git a/app/src/logging/__init__.py b/app/src/logging/__init__.py
deleted file mode 100644
index 3d6fe9d2..00000000
--- a/app/src/logging/__init__.py
+++ /dev/null
@@ -1,86 +0,0 @@
-"""Module for initializing logging configuration for the application.
-
-There are two formatters for the log messages: human-readable and JSON.
-The formatter that is used is determined by the environment variable
-LOG_FORMAT. If the environment variable is not set, the JSON formatter
-is used by default. See src.logging.formatters for more information.
-
-The logger also adds a PII mask filter to the root logger. See
-src.logging.pii for more information.
-
-Usage:
- import src.logging
-
- with src.logging.init("program name"):
- ...
-
-Once the module has been initialized, the standard logging module can be
-used to log messages:
-
-Example:
- import logging
-
- logger = logging.getLogger(__name__)
- logger.info("message")
-"""
-from contextlib import contextmanager
-import logging
-import os
-import platform
-import pwd
-import sys
-from typing import Any, ContextManager, cast
-
-import src.logging.config as config
-
-logger = logging.getLogger(__name__)
-_original_argv = tuple(sys.argv)
-
-
-class Log:
- def __init__(self, program_name: str) -> None:
- self.program_name = program_name
- self.root_logger, self.stream_handler = config.configure_logging()
- log_program_info(self.program_name)
-
- def __enter__(self) -> logging.Logger:
- return self.root_logger
-
- def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
- self.root_logger.removeHandler(self.stream_handler)
-
-@contextmanager
-def init(program_name: str):
- stream_handler = config.configure_logging()
- log_program_info(program_name)
- yield
- logging.root.removeHandler(stream_handler)
-
-def log_program_info(program_name: str) -> None:
- logger.info(
- "start %s: %s %s %s, hostname %s, pid %i, user %i(%s)",
- program_name,
- platform.python_implementation(),
- platform.python_version(),
- platform.system(),
- platform.node(),
- os.getpid(),
- os.getuid(),
- pwd.getpwuid(os.getuid()).pw_name,
- extra={
- "hostname": platform.node(),
- "cpu_count": os.cpu_count(),
- # If mypy is run on a mac, it will throw a module has no attribute error, even though
- # we never actually access it with the conditional.
- #
- # However, we can't just silence this error, because on linux (e.g. CI/CD) that will
- # throw an unused “type: ignore” comment error. Casting to Any instead ensures this
- # passes regardless of where mypy is being run
- "cpu_usable": (
- len(cast(Any, os).sched_getaffinity(0))
- if "sched_getaffinity" in dir(os)
- else "unknown"
- ),
- },
- )
- logger.info("invoked as: %s", " ".join(_original_argv))
diff --git a/app/src/logging/config.py b/app/src/logging/config.py
deleted file mode 100644
index d79b26fc..00000000
--- a/app/src/logging/config.py
+++ /dev/null
@@ -1,84 +0,0 @@
-from contextlib import contextmanager
-import logging
-import sys
-from typing import Generator, Tuple
-
-import src.logging.audit
-import src.logging.formatters as formatters
-import src.logging.pii as pii
-from src.util.env_config import PydanticBaseEnvConfig
-
-logger = logging.getLogger(__name__)
-
-is_initialized = False
-
-
-class HumanReadableFormatterConfig(PydanticBaseEnvConfig):
- message_width: int = formatters.HUMAN_READABLE_FORMATTER_DEFAULT_MESSAGE_WIDTH
-
-
-class LoggingConfig(PydanticBaseEnvConfig):
- format = "json"
- level = "INFO"
- enable_audit = True
- human_readable_formatter = HumanReadableFormatterConfig()
-
- class Config:
- env_prefix = "log_"
- env_nested_delimiter = "__"
-
-@contextmanager
-def configure_logging() -> logging.Handler:
- """Configure logging for the application.
-
- Configures the root module logger to log to stdout.
- Adds a PII mask filter to the root logger.
- Also configures log levels third party packages.
- """
- config = LoggingConfig()
-
- # Loggers can be configured using config functions defined
- # in logging.config or by directly making calls to the main API
- # of the logging module (see https://docs.python.org/3/library/logging.config.html)
- # We opt to use the main API using functions like `addHandler` which is
- # non-destructive, i.e. it does not overwrite any existing handlers.
- # In contrast, logging.config.dictConfig() would overwrite any existing loggers.
- # This is important during testing, since fixtures like `caplog` add handlers that would
- # get overwritten if we call logging.config.dictConfig() during the scope of the test.
- console_handler = logging.StreamHandler(sys.stdout)
- formatter = get_formatter(config)
- console_handler.setFormatter(formatter)
- console_handler.addFilter(pii.mask_pii)
- logging.root.removeHandler(console_handler)
- logging.root.addHandler(console_handler)
- logging.root.setLevel(config.level)
-
- if config.enable_audit:
- src.logging.audit.init()
-
- # Configure loggers for third party packages
- logging.getLogger("alembic").setLevel(logging.INFO)
- logging.getLogger("werkzeug").setLevel(logging.WARN)
- logging.getLogger("sqlalchemy.pool").setLevel(logging.INFO)
- logging.getLogger("sqlalchemy.dialects.postgresql").setLevel(logging.INFO)
-
- return console_handler
-
-
-
-def get_formatter(config: LoggingConfig) -> logging.Formatter:
- """Return the formatter used by the root logger.
-
- The formatter is determined by the environment variable LOG_FORMAT. If the
- environment variable is not set, the JSON formatter is used by default.
- """
- if config.format == "human-readable":
- return get_human_readable_formatter(config.human_readable_formatter)
- return formatters.JsonFormatter()
-
-
-def get_human_readable_formatter(
- config: HumanReadableFormatterConfig,
-) -> formatters.HumanReadableFormatter:
- """Return the human readable formatter used by the root logger."""
- return formatters.HumanReadableFormatter(message_width=config.message_width)
diff --git a/app/src/services/users/__init__.py b/app/src/services/users/__init__.py
index 9d01b20f..9a001819 100644
--- a/app/src/services/users/__init__.py
+++ b/app/src/services/users/__init__.py
@@ -1,12 +1,9 @@
-from .create_user import CreateUserParams, RoleParams, create_user
+from .create_user import create_user
from .create_user_csv import create_user_csv
from .get_user import get_user
-from .patch_user import PatchUserParams, patch_user
+from .patch_user import patch_user
__all__ = [
- "CreateUserParams",
- "PatchUserParams",
- "RoleParams",
"create_user",
"get_user",
"patch_user",
diff --git a/app/src/services/users/create_user.py b/app/src/services/users/create_user.py
index 9721594a..53050ac3 100644
--- a/app/src/services/users/create_user.py
+++ b/app/src/services/users/create_user.py
@@ -1,40 +1,21 @@
-from datetime import date
-from typing import TypedDict
-
from src.adapters.db import Session
-from src.db.models import user_models
+from src.api.users.user_schemas import UserModel
from src.db.models.user_models import Role, User
-class RoleParams(TypedDict):
- type: user_models.RoleType
-
-
-class CreateUserParams(TypedDict):
- first_name: str
- middle_name: str
- last_name: str
- phone_number: str
- date_of_birth: date
- is_active: bool
- roles: list[RoleParams]
-
-
# TODO: separate controller and service concerns
# https://github.com/navapbc/template-application-flask/issues/49#issue-1505008251
-# TODO: Use classes / objects as inputs to service methods
-# https://github.com/navapbc/template-application-flask/issues/52
-def create_user(db_session: Session, user_params: CreateUserParams) -> User:
+def create_user(db_session: Session, user_model: UserModel) -> User:
with db_session.begin():
# TODO: move this code to service and/or persistence layer
user = User(
- first_name=user_params["first_name"],
- middle_name=user_params["middle_name"],
- last_name=user_params["last_name"],
- phone_number=user_params["phone_number"],
- date_of_birth=user_params["date_of_birth"],
- is_active=user_params["is_active"],
- roles=[Role(type=role["type"]) for role in user_params["roles"]],
+ first_name=user_model.first_name,
+ middle_name=user_model.middle_name,
+ last_name=user_model.last_name,
+ phone_number=user_model.phone_number,
+ date_of_birth=user_model.date_of_birth,
+ is_active=user_model.is_active,
+ roles=[Role(type=role.type) for role in user_model.roles],
)
db_session.add(user)
return user
diff --git a/app/src/services/users/get_user.py b/app/src/services/users/get_user.py
index df4138f5..c6e17e60 100644
--- a/app/src/services/users/get_user.py
+++ b/app/src/services/users/get_user.py
@@ -1,4 +1,6 @@
-import apiflask
+import uuid
+
+from fastapi import HTTPException
from sqlalchemy import orm
from src.adapters.db import Session
@@ -7,15 +9,13 @@
# TODO: separate controller and service concerns
# https://github.com/navapbc/template-application-flask/issues/49#issue-1505008251
-# TODO: Use classes / objects as inputs to service methods
-# https://github.com/navapbc/template-application-flask/issues/52
-def get_user(db_session: Session, user_id: str) -> User:
+def get_user(db_session: Session, user_id: uuid.UUID) -> User:
# TODO: move this to service and/or persistence layer
result = db_session.get(User, user_id, options=[orm.selectinload(User.roles)])
if result is None:
# TODO move HTTP related logic out of service layer to controller layer and just return None from here
# https://github.com/navapbc/template-application-flask/pull/51#discussion_r1053754975
- raise apiflask.HTTPError(404, message=f"Could not find user with ID {user_id}")
+ raise HTTPException(status_code=404, detail=f"Could not find user with ID {user_id}")
return result
diff --git a/app/src/services/users/patch_user.py b/app/src/services/users/patch_user.py
index 3a21c039..06026c12 100644
--- a/app/src/services/users/patch_user.py
+++ b/app/src/services/users/patch_user.py
@@ -1,34 +1,21 @@
import bisect
-from datetime import date
+import uuid
from operator import attrgetter
-from typing import TypedDict
-import apiflask
+from fastapi import HTTPException
from sqlalchemy import orm
from src.adapters.db import Session
+from src.api.users.user_schemas import RoleModel, UserModelPatch
from src.db.models.user_models import Role, User
-from src.services.users.create_user import RoleParams
-
-
-class PatchUserParams(TypedDict, total=False):
- first_name: str
- middle_name: str
- last_name: str
- phone_number: str
- date_of_birth: date
- is_active: bool
- roles: list[RoleParams]
# TODO: separate controller and service concerns
# https://github.com/navapbc/template-application-flask/issues/49#issue-1505008251
-# TODO: Use classes / objects as inputs to service methods
-# https://github.com/navapbc/template-application-flask/issues/52
def patch_user(
db_session: Session,
- user_id: str,
- patch_user_params: PatchUserParams,
+ user_id: uuid.UUID,
+ patch_user_params: UserModelPatch,
) -> User:
with db_session.begin():
@@ -38,21 +25,21 @@ def patch_user(
if user is None:
# TODO move HTTP related logic out of service layer to controller layer and just return None from here
# https://github.com/navapbc/template-application-flask/pull/51#discussion_r1053754975
- raise apiflask.HTTPError(404, message=f"Could not find user with ID {user_id}")
+ raise HTTPException(status_code=404, detail=f"Could not find user with ID {user_id}")
- for key, value in patch_user_params.items():
+ for key, value in patch_user_params.model_dump(exclude_unset=True).items():
if key == "roles":
- _handle_role_patch(db_session, user, patch_user_params["roles"])
+ _handle_role_patch(db_session, user, patch_user_params.roles)
continue
setattr(user, key, value)
return user
-def _handle_role_patch(db_session: Session, user: User, request_roles: list[RoleParams]) -> None:
+def _handle_role_patch(db_session: Session, user: User, request_roles: list[RoleModel]) -> None:
current_role_types = set([role.type for role in user.roles])
- request_role_types = set([role["type"] for role in request_roles])
+ request_role_types = set([role.type for role in request_roles])
roles_to_delete = [role for role in user.roles if role.type not in request_role_types]
diff --git a/app/src/util/env_config.py b/app/src/util/env_config.py
index 40ea8a45..56017bd7 100644
--- a/app/src/util/env_config.py
+++ b/app/src/util/env_config.py
@@ -1,6 +1,6 @@
import os
-from pydantic_settings import BaseSettings
+from pydantic_settings import BaseSettings, SettingsConfigDict
import src
@@ -12,5 +12,4 @@
class PydanticBaseEnvConfig(BaseSettings):
- class Config:
- env_file = env_file
+ model_config = SettingsConfigDict(env_file=env_file)
diff --git a/app/tests/conftest.py b/app/tests/conftest.py
index e6b1c198..6c4d2cf8 100644
--- a/app/tests/conftest.py
+++ b/app/tests/conftest.py
@@ -2,10 +2,10 @@
import _pytest.monkeypatch
import boto3
-import flask
-import flask.testing
+import fastapi
import moto
import pytest
+from fastapi.testclient import TestClient
import src.adapters.db as db
import src.app as app_entry
@@ -115,18 +115,13 @@ def enable_factory_create(monkeypatch, db_session) -> db.Session:
# Make app session scoped so the database connection pool is only created once
# for the test session. This speeds up the tests.
@pytest.fixture(scope="session")
-def app(db_client) -> flask.Flask:
+def app(db_client) -> fastapi.FastAPI:
return app_entry.create_app()
@pytest.fixture
-def client(app: flask.Flask) -> flask.testing.FlaskClient:
- return app.test_client()
-
-
-@pytest.fixture
-def cli_runner(app: flask.Flask) -> flask.testing.CliRunner:
- return app.test_cli_runner()
+def client(app: fastapi.FastAPI) -> TestClient:
+ return TestClient(app)
@pytest.fixture
diff --git a/app/tests/src/adapters/db/test_fastapi_db.py b/app/tests/src/adapters/db/test_fastapi_db.py
new file mode 100644
index 00000000..32dab0bb
--- /dev/null
+++ b/app/tests/src/adapters/db/test_fastapi_db.py
@@ -0,0 +1,56 @@
+import typing
+
+import pytest
+from fastapi import Depends, FastAPI, Request
+from fastapi.testclient import TestClient
+from sqlalchemy import text
+
+import src.adapters.db as db
+import src.adapters.db.fastapi_db as fastapi_db
+
+
+# Define an isolated example FastAPI app fixture specific to this test module
+# to avoid dependencies on any project-specific fixtures in conftest.py
+@pytest.fixture
+def example_app() -> FastAPI:
+ app = FastAPI()
+ db_client = db.PostgresDBClient()
+ fastapi_db.register_db_client(db_client, app)
+ return app
+
+
+def test_get_db(example_app: FastAPI):
+ @example_app.get("/hello")
+ def hello(request: Request) -> dict:
+ with fastapi_db.get_db_client(request.app).get_connection() as conn:
+ return {"data": conn.scalar(text("SELECT 'hello, world'"))}
+
+ response = TestClient(example_app).get("/hello")
+ assert response.json() == {"data": "hello, world"}
+
+
+def test_with_db_session_depends(example_app: FastAPI):
+ @example_app.get("/hello")
+ def hello(db_session: typing.Annotated[db.Session, Depends(fastapi_db.DbSessionDependency())]):
+ with db_session.begin():
+ return {"data": db_session.scalar(text("SELECT 'hello, world'"))}
+
+ response = TestClient(example_app).get("/hello")
+ assert response.json() == {"data": "hello, world"}
+
+
+def test_with_db_session_depends_not_default_name(example_app: FastAPI):
+ db_client = db.PostgresDBClient()
+ fastapi_db.register_db_client(db_client, example_app, client_name="something_else")
+
+ @example_app.get("/hello")
+ def hello(
+ db_session: typing.Annotated[
+ db.Session, Depends(fastapi_db.DbSessionDependency(client_name="something_else"))
+ ]
+ ):
+ with db_session.begin():
+ return {"data": db_session.scalar(text("SELECT 'hello, world'"))}
+
+ response = TestClient(example_app).get("/hello")
+ assert response.json() == {"data": "hello, world"}
diff --git a/app/tests/src/adapters/db/test_flask_db.py b/app/tests/src/adapters/db/test_flask_db.py
deleted file mode 100644
index c95ee12d..00000000
--- a/app/tests/src/adapters/db/test_flask_db.py
+++ /dev/null
@@ -1,51 +0,0 @@
-import pytest
-from flask import Flask, current_app
-from sqlalchemy import text
-
-import src.adapters.db as db
-import src.adapters.db.flask_db as flask_db
-
-
-# Define an isolated example Flask app fixture specific to this test module
-# to avoid dependencies on any project-specific fixtures in conftest.py
-@pytest.fixture
-def example_app() -> Flask:
- app = Flask(__name__)
- db_client = db.PostgresDBClient()
- flask_db.register_db_client(db_client, app)
- return app
-
-
-def test_get_db(example_app: Flask):
- @example_app.route("/hello")
- def hello():
- with flask_db.get_db(current_app).get_connection() as conn:
- return {"data": conn.scalar(text("SELECT 'hello, world'"))}
-
- response = example_app.test_client().get("/hello")
- assert response.get_json() == {"data": "hello, world"}
-
-
-def test_with_db_session(example_app: Flask):
- @example_app.route("/hello")
- @flask_db.with_db_session()
- def hello(db_session: db.Session):
- with db_session.begin():
- return {"data": db_session.scalar(text("SELECT 'hello, world'"))}
-
- response = example_app.test_client().get("/hello")
- assert response.get_json() == {"data": "hello, world"}
-
-
-def test_with_db_session_not_default_name(example_app: Flask):
- db_client = db.PostgresDBClient()
- flask_db.register_db_client(db_client, example_app, client_name="something_else")
-
- @example_app.route("/hello")
- @flask_db.with_db_session(client_name="something_else")
- def hello(db_session: db.Session):
- with db_session.begin():
- return {"data": db_session.scalar(text("SELECT 'hello, world'"))}
-
- response = example_app.test_client().get("/hello")
- assert response.get_json() == {"data": "hello, world"}
diff --git a/app/tests/src/auth/test_api_key_auth.py b/app/tests/src/auth/test_api_key_auth.py
index 152cc63f..cf28bf5f 100644
--- a/app/tests/src/auth/test_api_key_auth.py
+++ b/app/tests/src/auth/test_api_key_auth.py
@@ -1,29 +1,19 @@
import pytest
-from apiflask import HTTPError
-from flask import g
+from fastapi import HTTPException
-from src.auth.api_key_auth import API_AUTH_USER, verify_token
+from src.auth.api_key_auth import verify_api_key
-def test_verify_token_success(app, api_auth_token):
- # Passing it the configured auth token successfully returns a user
- with app.app_context(): # So we can attach the user to the flask app
- user_map = verify_token(api_auth_token)
-
- assert user_map.get("uid") == API_AUTH_USER.id
- assert g.get("current_user") == API_AUTH_USER
-
-
-def test_verify_token_invalid_token(api_auth_token):
+def test_verify_api_key_invalid_token(api_auth_token):
# If you pass it the wrong token
- with pytest.raises(HTTPError):
- verify_token("not the right token")
+ with pytest.raises(HTTPException):
+ verify_api_key("not the right token")
-def test_verify_token_no_configuration(monkeypatch):
+def test_verify_api_key_no_configuration(monkeypatch):
# Remove the API_AUTH_TOKEN env var if set in
# your local environment
monkeypatch.delenv("API_AUTH_TOKEN", raising=False)
# If the auth token is not setup
- with pytest.raises(HTTPError):
- verify_token("any token")
+ with pytest.raises(HTTPException):
+ verify_api_key("any token")
diff --git a/app/src/api/schemas/__init__.py b/app/tests/src/logger/__init__.py
similarity index 100%
rename from app/src/api/schemas/__init__.py
rename to app/tests/src/logger/__init__.py
diff --git a/app/tests/src/logging/test_audit.py b/app/tests/src/logger/test_audit.py
similarity index 98%
rename from app/tests/src/logging/test_audit.py
rename to app/tests/src/logger/test_audit.py
index b361f54b..cb4aa944 100644
--- a/app/tests/src/logging/test_audit.py
+++ b/app/tests/src/logger/test_audit.py
@@ -1,5 +1,5 @@
#
-# Tests for src.logging.audit.
+# Tests for src.logger.audit.
#
import logging
@@ -141,7 +141,7 @@ def test_audit_hook(
pass
assert len(caplog.records) == len(expected_records)
- for record, expected_record in zip(caplog.records, expected_records):
+ for record, expected_record in zip(caplog.records, expected_records, strict=True):
assert record.levelname == "AUDIT"
assert_record_match(record, expected_record)
@@ -161,7 +161,7 @@ def test_os_kill(init_audit_hook, caplog: pytest.LogCaptureFixture):
]
assert len(caplog.records) == len(expected_records)
- for record, expected_record in zip(caplog.records, expected_records):
+ for record, expected_record in zip(caplog.records, expected_records, strict=True):
assert record.levelname == "AUDIT"
assert_record_match(record, expected_record)
@@ -232,7 +232,7 @@ def test_repeated_audit_logs(
]
assert len(caplog.records) == len(expected_records)
- for record, expected_record in zip(caplog.records, expected_records):
+ for record, expected_record in zip(caplog.records, expected_records, strict=True):
assert record.levelname == "AUDIT"
assert_record_match(record, expected_record)
diff --git a/app/tests/src/logging/test_flask_logger.py b/app/tests/src/logger/test_fastapi_logger.py
similarity index 57%
rename from app/tests/src/logging/test_flask_logger.py
rename to app/tests/src/logger/test_fastapi_logger.py
index 67315c80..4b6a803e 100644
--- a/app/tests/src/logging/test_flask_logger.py
+++ b/app/tests/src/logger/test_fastapi_logger.py
@@ -3,14 +3,21 @@
import time
import pytest
-from flask import Flask
+from fastapi import FastAPI
+from fastapi.testclient import TestClient
+from starlette_context.middleware import RawContextMiddleware
-import src.logger.flask_logger as flask_logger
+import src.logger.fastapi_logger as fastapi_logger
+from src.app import get_global_dependencies
from tests.lib.assertions import assert_dict_contains
@pytest.fixture
def logger():
+ # The test client used in tests calls the httpx library
+ # which does its own logging which add noise we don't want.
+ logging.getLogger("httpx").setLevel(logging.ERROR)
+
logger = logging.getLogger("src")
before_level = logger.level
@@ -24,14 +31,16 @@ def logger():
@pytest.fixture
def app(logger):
- app = Flask("test_app_name")
+ app = FastAPI(title="test_app_name", dependencies=get_global_dependencies())
- @app.get("/hello/")
- def hello(name):
+ @app.get("/hello/{name}")
+ def hello(name: str) -> str:
logging.getLogger("src.hello").info(f"hello, {name}!")
return "ok"
- flask_logger.init_app(logger, app)
+ fastapi_logger.init_app(logger, app)
+ app.add_middleware(RawContextMiddleware)
+
return app
@@ -43,10 +52,11 @@ def hello(name):
{"msg": "hello, jane!"},
{
"msg": "end request",
+ "app.name": "test_app_name",
"response.status_code": 200,
- "response.content_length": 2,
- "response.content_type": "text/html; charset=utf-8",
- "response.mimetype": "text/html",
+ "response.content_length": "4",
+ "response.content_type": "application/json",
+ "response.charset": "utf-8",
},
],
id="200",
@@ -57,10 +67,11 @@ def hello(name):
{"msg": "start request"},
{
"msg": "end request",
+ "app.name": "test_app_name",
"response.status_code": 404,
- "response.content_length": 207,
- "response.content_type": "text/html; charset=utf-8",
- "response.mimetype": "text/html",
+ "response.content_length": "22",
+ "response.content_type": "application/json",
+ "response.charset": "utf-8",
},
],
id="404",
@@ -73,69 +84,63 @@ def hello(name):
test_request_lifecycle_logs_data,
)
def test_request_lifecycle_logs(
- app: Flask, caplog: pytest.LogCaptureFixture, route, expected_extras
+ app: FastAPI, caplog: pytest.LogCaptureFixture, route, expected_extras
):
- app.test_client().get(route)
+ TestClient(app).get(route)
# Assert that the log messages are present
# There should be the route log message that is logged in the before_request handler
# as part of every request, followed by the log message in the route handler itself.
# then the log message in the after_request handler.
- assert len(caplog.records) == len(expected_extras)
- for record, expected_extra in zip(caplog.records, expected_extras):
- assert_dict_contains(record.__dict__, expected_extra)
-
-
-def test_app_context_extra_attributes(app: Flask, caplog: pytest.LogCaptureFixture):
- # Assert that extra attributes related to the app context are present in all log records
- expected_extra = {"app.name": "test_app_name"}
-
- app.test_client().get("/hello/jane")
-
- assert len(caplog.records) > 0
- for record in caplog.records:
+ assert len(caplog.records) == len(expected_extras), "|".join(
+ [record.__dict__["msg"] for record in caplog.records]
+ )
+ for record, expected_extra in zip(caplog.records, expected_extras, strict=True):
assert_dict_contains(record.__dict__, expected_extra)
-def test_request_context_extra_attributes(app: Flask, caplog: pytest.LogCaptureFixture):
+def test_request_context_extra_attributes(app: FastAPI, caplog: pytest.LogCaptureFixture):
# Assert that the extra attributes related to the request context are present in all log records
expected_extra = {
"request.id": "",
"request.method": "GET",
"request.path": "/hello/jane",
- "request.url_rule": "/hello/",
+ # "request.url_rule": "/hello/",
"request.query.up": "high",
"request.query.down": "low",
}
- app.test_client().get("/hello/jane?up=high&down=low")
+ TestClient(app).get("/hello/jane?up=high&down=low")
assert len(caplog.records) > 0
for record in caplog.records:
assert_dict_contains(record.__dict__, expected_extra)
+ # After the first log (the start request log), we'll have the url_rule
+ expected_extra["request.url_rule"] = "/hello/{name}"
+
-def test_add_extra_log_data_for_current_request(app: Flask, caplog: pytest.LogCaptureFixture):
- @app.get("/pet/")
- def pet(name):
- flask_logger.add_extra_data_to_current_request_logs({"pet.name": name})
+def test_add_extra_log_data_for_current_request(app: FastAPI, caplog: pytest.LogCaptureFixture):
+ @app.get("/pet/{name}")
+ def pet(name: str) -> str:
+ fastapi_logger.add_extra_data_to_current_request_logs({"pet.name": name})
logging.getLogger("test.pet").info(f"petting {name}")
return "ok"
- app.test_client().get("/pet/kitty")
+ TestClient(app).get("/pet/kitty")
last_record = caplog.records[-1]
assert_dict_contains(last_record.__dict__, {"pet.name": "kitty"})
-def test_log_response_time(app: Flask, caplog: pytest.LogCaptureFixture):
+def test_log_response_time(app: FastAPI, caplog: pytest.LogCaptureFixture):
@app.get("/sleep")
def sleep():
time.sleep(0.1) # 0.1 s = 100 ms
return "ok"
- app.test_client().get("/sleep")
+ TestClient(app).get("/sleep")
last_record = caplog.records[-1]
assert "response.time_ms" in last_record.__dict__
diff --git a/app/tests/src/logging/test_formatters.py b/app/tests/src/logger/test_formatters.py
similarity index 100%
rename from app/tests/src/logging/test_formatters.py
rename to app/tests/src/logger/test_formatters.py
diff --git a/app/tests/src/logging/test_logging.py b/app/tests/src/logger/test_logging.py
similarity index 97%
rename from app/tests/src/logging/test_logging.py
rename to app/tests/src/logger/test_logging.py
index 0a1b530c..afad9cc5 100644
--- a/app/tests/src/logging/test_logging.py
+++ b/app/tests/src/logger/test_logging.py
@@ -12,7 +12,7 @@
def init_test_logger(caplog: pytest.LogCaptureFixture, monkeypatch: pytest.MonkeyPatch):
caplog.set_level(logging.DEBUG)
monkeypatch.setenv("LOG_FORMAT", "human-readable")
- with src.logging.init("test_logging"):
+ with src.logger.init("test_logging"):
yield
@@ -27,7 +27,7 @@ def test_init(caplog: pytest.LogCaptureFixture, monkeypatch, log_format, expecte
caplog.set_level(logging.DEBUG)
monkeypatch.setenv("LOG_FORMAT", log_format)
- with src.logging.init("test_logging"):
+ with src.logger.init("test_logging"):
records = caplog.records
assert len(records) == 2
diff --git a/app/tests/src/logging/test_pii.py b/app/tests/src/logger/test_pii.py
similarity index 100%
rename from app/tests/src/logging/test_pii.py
rename to app/tests/src/logger/test_pii.py
diff --git a/app/tests/src/logging/__init__.py b/app/tests/src/logging/__init__.py
deleted file mode 100644
index e69de29b..00000000
diff --git a/app/tests/src/route/test_user_route.py b/app/tests/src/route/test_user_route.py
index f8b38efd..3eeda876 100644
--- a/app/tests/src/route/test_user_route.py
+++ b/app/tests/src/route/test_user_route.py
@@ -28,7 +28,7 @@ def base_request():
@pytest.fixture
def created_user(client, api_auth_token, base_request):
response = client.post("/v1/users", json=base_request, headers={"X-Auth": api_auth_token})
- return response.get_json()["data"]
+ return response.json()
test_create_and_get_user_data = [
@@ -45,7 +45,7 @@ def test_create_and_get_user(client, base_request, api_auth_token, roles):
"roles": roles,
}
post_response = client.post("/v1/users", json=request, headers={"X-Auth": api_auth_token})
- post_response_data = post_response.get_json()["data"]
+ post_response_data = post_response.json()
expected_response = {
**request,
"id": post_response_data["id"],
@@ -59,26 +59,41 @@ def test_create_and_get_user(client, base_request, api_auth_token, roles):
assert post_response_data["updated_at"] is not None
# Get the user
- user_id = post_response.get_json()["data"]["id"]
+ user_id = post_response.json()["id"]
get_response = client.get(f"/v1/users/{user_id}", headers={"X-Auth": api_auth_token})
assert get_response.status_code == 200
- get_response_data = get_response.get_json()["data"]
+ get_response_data = get_response.json()
assert get_response_data == expected_response
test_create_user_bad_request_data = [
pytest.param(
{},
- {
- "first_name": ["Missing data for required field."],
- "last_name": ["Missing data for required field."],
- "phone_number": ["Missing data for required field."],
- "date_of_birth": ["Missing data for required field."],
- "is_active": ["Missing data for required field."],
- "roles": ["Missing data for required field."],
- },
+ [
+ {
+ "input": {},
+ "loc": ["body", "first_name"],
+ "msg": "Field required",
+ "type": "missing",
+ },
+ {"input": {}, "loc": ["body", "last_name"], "msg": "Field required", "type": "missing"},
+ {
+ "input": {},
+ "loc": ["body", "phone_number"],
+ "msg": "Field required",
+ "type": "missing",
+ },
+ {
+ "input": {},
+ "loc": ["body", "date_of_birth"],
+ "msg": "Field required",
+ "type": "missing",
+ },
+ {"input": {}, "loc": ["body", "is_active"], "msg": "Field required", "type": "missing"},
+ {"input": {}, "loc": ["body", "roles"], "msg": "Field required", "type": "missing"},
+ ],
id="missing all required fields",
),
pytest.param(
@@ -91,25 +106,70 @@ def test_create_and_get_user(client, base_request, api_auth_token, roles):
"is_active": 6,
"roles": 7,
},
- {
- "first_name": ["Not a valid string."],
- "middle_name": ["Not a valid string."],
- "last_name": ["Not a valid string."],
- "phone_number": ["Not a valid string."],
- "date_of_birth": ["Not a valid date."],
- "is_active": ["Not a valid boolean."],
- "roles": ["Not a valid list."],
- },
+ [
+ {
+ "input": 1,
+ "loc": ["body", "first_name"],
+ "msg": "Input should be a valid string",
+ "type": "string_type",
+ },
+ {
+ "input": 2,
+ "loc": ["body", "middle_name"],
+ "msg": "Input should be a valid string",
+ "type": "string_type",
+ },
+ {
+ "input": 3,
+ "loc": ["body", "last_name"],
+ "msg": "Input should be a valid string",
+ "type": "string_type",
+ },
+ {
+ "input": 5,
+ "loc": ["body", "phone_number"],
+ "msg": "Input should be a valid string",
+ "type": "string_type",
+ },
+ {
+ "input": 4,
+ "loc": ["body", "date_of_birth"],
+ "msg": "Datetimes provided to dates should have zero time - e.g. be exact " "dates",
+ "type": "date_from_datetime_inexact",
+ },
+ {
+ "input": 6,
+ "loc": ["body", "is_active"],
+ "msg": "Input should be a valid boolean, unable to interpret input",
+ "type": "bool_parsing",
+ },
+ {
+ "input": 7,
+ "loc": ["body", "roles"],
+ "msg": "Input should be a valid list",
+ "type": "list_type",
+ },
+ ],
id="invalid types",
),
pytest.param(
get_base_request() | {"roles": [{"type": "Mime"}, {"type": "Clown"}]},
- {
- "roles": {
- "0": {"type": ["Must be one of: USER, ADMIN."]},
- "1": {"type": ["Must be one of: USER, ADMIN."]},
- }
- },
+ [
+ {
+ "ctx": {"expected": "'USER' or 'ADMIN'"},
+ "input": "Mime",
+ "loc": ["body", "roles", 0, "type"],
+ "msg": "Input should be 'USER' or 'ADMIN'",
+ "type": "enum",
+ },
+ {
+ "ctx": {"expected": "'USER' or 'ADMIN'"},
+ "input": "Clown",
+ "loc": ["body", "roles", 1, "type"],
+ "msg": "Input should be 'USER' or 'ADMIN'",
+ "type": "enum",
+ },
+ ],
id="invalid role type",
),
]
@@ -120,7 +180,7 @@ def test_create_user_bad_request(client, api_auth_token, request_data, expected_
response = client.post("/v1/users", json=request_data, headers={"X-Auth": api_auth_token})
assert response.status_code == 422
- response_data = response.get_json()["detail"]["json"]
+ response_data = response.json()["detail"]
assert response_data == expected_response_data
@@ -130,7 +190,7 @@ def test_patch_user(client, api_auth_token, created_user):
patch_response = client.patch(
f"/v1/users/{user_id}", json=patch_request, headers={"X-Auth": api_auth_token}
)
- patch_response_data = patch_response.get_json()["data"]
+ patch_response_data = patch_response.json()
expected_response_data = {
**created_user,
**patch_request,
@@ -141,7 +201,7 @@ def test_patch_user(client, api_auth_token, created_user):
assert patch_response_data == expected_response_data
get_response = client.get(f"/v1/users/{user_id}", headers={"X-Auth": api_auth_token})
- get_response_data = get_response.get_json()["data"]
+ get_response_data = get_response.json()
assert get_response_data == expected_response_data
@@ -155,14 +215,14 @@ def test_patch_user_roles(client, base_request, api_auth_token, initial_roles, u
}
created_user = client.post(
"/v1/users", json=post_request, headers={"X-Auth": api_auth_token}
- ).get_json()["data"]
+ ).json()
user_id = created_user["id"]
patch_request = {"roles": updated_roles}
patch_response = client.patch(
f"/v1/users/{user_id}", json=patch_request, headers={"X-Auth": api_auth_token}
)
- patch_response_data = patch_response.get_json()["data"]
+ patch_response_data = patch_response.json()
expected_response_data = {
**created_user,
**patch_request,
@@ -173,7 +233,7 @@ def test_patch_user_roles(client, base_request, api_auth_token, initial_roles, u
assert patch_response_data == expected_response_data
get_response = client.get(f"/v1/users/{user_id}", headers={"X-Auth": api_auth_token})
- get_response_data = get_response.get_json()["data"]
+ get_response_data = get_response.json()
assert get_response_data == expected_response_data
@@ -190,10 +250,12 @@ def test_unauthorized(client, method, url, body, api_auth_token):
expected_message = (
"The server could not verify that you are authorized to access the URL requested"
)
- response = getattr(client, method)(url, json=body, headers={"X-Auth": "incorrect token"})
+ response = client.request(
+ method=method, url=url, json=body, headers={"X-Auth": "incorrect token"}
+ )
assert response.status_code == 401
- assert response.get_json()["message"] == expected_message
+ assert response.json()["detail"] == expected_message
test_not_found_data = [
@@ -206,7 +268,7 @@ def test_unauthorized(client, method, url, body, api_auth_token):
def test_not_found(client, api_auth_token, method, body):
user_id = uuid.uuid4()
url = f"/v1/users/{user_id}"
- response = getattr(client, method)(url, json=body, headers={"X-Auth": api_auth_token})
+ response = client.request(method=method, url=url, json=body, headers={"X-Auth": api_auth_token})
assert response.status_code == 404
- assert response.get_json()["message"] == f"Could not find user with ID {user_id}"
+ assert response.json()["detail"] == f"Could not find user with ID {user_id}"
diff --git a/app/tests/src/scripts/__init__.py b/app/tests/src/scripts/__init__.py
deleted file mode 100644
index e69de29b..00000000
diff --git a/app/tests/src/scripts/test_create_user_csv.py b/app/tests/src/scripts/test_create_user_csv.py
deleted file mode 100644
index ccd04e1e..00000000
--- a/app/tests/src/scripts/test_create_user_csv.py
+++ /dev/null
@@ -1,60 +0,0 @@
-import os
-import os.path as path
-import re
-
-import flask.testing
-import pytest
-from pytest_lazyfixture import lazy_fixture
-from smart_open import open as smart_open
-
-import src.adapters.db as db
-from src.db.models.user_models import User
-from tests.src.db.models.factories import UserFactory
-
-
-@pytest.fixture
-def prepopulate_user_table(enable_factory_create, db_session: db.Session) -> list[User]:
- # First make sure the table is empty, as other tests may have inserted data
- # and this test expects a clean slate (unlike most tests that are designed to
- # be isolated from other tests)
- db_session.query(User).delete()
- return [
- UserFactory.create(first_name="Jon", last_name="Doe", is_active=True),
- UserFactory.create(first_name="Jane", last_name="Doe", is_active=False),
- UserFactory.create(
- first_name="Alby",
- last_name="Testin",
- is_active=True,
- ),
- ]
-
-
-@pytest.fixture
-def tmp_s3_folder(mock_s3_bucket):
- return f"s3://{mock_s3_bucket}/path/to/"
-
-
-@pytest.mark.parametrize(
- "dir",
- [
- pytest.param(lazy_fixture("tmp_s3_folder"), id="write-to-s3"),
- pytest.param(lazy_fixture("tmp_path"), id="write-to-local"),
- ],
-)
-def test_create_user_csv(
- prepopulate_user_table: list[User],
- cli_runner: flask.testing.FlaskCliRunner,
- dir: str,
-):
- cli_runner.invoke(args=["user", "create-csv", "--dir", dir, "--filename", "test.csv"])
- output = smart_open(path.join(dir, "test.csv")).read()
- expected_output = open(
- path.join(path.dirname(__file__), "test_create_user_csv_expected.csv")
- ).read()
- assert output == expected_output
-
-
-def test_default_filename(cli_runner: flask.testing.FlaskCliRunner, tmp_path: str):
- cli_runner.invoke(args=["user", "create-csv", "--dir", tmp_path])
- filenames = os.listdir(tmp_path)
- assert re.match(r"\d{4}-\d{2}-\d{2}-\d{2}-\d{2}-\d{2}-user-roles.csv", filenames[0])
diff --git a/app/tests/src/scripts/test_create_user_csv_expected.csv b/app/tests/src/scripts/test_create_user_csv_expected.csv
deleted file mode 100644
index da0371ff..00000000
--- a/app/tests/src/scripts/test_create_user_csv_expected.csv
+++ /dev/null
@@ -1,4 +0,0 @@
-"User Name","Roles","Is User Active?"
-"Jon Doe","ADMIN USER","True"
-"Jane Doe","ADMIN USER","False"
-"Alby Testin","ADMIN USER","True"
diff --git a/docker-compose.yml b/docker-compose.yml
index ed88388b..19e30a09 100644
--- a/docker-compose.yml
+++ b/docker-compose.yml
@@ -24,7 +24,7 @@ services:
args:
- RUN_UID=${RUN_UID:-4000}
- RUN_USER=${RUN_USER:-app}
- command: ["poetry", "run", "flask", "--app", "src.app", "run", "--host", "0.0.0.0", "--port", "8080", "--reload"]
+ command: ["poetry", "run", "python", "src/app.py"]
container_name: main-app
env_file: ./app/local.env
ports: