diff --git a/app/Dockerfile b/app/Dockerfile index 76c48f2d..e28e593d 100644 --- a/app/Dockerfile +++ b/app/Dockerfile @@ -4,7 +4,7 @@ # The build stage that will be used to deploy to the various environments # needs to be called `release` in order to integrate with the repo's # top-level Makefile -FROM python:3-slim AS base +FROM python:3.11-slim AS base # Install poetry, the package manager. # https://python-poetry.org @@ -57,7 +57,7 @@ COPY . /app ENV HOST=0.0.0.0 # Run the application. -CMD ["poetry", "run", "python", "-m", "src"] +CMD ["poetry", "run", "python", "src/app.py"] #--------- # Release @@ -109,4 +109,4 @@ ENV HOST=0.0.0.0 USER ${RUN_USER} # Run the application. -CMD ["poetry", "run", "gunicorn", "src.app:create_app()"] +CMD ["poetry", "run", "python", "src/app.py"] diff --git a/app/local.env b/app/local.env index 21c5a1b6..ac6d2e40 100644 --- a/app/local.env +++ b/app/local.env @@ -15,8 +15,6 @@ PYTHONPATH=/app/ # commands that can run in or out # of the Docker container - defaults to outside -FLASK_APP=src.app:create_app - ############################ # Logging ############################ diff --git a/app/poetry.lock b/app/poetry.lock index acb931c9..78072d73 100644 --- a/app/poetry.lock +++ b/app/poetry.lock @@ -97,16 +97,6 @@ tests = ["PyYAML (>=3.10)", "marshmallow (>=3.13.0)", "openapi-spec-validator (< validation = ["openapi-spec-validator (<0.5)", "prance[osv] (>=0.11)"] yaml = ["PyYAML (>=3.10)"] -[[package]] -name = "atomicwrites" -version = "1.4.1" -description = "Atomic file writes." -optional = false -python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" -files = [ - {file = "atomicwrites-1.4.1.tar.gz", hash = "sha256:81b2c9071a49367a7f770170e5eec8cb66567cfbbc8c73d20ce5ca4a8d71cf11"}, -] - [[package]] name = "attrs" version = "23.1.0" @@ -1289,44 +1279,48 @@ xray = ["aws-xray-sdk (>=0.93,!=0.96)", "setuptools"] [[package]] name = "mypy" -version = "0.971" +version = "1.5.1" description = "Optional static typing for Python" optional = false -python-versions = ">=3.6" +python-versions = ">=3.8" files = [ - {file = "mypy-0.971-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:f2899a3cbd394da157194f913a931edfd4be5f274a88041c9dc2d9cdcb1c315c"}, - {file = "mypy-0.971-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:98e02d56ebe93981c41211c05adb630d1d26c14195d04d95e49cd97dbc046dc5"}, - {file = "mypy-0.971-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:19830b7dba7d5356d3e26e2427a2ec91c994cd92d983142cbd025ebe81d69cf3"}, - {file = "mypy-0.971-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:02ef476f6dcb86e6f502ae39a16b93285fef97e7f1ff22932b657d1ef1f28655"}, - {file = "mypy-0.971-cp310-cp310-win_amd64.whl", hash = "sha256:25c5750ba5609a0c7550b73a33deb314ecfb559c350bb050b655505e8aed4103"}, - {file = "mypy-0.971-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:d3348e7eb2eea2472db611486846742d5d52d1290576de99d59edeb7cd4a42ca"}, - {file = "mypy-0.971-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:3fa7a477b9900be9b7dd4bab30a12759e5abe9586574ceb944bc29cddf8f0417"}, - {file = "mypy-0.971-cp36-cp36m-win_amd64.whl", hash = "sha256:2ad53cf9c3adc43cf3bea0a7d01a2f2e86db9fe7596dfecb4496a5dda63cbb09"}, - {file = "mypy-0.971-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:855048b6feb6dfe09d3353466004490b1872887150c5bb5caad7838b57328cc8"}, - {file = "mypy-0.971-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:23488a14a83bca6e54402c2e6435467a4138785df93ec85aeff64c6170077fb0"}, - {file = "mypy-0.971-cp37-cp37m-win_amd64.whl", hash = "sha256:4b21e5b1a70dfb972490035128f305c39bc4bc253f34e96a4adf9127cf943eb2"}, - {file = "mypy-0.971-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:9796a2ba7b4b538649caa5cecd398d873f4022ed2333ffde58eaf604c4d2cb27"}, - {file = "mypy-0.971-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:5a361d92635ad4ada1b1b2d3630fc2f53f2127d51cf2def9db83cba32e47c856"}, - {file = "mypy-0.971-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:b793b899f7cf563b1e7044a5c97361196b938e92f0a4343a5d27966a53d2ec71"}, - {file = "mypy-0.971-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:d1ea5d12c8e2d266b5fb8c7a5d2e9c0219fedfeb493b7ed60cd350322384ac27"}, - {file = "mypy-0.971-cp38-cp38-win_amd64.whl", hash = "sha256:23c7ff43fff4b0df93a186581885c8512bc50fc4d4910e0f838e35d6bb6b5e58"}, - {file = "mypy-0.971-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:1f7656b69974a6933e987ee8ffb951d836272d6c0f81d727f1d0e2696074d9e6"}, - {file = "mypy-0.971-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:d2022bfadb7a5c2ef410d6a7c9763188afdb7f3533f22a0a32be10d571ee4bbe"}, - {file = "mypy-0.971-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ef943c72a786b0f8d90fd76e9b39ce81fb7171172daf84bf43eaf937e9f220a9"}, - {file = "mypy-0.971-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:d744f72eb39f69312bc6c2abf8ff6656973120e2eb3f3ec4f758ed47e414a4bf"}, - {file = "mypy-0.971-cp39-cp39-win_amd64.whl", hash = "sha256:77a514ea15d3007d33a9e2157b0ba9c267496acf12a7f2b9b9f8446337aac5b0"}, - {file = "mypy-0.971-py3-none-any.whl", hash = "sha256:0d054ef16b071149917085f51f89555a576e2618d5d9dd70bd6eea6410af3ac9"}, - {file = "mypy-0.971.tar.gz", hash = "sha256:40b0f21484238269ae6a57200c807d80debc6459d444c0489a102d7c6a75fa56"}, + {file = "mypy-1.5.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:f33592ddf9655a4894aef22d134de7393e95fcbdc2d15c1ab65828eee5c66c70"}, + {file = "mypy-1.5.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:258b22210a4a258ccd077426c7a181d789d1121aca6db73a83f79372f5569ae0"}, + {file = "mypy-1.5.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a9ec1f695f0c25986e6f7f8778e5ce61659063268836a38c951200c57479cc12"}, + {file = "mypy-1.5.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:abed92d9c8f08643c7d831300b739562b0a6c9fcb028d211134fc9ab20ccad5d"}, + {file = "mypy-1.5.1-cp310-cp310-win_amd64.whl", hash = "sha256:a156e6390944c265eb56afa67c74c0636f10283429171018446b732f1a05af25"}, + {file = "mypy-1.5.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6ac9c21bfe7bc9f7f1b6fae441746e6a106e48fc9de530dea29e8cd37a2c0cc4"}, + {file = "mypy-1.5.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:51cb1323064b1099e177098cb939eab2da42fea5d818d40113957ec954fc85f4"}, + {file = "mypy-1.5.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:596fae69f2bfcb7305808c75c00f81fe2829b6236eadda536f00610ac5ec2243"}, + {file = "mypy-1.5.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:32cb59609b0534f0bd67faebb6e022fe534bdb0e2ecab4290d683d248be1b275"}, + {file = "mypy-1.5.1-cp311-cp311-win_amd64.whl", hash = "sha256:159aa9acb16086b79bbb0016145034a1a05360626046a929f84579ce1666b315"}, + {file = "mypy-1.5.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:f6b0e77db9ff4fda74de7df13f30016a0a663928d669c9f2c057048ba44f09bb"}, + {file = "mypy-1.5.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:26f71b535dfc158a71264e6dc805a9f8d2e60b67215ca0bfa26e2e1aa4d4d373"}, + {file = "mypy-1.5.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2fc3a600f749b1008cc75e02b6fb3d4db8dbcca2d733030fe7a3b3502902f161"}, + {file = "mypy-1.5.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:26fb32e4d4afa205b24bf645eddfbb36a1e17e995c5c99d6d00edb24b693406a"}, + {file = "mypy-1.5.1-cp312-cp312-win_amd64.whl", hash = "sha256:82cb6193de9bbb3844bab4c7cf80e6227d5225cc7625b068a06d005d861ad5f1"}, + {file = "mypy-1.5.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:4a465ea2ca12804d5b34bb056be3a29dc47aea5973b892d0417c6a10a40b2d65"}, + {file = "mypy-1.5.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:9fece120dbb041771a63eb95e4896791386fe287fefb2837258925b8326d6160"}, + {file = "mypy-1.5.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d28ddc3e3dfeab553e743e532fb95b4e6afad51d4706dd22f28e1e5e664828d2"}, + {file = "mypy-1.5.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:57b10c56016adce71fba6bc6e9fd45d8083f74361f629390c556738565af8eeb"}, + {file = "mypy-1.5.1-cp38-cp38-win_amd64.whl", hash = "sha256:ff0cedc84184115202475bbb46dd99f8dcb87fe24d5d0ddfc0fe6b8575c88d2f"}, + {file = "mypy-1.5.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:8f772942d372c8cbac575be99f9cc9d9fb3bd95c8bc2de6c01411e2c84ebca8a"}, + {file = "mypy-1.5.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:5d627124700b92b6bbaa99f27cbe615c8ea7b3402960f6372ea7d65faf376c14"}, + {file = "mypy-1.5.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:361da43c4f5a96173220eb53340ace68cda81845cd88218f8862dfb0adc8cddb"}, + {file = "mypy-1.5.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:330857f9507c24de5c5724235e66858f8364a0693894342485e543f5b07c8693"}, + {file = "mypy-1.5.1-cp39-cp39-win_amd64.whl", hash = "sha256:c543214ffdd422623e9fedd0869166c2f16affe4ba37463975043ef7d2ea8770"}, + {file = "mypy-1.5.1-py3-none-any.whl", hash = "sha256:f757063a83970d67c444f6e01d9550a7402322af3557ce7630d3c957386fa8f5"}, + {file = "mypy-1.5.1.tar.gz", hash = "sha256:b031b9601f1060bf1281feab89697324726ba0c0bae9d7cd7ab4b690940f0b92"}, ] [package.dependencies] -mypy-extensions = ">=0.4.3" +mypy-extensions = ">=1.0.0" tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} -typing-extensions = ">=3.10" +typing-extensions = ">=4.1.0" [package.extras] dmypy = ["psutil (>=4.0)"] -python2 = ["typed-ast (>=1.4.0,<2)"] +install-types = ["pip"] reports = ["lxml"] [[package]] @@ -1559,17 +1553,6 @@ files = [ {file = "psycopg_binary-3.1.10-cp39-cp39-win_amd64.whl", hash = "sha256:b30887e631fd67affaed98f6cd2135b44f2d1a6d9bca353a69c3889c78bd7aa8"}, ] -[[package]] -name = "py" -version = "1.11.0" -description = "library with cross-python path, ini-parsing, io, code, log facilities" -optional = false -python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" -files = [ - {file = "py-1.11.0-py2.py3-none-any.whl", hash = "sha256:607c53218732647dff4acdfcd50cb62615cedf612e72d1724fb1a0cc6405b378"}, - {file = "py-1.11.0.tar.gz", hash = "sha256:51c75c4126074b472f746a24399ad32f6053d1b34b68d2fa41e558e6f4a98719"}, -] - [[package]] name = "py-partiql-parser" version = "0.3.7" @@ -1797,27 +1780,25 @@ plugins = ["importlib-metadata"] [[package]] name = "pytest" -version = "6.2.5" +version = "7.4.2" description = "pytest: simple powerful testing with Python" optional = false -python-versions = ">=3.6" +python-versions = ">=3.7" files = [ - {file = "pytest-6.2.5-py3-none-any.whl", hash = "sha256:7310f8d27bc79ced999e760ca304d69f6ba6c6649c0b60fb0e04a4a77cacc134"}, - {file = "pytest-6.2.5.tar.gz", hash = "sha256:131b36680866a76e6781d13f101efb86cf674ebb9762eb70d3082b6f29889e89"}, + {file = "pytest-7.4.2-py3-none-any.whl", hash = "sha256:1d881c6124e08ff0a1bb75ba3ec0bfd8b5354a01c194ddd5a0a870a48d99b002"}, + {file = "pytest-7.4.2.tar.gz", hash = "sha256:a766259cfab564a2ad52cb1aae1b881a75c3eb7e34ca3779697c23ed47c47069"}, ] [package.dependencies] -atomicwrites = {version = ">=1.0", markers = "sys_platform == \"win32\""} -attrs = ">=19.2.0" colorama = {version = "*", markers = "sys_platform == \"win32\""} +exceptiongroup = {version = ">=1.0.0rc8", markers = "python_version < \"3.11\""} iniconfig = "*" packaging = "*" pluggy = ">=0.12,<2.0" -py = ">=1.8.2" -toml = "*" +tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""} [package.extras] -testing = ["argcomplete", "hypothesis (>=3.56)", "mock", "nose", "requests", "xmlschema"] +testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] [[package]] name = "pytest-lazy-fixture" @@ -2176,6 +2157,20 @@ anyio = ">=3.4.0,<5" [package.extras] full = ["httpx (>=0.22.0)", "itsdangerous", "jinja2", "python-multipart", "pyyaml"] +[[package]] +name = "starlette-context" +version = "0.3.6" +description = "Middleware for Starlette that allows you to store and access the context data of a request. Can be used with logging so logs automatically use request headers such as x-request-id or x-correlation-id." +optional = false +python-versions = ">=3.8,<4.0" +files = [ + {file = "starlette_context-0.3.6-py3-none-any.whl", hash = "sha256:b14ce373fbb6895a2182a7104b9f63ba20c8db83444005fb9a844dd77ad9895c"}, + {file = "starlette_context-0.3.6.tar.gz", hash = "sha256:d361a36ba2d4acca3ab680f917b25e281533d725374752d47607a859041958cb"}, +] + +[package.dependencies] +starlette = "*" + [[package]] name = "stevedore" version = "5.1.0" @@ -2190,17 +2185,6 @@ files = [ [package.dependencies] pbr = ">=2.0.0,<2.1.0 || >2.1.0" -[[package]] -name = "toml" -version = "0.10.2" -description = "Python Library for Tom's Obvious, Minimal Language" -optional = false -python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*" -files = [ - {file = "toml-0.10.2-py2.py3-none-any.whl", hash = "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b"}, - {file = "toml-0.10.2.tar.gz", hash = "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f"}, -] - [[package]] name = "tomli" version = "2.0.1" @@ -2650,4 +2634,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "8a0dc94cf02cf1ff2ca09f4e5163d26ead0534f6648196b4188b8303fe54da95" +content-hash = "b7560789c538007917c286a7d07a430fbcd01560b8dc27da4037a58d80e7a1cf" diff --git a/app/pyproject.toml b/app/pyproject.toml index 25233f03..284f754d 100644 --- a/app/pyproject.toml +++ b/app/pyproject.toml @@ -21,6 +21,7 @@ marshmallow = "^3.18.0" gunicorn = "^21.2.0" fastapi = {extras = ["all"], version = "^0.103.1"} psycopg = {extras = ["binary"], version = "^3.1.10"} +starlette-context = "^0.3.6" [tool.poetry.group.dev.dependencies] @@ -29,14 +30,14 @@ flake8 = "^5.0.4" flake8-bugbear = "^22.8.23" flake8-alfred = "^1.1.1" isort = "^5.10.1" -mypy = "^0.971" +mypy = "1.5.1" moto = {extras = ["s3"], version = "^4.0.2"} types-pytz = "^2022.2.1" coverage = "^6.4.4" Faker = "^14.2.0" factory-boy = "^3.2.1" bandit = "^1.7.4" -pytest = "^6.0.0" +pytest = "7.4.2" pytest-watch = "^4.2.0" pytest-lazy-fixture = "^0.6.3" types-pyyaml = "^6.0.12.11" @@ -83,6 +84,9 @@ warn_redundant_casts = true warn_unreachable = true warn_unused_ignores = true +plugins = ["pydantic.mypy"] + + [tool.bandit] # Ignore audit logging test file since test audit logging requires a lot of operations that trigger bandit warnings exclude_dirs = ["./tests/src/logging/test_audit.py"] diff --git a/app/setup.cfg b/app/setup.cfg index 54346acf..0505d286 100644 --- a/app/setup.cfg +++ b/app/setup.cfg @@ -18,3 +18,9 @@ ignore = E266, # don't use bare except, B001 is more descriptive E722 + +# Tell Flake to not error when we do "def my_func(some_field = default_value)" when +# the default value is one of the following types as these are used by FastAPI. +# These are calls to classes and functions and are immutably safe in their usage. +# https://stackoverflow.com/questions/73306462/flake8-throws-b008-fastapi-data-type-definitions +extend-immutable-calls = Depends, Security, fastapi_db.DbSessionDependency \ No newline at end of file diff --git a/app/src/adapters/db/__init__.py b/app/src/adapters/db/__init__.py index 591de851..913841dd 100644 --- a/app/src/adapters/db/__init__.py +++ b/app/src/adapters/db/__init__.py @@ -2,9 +2,9 @@ Database module. This module contains the DBClient class, which is used to manage database connections. -This module can be used on it's own or with an application framework such as Flask. +This module can be used on it's own or with an application framework such as Fast API. -To use this module with Flask, use the flask_db module. +To use this module with FastAPI, use the fastapi_db module. Usage: import src.adapters.db as db @@ -26,7 +26,7 @@ from src.adapters.db.client import Connection, DBClient, Session from src.adapters.db.clients.postgres_client import PostgresDBClient -# Do not import flask_db here, because this module is not dependent on any specific framework. -# Code can choose to use this module on its own or with the flask_db module depending on needs. +# Do not import fastapi_db here, because this module is not dependent on any specific framework. +# Code can choose to use this module on its own or with the fastapi_db module depending on needs. __all__ = ["Connection", "DBClient", "Session", "PostgresDBClient"] diff --git a/app/src/adapters/db/client.py b/app/src/adapters/db/client.py index 78b83261..251aa310 100644 --- a/app/src/adapters/db/client.py +++ b/app/src/adapters/db/client.py @@ -23,7 +23,7 @@ class DBClient(abc.ABC, metaclass=abc.ABCMeta): """Database connection manager. - This class is used to manage database connections for the Flask app. + This class is used to manage database connections for the app. It has methods for getting a new connection or session object. A derived class must initialize _engine in the __init__ function diff --git a/app/src/adapters/db/clients/postgres_config.py b/app/src/adapters/db/clients/postgres_config.py index 961c20d6..41f4bebe 100644 --- a/app/src/adapters/db/clients/postgres_config.py +++ b/app/src/adapters/db/clients/postgres_config.py @@ -9,16 +9,16 @@ class PostgresDBConfig(PydanticBaseEnvConfig): - check_connection_on_init: bool = Field(True, env="DB_CHECK_CONNECTION_ON_INIT") - aws_region: Optional[str] = Field(None, env="AWS_REGION") - host: str = Field(env="DB_HOST") - name: str = Field(env="DB_NAME") - username: str = Field(env="DB_USER") - password: Optional[str] = Field(None, env="DB_PASSWORD") - db_schema: str = Field("public", env="DB_SCHEMA") - port: int = Field(5432, env="DB_PORT") - hide_sql_parameter_logs: bool = Field(True, env="HIDE_SQL_PARAMETER_LOGS") - ssl_mode: str = Field("require", env="DB_SSL_MODE") + check_connection_on_init: bool = Field(True, alias="DB_CHECK_CONNECTION_ON_INIT") + aws_region: Optional[str] = Field(None, alias="AWS_REGION") + host: str = Field(alias="DB_HOST") + name: str = Field(alias="DB_NAME") + username: str = Field(alias="DB_USER") + password: Optional[str] = Field(None, alias="DB_PASSWORD") + db_schema: str = Field("public", alias="DB_SCHEMA") + port: int = Field(5432, alias="DB_PORT") + hide_sql_parameter_logs: bool = Field(True, alias="HIDE_SQL_PARAMETER_LOGS") + ssl_mode: str = Field("require", alias="DB_SSL_MODE") def get_db_config() -> PostgresDBConfig: diff --git a/app/src/adapters/db/fastapi_db.py b/app/src/adapters/db/fastapi_db.py new file mode 100644 index 00000000..922e57eb --- /dev/null +++ b/app/src/adapters/db/fastapi_db.py @@ -0,0 +1,87 @@ +""" +This module has functionality to extend FastAPI with a database client. + +To initialize this FastAPI extension, call register_db_client() with an instance +of a FastAPI app and an instance of a DBClient. + +Example: + import src.adapters.db as db + import src.adapters.db.fastapi_db as fastapi_db + + db_client = db.PostgresDBClient() + app = fastapi.FastAPI() + fastapi_db.register_db_client(db_client, app) + +Then, in a request handler, specify the DB session as a dependency to get a +new database session that lasts for the duration of the request. + +Example: + import src.adapters.db as db + import src.adapters.db.fastapi_db as fastapi_db + + @app.get("/health") + def health(db_session: db.Session = Depends(fastapi_db.DbSessionDependency())): + with db_session.begin(): + ... + +The session can also be defined as an annotation in the typing like: +`db_session: typing.Annotated[db.Session, fastapi.Depends(fastapi_db.DbSessionDependency())]` +if you want to avoid having the function defined with a default value + +Alternatively, if you want to get the database client directly, use the get_db_client +function which requires you have the application itself via the Request + +Example: + from fastapi import Request + import src.adapters.db.fastapi_db as fastapi_db + + @app.get("/health") + def health(request: Request): + db_client = fastapi_db.get_db_client(request.app) + # db_client.get_connection() or db_client.get_session() +""" + +from typing import Generator + +import fastapi + +import src.adapters.db as db + +_FASTAPI_KEY_PREFIX = "db" +_DEFAULT_CLIENT_NAME = "default" + + +def register_db_client( + db_client: db.DBClient, app: fastapi.FastAPI, client_name: str = _DEFAULT_CLIENT_NAME +) -> None: + fastapi_state_key = f"{_FASTAPI_KEY_PREFIX}{client_name}" + setattr(app.state, fastapi_state_key, db_client) + + +def get_db_client(app: fastapi.FastAPI, client_name: str = _DEFAULT_CLIENT_NAME) -> db.DBClient: + fastapi_state_key = f"{_FASTAPI_KEY_PREFIX}{client_name}" + return getattr(app.state, fastapi_state_key) + + +class DbSessionDependency: + """ + FastAPI dependency class that can be used to fetch a DB session:: + + import src.adapters.db as db + import src.adapters.db.fastapi_db as fastapi_db + + @app.get("/health") + def health(db_session: db.Session = Depends(fastapi_db.DbSessionDependency())): + with db_session.begin(): + ... + + This approach to setting up a dependency allows us to take in a parameter (the client name) + See: https://fastapi.tiangolo.com/advanced/advanced-dependencies/#a-callable-instance + """ + + def __init__(self, client_name: str = _DEFAULT_CLIENT_NAME): + self.client_name = client_name + + def __call__(self, request: fastapi.Request) -> Generator[db.Session, None, None]: + with get_db_client(request.app, self.client_name).get_session() as session: + yield session diff --git a/app/src/adapters/db/flask_db.py b/app/src/adapters/db/flask_db.py deleted file mode 100644 index 779e5ad4..00000000 --- a/app/src/adapters/db/flask_db.py +++ /dev/null @@ -1,124 +0,0 @@ -""" -This module has functionality to extend Flask with a database client. - -To initialize this flask extension, call register_db_client() with an instance -of a Flask app and an instance of a DBClient. - -Example: - import src.adapters.db as db - import src.adapters.db.flask_db as flask_db - - db_client = db.PostgresDBClient() - app = APIFlask(__name__) - flask_db.register_db_client(db_client, app) - -Then, in a request handler, use the with_db_session decorator to get a -new database session that lasts for the duration of the request. - -Example: - import src.adapters.db as db - import src.adapters.db.flask_db as flask_db - - @app.route("/health") - @flask_db.with_db_session - def health(db_session: db.Session): - with db_session.begin(): - ... - - -Alternatively, if you want to get the database client directly, use the get_db -function. - -Example: - from flask import current_app - import src.adapters.db.flask_db as flask_db - - @app.route("/health") - def health(): - db_client = flask_db.get_db(current_app) - # db_client.get_connection() or db_client.get_session() -""" -from functools import wraps -from typing import Any, Callable, Concatenate, ParamSpec, TypeVar - -from flask import Flask, current_app - -import src.adapters.db as db -from src.adapters.db.client import DBClient - -_FLASK_EXTENSION_KEY_PREFIX = "db" -_DEFAULT_CLIENT_NAME = "default" - - -def register_db_client( - db_client: DBClient, app: Flask, client_name: str = _DEFAULT_CLIENT_NAME -) -> None: - """Initialize the Flask app. - - Add the database to the Flask app's extensions so that it can be - accessed by request handlers using the current app context. - - If you use multiple DB clients, you can differentiate them by - specifying a client_name. - - see get_db - """ - flask_extension_key = f"{_FLASK_EXTENSION_KEY_PREFIX}{client_name}" - app.extensions[flask_extension_key] = db_client - - -def get_db(app: Flask, client_name: str = _DEFAULT_CLIENT_NAME) -> DBClient: - """Get the database connection for the given Flask app. - - Use this in request handlers to access the database from the active Flask app. - - Specify the same client_name as used in register_db_client to get the correct client - - Example: - from flask import current_app - import src.adapters.db.flask_db as flask_db - - @app.route("/health") - def health(): - db_client = flask_db.get_db(current_app) - """ - flask_extension_key = f"{_FLASK_EXTENSION_KEY_PREFIX}{client_name}" - return app.extensions[flask_extension_key] - - -P = ParamSpec("P") -T = TypeVar("T") - - -def with_db_session( - *, client_name: str = _DEFAULT_CLIENT_NAME -) -> Callable[[Callable[Concatenate[db.Session, P], T]], Callable[P, T]]: - """Decorator for functions that need a database session. - - This decorator will create a new session object and pass it to the function - as the first positional argument. A transaction is not started automatically. - To start a transaction use db_session.begin() - - Usage: - @with_db_session() - def foo(db_session: db.Session): - ... - - @with_db_session() - def bar(db_session: db.Session, x, y): - ... - - @with_db_session(client_name="legacy_db") - def fiz(db_session: db.Session, x, y, z): - ... - """ - - def decorator(f: Callable[Concatenate[db.Session, P], T]) -> Callable[P, T]: - @wraps(f) - def wrapper(*args: Any, **kwargs: Any) -> T: - with get_db(current_app, client_name=client_name).get_session() as session: - return f(session, *args, **kwargs) - - return wrapper - - return decorator diff --git a/app/src/api/healthcheck.py b/app/src/api/healthcheck.py index 4dc5782b..60769670 100644 --- a/app/src/api/healthcheck.py +++ b/app/src/api/healthcheck.py @@ -1,33 +1,28 @@ import logging -from typing import Tuple -from apiflask import APIBlueprint -from flask import current_app +from fastapi import APIRouter, HTTPException, Request, status +from pydantic import BaseModel, Field from sqlalchemy import text -from werkzeug.exceptions import ServiceUnavailable -import src.adapters.db.flask_db as flask_db -from src.api import response -from src.api.schemas import request_schema +import src.adapters.db.fastapi_db as fastapi_db logger = logging.getLogger(__name__) +healthcheck_router = APIRouter(tags=["healthcheck"]) -class HealthcheckSchema(request_schema.OrderedSchema): - message: str +class HealthcheckModel(BaseModel): + message: str = Field(examples=["Service healthy"]) -healthcheck_blueprint = APIBlueprint("healthcheck", __name__, tag="Health") - -@healthcheck_blueprint.get("/health") -@healthcheck_blueprint.output(HealthcheckSchema) -@healthcheck_blueprint.doc(responses=[200, ServiceUnavailable.code]) -def health() -> Tuple[dict, int]: +@healthcheck_router.get("/health") +def health(request: Request) -> HealthcheckModel: try: - with flask_db.get_db(current_app).get_connection() as conn: + with fastapi_db.get_db_client(request.app).get_connection() as conn: assert conn.scalar(text("SELECT 1 AS healthy")) == 1 - return response.ApiResponse(message="Service healthy").asdict(), 200 - except Exception: + return HealthcheckModel(message="Service healthy") + except Exception as err: logger.exception("Connection to DB failure") - return response.ApiResponse(message="Service unavailable").asdict(), ServiceUnavailable.code + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Service unavailable" + ) from err diff --git a/app/src/api/index.py b/app/src/api/index.py new file mode 100644 index 00000000..58e77429 --- /dev/null +++ b/app/src/api/index.py @@ -0,0 +1,20 @@ +from fastapi import APIRouter +from fastapi.responses import HTMLResponse + +# This route won't appear on the OpenAPI docs +index_router = APIRouter(include_in_schema=False) + + +@index_router.get("/") +def get_index() -> HTMLResponse: + content = """ + + + Home + +

Home

+

Visit /docs to view the api documentation for this project.

+ + + """ + return HTMLResponse(content=content) diff --git a/app/src/api/response.py b/app/src/api/response.py deleted file mode 100644 index 70e2ae8f..00000000 --- a/app/src/api/response.py +++ /dev/null @@ -1,48 +0,0 @@ -import dataclasses -from typing import Optional - -from src.api.schemas import response_schema -from src.db.models.base import Base - - -@dataclasses.dataclass -class ValidationErrorDetail: - type: str - message: str = "" - rule: Optional[str] = None - field: Optional[str] = None - value: Optional[str] = None # Do not store PII data here, as it gets logged in some cases - - -class ValidationException(Exception): - __slots__ = ["errors", "message", "data"] - - def __init__( - self, - errors: list[ValidationErrorDetail], - message: str = "Invalid request", - data: Optional[dict | list[dict]] = None, - ): - self.errors = errors - self.message = message - self.data = data or {} - - -@dataclasses.dataclass -class ApiResponse: - """Base response model for all API responses.""" - - message: str - data: Optional[Base] = None - warnings: list[ValidationErrorDetail] = dataclasses.field(default_factory=list) - errors: list[ValidationErrorDetail] = dataclasses.field(default_factory=list) - - # This method is used to convert ApiResponse objects to a dictionary - # This is necessary because APIFlask has a bug that causes an exception to be - # thrown when returning objects from routes when BASE_RESPONSE_SCHEMA is set - # (See https://github.com/apiflask/apiflask/issues/384) - # Once that issue is fixed, this method can be removed and routes can simply - # return ApiResponse objects directly and allow APIFlask to serealize the objects - # to JSON automatically. - def asdict(self) -> dict: - return response_schema.ResponseSchema().dump(self) diff --git a/app/src/api/schemas/request_schema.py b/app/src/api/schemas/request_schema.py deleted file mode 100644 index 08eca1d9..00000000 --- a/app/src/api/schemas/request_schema.py +++ /dev/null @@ -1,11 +0,0 @@ -import apiflask - - -# Use ordered schemas to ensure that the order of fields in the generated OpenAPI -# schema is deterministic. -# This should no longer be needed once apiflask is updated to use ordered schemas -# TODO update apiflask with fix to default to ordered schemas -# See https://github.com/apiflask/apiflask/issues/385#issuecomment-1364463104 -class OrderedSchema(apiflask.Schema): - class Meta: - ordered = True diff --git a/app/src/api/schemas/response_schema.py b/app/src/api/schemas/response_schema.py deleted file mode 100644 index 5d89e85e..00000000 --- a/app/src/api/schemas/response_schema.py +++ /dev/null @@ -1,19 +0,0 @@ -from apiflask import fields - -from src.api.schemas import request_schema - - -class ValidationErrorSchema(request_schema.OrderedSchema): - type = fields.String(metadata={"description": "The type of error"}) - message = fields.String(metadata={"description": "The message to return"}) - rule = fields.String(metadata={"description": "The rule that failed"}) - field = fields.String(metadata={"description": "The field that failed"}) - value = fields.String(metadata={"description": "The value that failed"}) - - -class ResponseSchema(request_schema.OrderedSchema): - message = fields.String(metadata={"description": "The message to return"}) - data = fields.Field(metadata={"description": "The REST resource object"}, dump_default={}) - status_code = fields.Integer(metadata={"description": "The HTTP status code"}, dump_default=200) - warnings = fields.List(fields.Nested(ValidationErrorSchema), dump_default=[]) - errors = fields.List(fields.Nested(ValidationErrorSchema), dump_default=[]) diff --git a/app/src/api/users/__init__.py b/app/src/api/users/__init__.py deleted file mode 100644 index 8ddfd549..00000000 --- a/app/src/api/users/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -from src.api.users.user_blueprint import user_blueprint - -# import user_commands module to register the CLI commands on the user_blueprint -import src.api.users.user_commands # noqa: F401 E402 isort:skip - -# import user_commands module to register the API routes on the user_blueprint -import src.api.users.user_routes # noqa: F401 E402 isort:skip - - -__all__ = ["user_blueprint"] diff --git a/app/src/api/users/user_blueprint.py b/app/src/api/users/user_blueprint.py deleted file mode 100644 index 3155fe4e..00000000 --- a/app/src/api/users/user_blueprint.py +++ /dev/null @@ -1,3 +0,0 @@ -from apiflask import APIBlueprint - -user_blueprint = APIBlueprint("user", __name__, tag="User", cli_group="user") diff --git a/app/src/api/users/user_commands.py b/app/src/api/users/user_commands.py deleted file mode 100644 index 6e69481c..00000000 --- a/app/src/api/users/user_commands.py +++ /dev/null @@ -1,34 +0,0 @@ -import logging -import os.path as path -from typing import Optional - -import click - -import src.adapters.db as db -import src.adapters.db.flask_db as flask_db -import src.services.users as user_service -from src.api.users.user_blueprint import user_blueprint -from src.util.datetime_util import utcnow - -logger = logging.getLogger(__name__) - -user_blueprint.cli.help = "User commands" - - -@user_blueprint.cli.command("create-csv", help="Create a CSV of all users and their roles") -@flask_db.with_db_session() -@click.option( - "--dir", - default=".", - help="Directory to save output file in. Can be an S3 path (e.g. 's3://bucketname/folder/') Defaults to current directory ('.').", -) -@click.option( - "--filename", - default=None, - help="Filename to save output file as. Defaults to '[timestamp]-user-roles.csv.", -) -def create_csv(db_session: db.Session, dir: str, filename: Optional[str]) -> None: - if filename is None: - filename = utcnow().strftime("%Y-%m-%d-%H-%M-%S") + "-user-roles.csv" - filepath = path.join(dir, filename) - user_service.create_user_csv(db_session, filepath) diff --git a/app/src/api/users/user_routes.py b/app/src/api/users/user_routes.py index 338ae313..7b61fe39 100644 --- a/app/src/api/users/user_routes.py +++ b/app/src/api/users/user_routes.py @@ -1,57 +1,57 @@ import logging +import typing +import uuid from typing import Any +import fastapi.exception_handlers +from fastapi import APIRouter + import src.adapters.db as db -import src.adapters.db.flask_db as flask_db -import src.api.response as response +import src.adapters.db.fastapi_db as fastapi_db import src.api.users.user_schemas as user_schemas import src.services.users as user_service -import src.services.users as users -from src.api.users.user_blueprint import user_blueprint -from src.auth.api_key_auth import api_key_auth +from src.auth.api_key_auth import verify_api_key from src.db.models.user_models import User logger = logging.getLogger(__name__) -@user_blueprint.post("/v1/users") -@user_blueprint.input(user_schemas.UserSchema) -@user_blueprint.output(user_schemas.UserSchema, status_code=201) -@user_blueprint.auth_required(api_key_auth) -@flask_db.with_db_session() -def user_post(db_session: db.Session, user_params: users.CreateUserParams) -> dict: +user_router = APIRouter(tags=["user"], dependencies=[fastapi.Depends(verify_api_key)]) + + +@user_router.post("/v1/users", status_code=201, response_model=user_schemas.UserModelOut) +def user_post( + db_session: typing.Annotated[db.Session, fastapi.Depends(fastapi_db.DbSessionDependency())], + user_model: user_schemas.UserModel, +) -> User: """ POST /v1/users """ - user = user_service.create_user(db_session, user_params) + logger.info(user_model) + user = user_service.create_user(db_session, user_model) logger.info("Successfully inserted user", extra=get_user_log_params(user)) - return response.ApiResponse(message="Success", data=user).asdict() + return user -@user_blueprint.patch("/v1/users/") -# Allow partial updates. partial=true means requests that are missing -# required fields will not be rejected. -# https://marshmallow.readthedocs.io/en/stable/quickstart.html#partial-loading -@user_blueprint.input(user_schemas.UserSchema(partial=True)) -@user_blueprint.output(user_schemas.UserSchema) -@user_blueprint.auth_required(api_key_auth) -@flask_db.with_db_session() +@user_router.patch("/v1/users/{user_id}", response_model=user_schemas.UserModelOut) def user_patch( - db_session: db.Session, user_id: str, patch_user_params: users.PatchUserParams -) -> dict: + db_session: typing.Annotated[db.Session, fastapi.Depends(fastapi_db.DbSessionDependency())], + user_id: uuid.UUID, + patch_user_params: user_schemas.UserModelPatch, +) -> User: user = user_service.patch_user(db_session, user_id, patch_user_params) logger.info("Successfully patched user", extra=get_user_log_params(user)) - return response.ApiResponse(message="Success", data=user).asdict() + return user -@user_blueprint.get("/v1/users/") -@user_blueprint.output(user_schemas.UserSchema) -@user_blueprint.auth_required(api_key_auth) -@flask_db.with_db_session() -def user_get(db_session: db.Session, user_id: str) -> dict: +@user_router.get("/v1/users/{user_id}", response_model=user_schemas.UserModelOut) +def user_get( + db_session: typing.Annotated[db.Session, fastapi.Depends(fastapi_db.DbSessionDependency())], + user_id: uuid.UUID, +) -> User: user = user_service.get_user(db_session, user_id) logger.info("Successfully fetched user", extra=get_user_log_params(user)) - return response.ApiResponse(message="Success", data=user).asdict() + return user def get_user_log_params(user: User) -> dict[str, Any]: diff --git a/app/src/api/users/user_schemas.py b/app/src/api/users/user_schemas.py index b0fd2c14..b3872f73 100644 --- a/app/src/api/users/user_schemas.py +++ b/app/src/api/users/user_schemas.py @@ -1,44 +1,57 @@ -from apiflask import fields -from marshmallow import fields as marshmallow_fields +from datetime import date, datetime +from typing import Any +from uuid import UUID + +from pydantic import BaseModel, Field -from src.api.schemas import request_schema from src.db.models import user_models -class RoleSchema(request_schema.OrderedSchema): - type = marshmallow_fields.Enum( - user_models.RoleType, - by_value=True, - metadata={"description": "The name of the role"}, - ) +class RoleModel(BaseModel): + type: user_models.RoleType - # Note that user_id is not included in the API schema since the role - # will always be a nested fields of the API user - - -class UserSchema(request_schema.OrderedSchema): - id = fields.UUID(dump_only=True) - first_name = fields.String(metadata={"description": "The user's first name"}, required=True) - middle_name = fields.String(metadata={"description": "The user's middle name"}) - last_name = fields.String(metadata={"description": "The user's last name"}, required=True) - phone_number = fields.String( - required=True, - metadata={ - "description": "The user's phone number", - "example": "123-456-7890", - "pattern": r"^([0-9]|\*){3}\-([0-9]|\*){3}\-[0-9]{4}$", - }, - ) - date_of_birth = fields.Date( - metadata={"description": "The users date of birth"}, - required=True, + +class UserModel(BaseModel): + + first_name: str + middle_name: str | None = None + last_name: str + phone_number: str = Field( + pattern=r"^([0-9]|\*){3}\-([0-9]|\*){3}\-[0-9]{4}$", examples=["123-456-7890"] ) - is_active = fields.Boolean( - metadata={"description": "Whether the user is active"}, - required=True, + date_of_birth: date + is_active: bool + roles: list[RoleModel] + + +UNSET: Any = None + + +class UserModelPatch(UserModel): + # TODO - this is unfortunately what we need to do for PATCH to work + # and is making me think PATCH should just not be used in favor of PUT endpoints + # which may also be simpler for front-ends as well. + # + # Ideas: + # - Just go with a PUT endpoint + # - Make this how the main model works, merge the two + # - Be fine with duplication? + # + first_name: str = UNSET + # middle_initial does not need to be redefined + last_name: str = UNSET + phone_number: str = Field( + pattern=r"^([0-9]|\*){3}\-([0-9]|\*){3}\-[0-9]{4}$", + examples=["123-456-7890"], + default=UNSET, ) - roles = fields.List(fields.Nested(RoleSchema), required=True) + date_of_birth: date = UNSET + is_active: bool = UNSET + roles: list[RoleModel] = UNSET + + +class UserModelOut(UserModel): + id: UUID - # Output only fields in addition to id field - created_at = fields.DateTime(dump_only=True) - updated_at = fields.DateTime(dump_only=True) + created_at: datetime + updated_at: datetime diff --git a/app/src/app.py b/app/src/app.py index 2ca137d3..bc66af1a 100644 --- a/app/src/app.py +++ b/app/src/app.py @@ -1,97 +1,108 @@ import logging -import os -from typing import Optional -from apiflask import APIFlask -from flask import g -from werkzeug.exceptions import Unauthorized +import fastapi +import fastapi.params +import uvicorn +from fastapi.encoders import jsonable_encoder +from fastapi.exceptions import RequestValidationError, ResponseValidationError +from fastapi.responses import JSONResponse +from starlette_context.middleware import RawContextMiddleware import src.adapters.db as db -import src.adapters.db.flask_db as flask_db +import src.adapters.db.fastapi_db as fastapi_db import src.logger -import src.logger.flask_logger as flask_logger -from src.api.healthcheck import healthcheck_blueprint -from src.api.schemas import response_schema -from src.api.users import user_blueprint -from src.auth.api_key_auth import User, get_app_security_scheme +import src.logger.fastapi_logger as fastapi_logger +from src.api.healthcheck import healthcheck_router +from src.api.index import index_router +from src.api.users.user_routes import user_router +from src.app_config import AppConfig logger = logging.getLogger(__name__) -def create_app() -> APIFlask: - app = APIFlask(__name__) +def create_app() -> fastapi.FastAPI: + src.logger.init(__package__) -<<<<<<< Updated upstream - src.logging.init(__package__) -======= - root_logger = src.logging.init(__package__) ->>>>>>> Stashed changes - flask_logger.init_app(logging.root, app) + app = fastapi.FastAPI( + # Fields which appear on the OpenAPI docs + # title appears at the top of the OpenAPI page as well attached to all logs as "app.name" + title="Template Application FastAPI", + description="Template for a FastAPI Application", + contact={ + "name": "Nava PBC Engineering", + "url": "https://www.navapbc.com", + "email": "engineering@navapbc.com", + }, + # Global dependencies of every API endpoint + dependencies=get_global_dependencies(), + ) + + fastapi_logger.init_app(logging.root, app) + + configure_exception_handling(app) db_client = db.PostgresDBClient() - flask_db.register_db_client(db_client, app) + fastapi_db.register_db_client(db_client, app) + + register_routers(app) - configure_app(app) - register_blueprints(app) - register_index(app) + # Add this middleware last so that every other middleware has access to the context + # object we use to pass logging information around with. Middlewares work like a stack so + # the last one added is the first one executed. + app.add_middleware(RawContextMiddleware) + logger.info("Creating app") return app -def current_user(is_user_expected: bool = True) -> Optional[User]: - current = g.get("current_user") - if is_user_expected and current is None: - logger.error("No current user found for request") - raise Unauthorized - return current +def get_global_dependencies() -> list[fastapi.params.Depends]: + return [fastapi.Depends(fastapi_logger.add_url_rule_to_request_context)] -def configure_app(app: APIFlask) -> None: - # Modify the response schema to instead use the format of our ApiResponse class - # which adds additional details to the object. - # https://apiflask.com/schema/#base-response-schema-customization - app.config["BASE_RESPONSE_SCHEMA"] = response_schema.ResponseSchema +def register_routers(app: fastapi.FastAPI) -> None: + app.include_router(index_router) + app.include_router(healthcheck_router) + app.include_router(user_router) - # Set a few values for the Swagger endpoint - app.config["OPENAPI_VERSION"] = "3.0.3" - # Set various general OpenAPI config values - app.info = { - "title": "Template Application Flask", - "description": "Template API for a Flask Application", - "contact": { - "name": "Nava PBC Engineering", - "url": "https://www.navapbc.com", - "email": "engineering@navapbc.com", - }, - } +def configure_exception_handling(app: fastapi.FastAPI) -> None: - # Set the security schema and define the header param - # where we expect the API token to reside. - # See: https://apiflask.com/authentication/#use-external-authentication-library - app.security_schemes = get_app_security_scheme() + # Pydantic v2 by default returns a URL for each validation error pointing to their documentation + # which makes the errors overly verbose. We trim that out here before returning the error. + # See: https://fastapi.tiangolo.com/tutorial/handling-errors/?h=handlin#override-request-validation-exceptions + # + # If you would like to modify the format of the errors from Pydantic, you can restructure the response object here. + @app.exception_handler(RequestValidationError) + async def request_validation_exception_handler( + request: fastapi.Request, exc: RequestValidationError + ) -> JSONResponse: + # This matches the default request validation handler, but excludes the URL from the validation error. + return JSONResponse( + status_code=fastapi.status.HTTP_422_UNPROCESSABLE_ENTITY, + content={"detail": jsonable_encoder(exc.errors(), exclude={"url"})}, + ) + # TODO - do this a bit better, this is to prevent PII from being logged + # - should open a ticket against FastAPI as this looks like a relatively recent change: + # - https://github.com/tiangolo/fastapi/pull/10078 + def override_method(self: ResponseValidationError) -> str: + return f"{len(self._errors)} validation errors:\n" -def register_blueprints(app: APIFlask) -> None: - app.register_blueprint(healthcheck_blueprint) - app.register_blueprint(user_blueprint) + ResponseValidationError.__str__ = override_method # type: ignore -def get_project_root_dir() -> str: - return os.path.join(os.path.dirname(__file__), "..") +if __name__ == "__main__": + app_config = AppConfig() + uvicorn_params: dict = { + "host": app_config.host, + "port": app_config.port, + "factory": True, + "app_dir": "src/", + } + if app_config.use_reloader: + uvicorn_params["reload"] = True + else: + uvicorn_params["workers"] = 4 -def register_index(app: APIFlask) -> None: - @app.route("/") - @app.doc(hide=True) - def index() -> str: - return """ - - - Home - -

Home

-

Visit /docs to view the api documentation for this project.

- - - """ + uvicorn.run("app:create_app", **uvicorn_params) diff --git a/app/src/app_config.py b/app/src/app_config.py index aa780987..1c3a8339 100644 --- a/app/src/app_config.py +++ b/app/src/app_config.py @@ -6,6 +6,8 @@ class AppConfig(PydanticBaseEnvConfig): # from accessing the application. This is especially important if you are # running the application locally on a public network. This needs to be # overriden to 0.0.0.0 when running in a container - # See https://flask.palletsprojects.com/en/2.2.x/api/#flask.Flask.run host: str = "127.0.0.1" port: int = 8080 + + # TODO - figure out how we want this configured + use_reloader: bool = True diff --git a/app/src/auth/api_key_auth.py b/app/src/auth/api_key_auth.py index a1511f01..5806fac7 100644 --- a/app/src/auth/api_key_auth.py +++ b/app/src/auth/api_key_auth.py @@ -1,71 +1,15 @@ import logging import os -import uuid -from dataclasses import dataclass -from typing import Any -import flask -from apiflask import HTTPTokenAuth, abort +from fastapi import HTTPException, Security, status +from fastapi.security import APIKeyHeader logger = logging.getLogger(__name__) -# Initialize the authorization context -# this needs to be attached to your -# routes as `your_blueprint..auth_required(api_key_auth)` -# in order to enable authorization -api_key_auth = HTTPTokenAuth("ApiKey", header="X-Auth") +API_KEY_HEADER = APIKeyHeader(name="X-Auth") -def get_app_security_scheme() -> dict[str, Any]: - return {"ApiKeyAuth": {"type": "apiKey", "in": "header", "name": "X-Auth"}} - - -@dataclass -class User: - # WARNING: This is a very rudimentary - # user for demo purposes and is not - # a production ready approach. It exists - # purely to define a rough structure / example - id: uuid.UUID - sub_id: str - username: str - - def as_dict(self) -> dict[str, Any]: - # Connexion expects a dictionary it can - # use .get() on, so convert this to that format - return {"uid": self.id, "sub": self.sub_id} - - def get_user_log_attributes(self) -> dict: - # Note this gets called during authentication - # to attach the information to the flask global object - # which will in turn be attached to the log record - return {"current_user.id": str(self.id)} - - -API_AUTH_USER = User(uuid.uuid4(), "sub_id_1234", "API auth user") - - -@api_key_auth.verify_token -def verify_token(token: str) -> dict: - logger.info("Authenticating provided token") - - user = process_token(token) - - # Note that the current user can also be found - # by doing api_key_auth.current_user once in - # the request context. This is here in case - # multiple authentication approaches exist - # in your API, you don't need to check each - # one in order to figure out which was actually used - flask.g.current_user = user - flask.g.current_user_log_attributes = user.get_user_log_attributes() - - logger.info("Authentication successful") - - return user.as_dict() - - -def process_token(token: str) -> User: +def verify_api_key(token: str = Security(API_KEY_HEADER)) -> str: # WARNING: this isn't really a production ready # auth approach. In reality the user object returned # here should be pulled from a DB or auth service, but @@ -74,15 +18,15 @@ def process_token(token: str) -> User: expected_auth_token = os.getenv("API_AUTH_TOKEN", None) if not expected_auth_token: - logger.info( - "Authentication is not setup, please add an API_AUTH_TOKEN environment variable." + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Authentication is not setup, please add an API_AUTH_TOKEN environment variable.", ) - abort(401, "Authentication is not setup properly and the user cannot be authenticated") if token != expected_auth_token: - logger.info("Authentication failed for provided auth token.") - abort( - 401, "The server could not verify that you are authorized to access the URL requested" + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="The server could not verify that you are authorized to access the URL requested", ) - return API_AUTH_USER + return token diff --git a/app/src/db/migrations/env.py b/app/src/db/migrations/env.py index 29a80a46..d7134304 100644 --- a/app/src/db/migrations/env.py +++ b/app/src/db/migrations/env.py @@ -15,17 +15,7 @@ logger = logging.getLogger("migrations") # Initialize logging -with src.logging.init("migrations"): -<<<<<<< Updated upstream -======= - - if not config.get_main_option("sqlalchemy.url"): - uri = make_connection_uri(get_db_config()) - - # Escape percentage signs in the URI. - # https://alembic.sqlalchemy.org/en/latest/api/config.html#alembic.config.Config.set_main_option - config.set_main_option("sqlalchemy.url", uri.replace("%", "%%")) ->>>>>>> Stashed changes +with src.logger.init("migrations"): # add your model's MetaData object here # for 'autogenerate' support diff --git a/app/src/fastapi_app.py b/app/src/fastapi_app.py deleted file mode 100644 index d6fe27b2..00000000 --- a/app/src/fastapi_app.py +++ /dev/null @@ -1,38 +0,0 @@ -import fastapi -import logging -import logger as our_logging -from pydantic import BaseModel, Field - -logger = logging.getLogger(__name__) - -def create_app() -> fastapi.FastAPI: - our_logging.init(__package__) - - logger.info("Creating app") - return fastapi.FastAPI() - -app = create_app() - -@app.get("/") -def root(): - logger.info("hello - in the root") - return {"message": "hello"} - -class NestedA(BaseModel): - x: str = Field(examples=["hello there again"]) - - -class NestedB(BaseModel): - y: list[NestedA] - -class Item(BaseModel): - name: str - description: str | None = None - price: float - tax: float | None = None - z: list[NestedB] - -@app.post("/items/") -def create_item(item: Item) -> Item: - logger.info(item) - return item \ No newline at end of file diff --git a/app/src/logger/__init__.py b/app/src/logger/__init__.py index 30f0078d..cda6e6aa 100644 --- a/app/src/logger/__init__.py +++ b/app/src/logger/__init__.py @@ -3,15 +3,15 @@ There are two formatters for the log messages: human-readable and JSON. The formatter that is used is determined by the environment variable LOG_FORMAT. If the environment variable is not set, the JSON formatter -is used by default. See src.logging.formatters for more information. +is used by default. See src.logger.formatters for more information. The logger also adds a PII mask filter to the root logger. See -src.logging.pii for more information. +src.logger.pii for more information. Usage: - import src.logging + import src.logger - with src.logging.init("program name"): + with src.logger.init("program name"): ... Once the module has been initialized, the standard logging module can be @@ -28,6 +28,6 @@ import src.logger.config as config -def init(program_name: str) -> config.LoggingContext: +def init(program_name: str) -> config.LoggingContext: return config.LoggingContext(program_name) diff --git a/app/src/logger/audit.py b/app/src/logger/audit.py index 49b29a21..0ae27728 100644 --- a/app/src/logger/audit.py +++ b/app/src/logger/audit.py @@ -72,7 +72,9 @@ def handle_audit_event(event_name: str, args: tuple[Any, ...]) -> None: def log_audit_event(event_name: str, args: Sequence[Any], arg_names: Sequence[str]) -> None: """Log a message but only log recently repeated messages at intervals.""" extra = { - f"audit.args.{arg_name}": arg for arg_name, arg in zip(arg_names, args) if arg_name != "_" + f"audit.args.{arg_name}": arg + for arg_name, arg in zip(arg_names, args, strict=True) + if arg_name != "_" } key = (event_name, repr(args)) diff --git a/app/src/logger/config.py b/app/src/logger/config.py index 587afd8d..7cc70687 100644 --- a/app/src/logger/config.py +++ b/app/src/logger/config.py @@ -5,6 +5,8 @@ import sys from typing import Any, ContextManager, cast +from pydantic_settings import SettingsConfigDict + import src.logger.audit import src.logger.formatters as formatters import src.logger.pii as pii @@ -23,11 +25,9 @@ class LoggingConfig(PydanticBaseEnvConfig): format: str = "json" level: str = "INFO" enable_audit: bool = False - human_readable_formatter: PydanticBaseEnvConfig = HumanReadableFormatterConfig() + human_readable_formatter: HumanReadableFormatterConfig = HumanReadableFormatterConfig() - class Config: - env_prefix = "log_" - env_nested_delimiter = "__" + model_config = SettingsConfigDict(env_prefix="log_", env_nested_delimiter="__") class LoggingContext(ContextManager[None]): @@ -97,7 +97,7 @@ def _configure_logging(self) -> None: logging.root.setLevel(config.level) if config.enable_audit: - src.logging.audit.init() + src.logger.audit.init() # Configure loggers for third party packages logging.getLogger("alembic").setLevel(logging.INFO) @@ -105,6 +105,12 @@ def _configure_logging(self) -> None: logging.getLogger("sqlalchemy.pool").setLevel(logging.INFO) logging.getLogger("sqlalchemy.dialects.postgresql").setLevel(logging.INFO) + # Uvicorn sets up its own handlers with different formatting + # To keep things consistent, we override their handler to be ours + # Alternatively we could configure all of this directly, but that is significantly more configuration. + logging.getLogger("uvicorn").handlers = [self.console_handler] + logging.getLogger("uvicorn.access").handlers = [self.console_handler] + def get_formatter(config: LoggingConfig) -> logging.Formatter: """Return the formatter used by the root logger. diff --git a/app/src/logger/fastapi_logger.py b/app/src/logger/fastapi_logger.py new file mode 100644 index 00000000..f41195d1 --- /dev/null +++ b/app/src/logger/fastapi_logger.py @@ -0,0 +1,161 @@ +"""Module for adding standard logging functionality to a FastAPI app. + +This module configures an application's logger to add extra data +to all log messages. FastAPI application context data such as the +app name and request context data such as the request method, request url +rule, and query parameters are added to the log record. + +This module also configures the FastAPI application to log every requests start and end. + +Usage: + import src.logger.fastapi_logger as fastapi_logger + + logger = logging.getLogger(__name__) + app = create_app() + + fastapi_logger.init_app(logger, app) +""" +import logging +import time +import typing + +import fastapi +import starlette_context + +logger = logging.getLogger(__name__) +EXTRA_LOG_DATA_ATTR = "extra_log_data" + + +def init_app(app_logger: logging.Logger, app: fastapi.FastAPI) -> None: + """Initialize the FastAPI app logger. + + Adds FastAPI app context data and FastAPI request context data + to every log record using log filters. + See https://docs.python.org/3/howto/logging-cookbook.html#using-filters-to-impart-contextual-information + + Also configures the app to log every non-404 request using the given logger. + + Usage: + import src.logger.fastapi_logger as fastapi_logger + + logger = logging.getLogger(__name__) + app = create_app() + + fastapi_logger.init_app(logger, app) + """ + + # Need to add filters to each of the handlers rather than to the logger itself, since + # messages are passed directly to the ancestor loggers’ handlers bypassing any filters + # set on the ancestors. + # See https://docs.python.org/3/library/logging.html#logging.Logger.propagate + for handler in app_logger.handlers: + handler.addFilter(_add_request_context_info_to_log_record) + + _add_logging_middleware(app) + + logger.info("Initialized Fast API logger") + + +def _add_logging_middleware(app: fastapi.FastAPI) -> None: + """ + Add middleware that runs before we start processing a request to add + additional context to log messages and automatically log the start/end + + IMPORTANT: These are defined in the reverse-order that they execute + as middleware work like a stack. + """ + + # Log the start/end of a request and include the timing. + @app.middleware("http") + async def log_start_and_end_of_request( + request: fastapi.Request, + call_next: typing.Callable[[fastapi.Request], typing.Awaitable[fastapi.Response]], + ) -> fastapi.Response: + request_start_time = time.perf_counter() + + logger.info("start request") + response = await call_next(request) + + logger.info( + "end request", + extra={ + "response.status_code": response.status_code, + "response.content_length": response.headers.get("content-length", None), + "response.content_type": response.headers.get("content-type", None), + "response.charset": response.charset, + "response.time_ms": (time.perf_counter() - request_start_time) * 1000, + }, + ) + return response + + # Add general information regarding the request (route, request ID, method) to all + # log messages for the lifecycle of the request. + @app.middleware("http") + async def attach_request_context_info( + request: fastapi.Request, + call_next: typing.Callable[[fastapi.Request], typing.Awaitable[fastapi.Response]], + ) -> fastapi.Response: + # This will be added to all log messages for the rest of the request lifecycle + data = { + "app.name": request.app.title, + "request.id": request.headers.get("x-amzn-requestid", ""), + "request.method": request.method, + "request.path": request.scope.get("path", ""), + # Starlette does not resolve the URL rule (ie. the specific route) until after + # middleware runs, so the url_rule cannot be added here, see: add_url_rule_to_request_context for where that happens + } + + # Add query parameter data in the format request.query. = + # For example, the query string ?foo=bar&baz=qux would be added as + # request.query.foo = bar and request.query.baz = qux + # PII should be kept out of the URL, as URLs are logged in access logs. + # With that assumption, it is safe to log query parameters. + for key, value in request.query_params.items(): + data[f"request.query.{key}"] = value + + add_extra_data_to_current_request_logs(data) + + return await call_next(request) + + +def add_url_rule_to_request_context(request: fastapi.Request) -> None: + """ + Starlette, the underlying routing library that FastAPI does not determine the + route that will handle a request until after all middleware run. This method instead + relies on being used as a dependency (eg. FastAPI(dependencies=[Depends(add_url_rule_to_request_context)]) + which will make it always run for every route, but after all of the middlewares. + + See: https://github.com/encode/starlette/issues/685 which describes the issue in Starlette + """ + url_rule = "" + + api_route = request.scope.get("route", None) + if api_route: + url_rule = api_route.path + + add_extra_data_to_current_request_logs({"request.url_rule": url_rule}) + + +def _add_request_context_info_to_log_record(record: logging.LogRecord) -> bool: + """Add request context data to the log record. + + If there is no request context, then do not add any data. + """ + if not starlette_context.context.exists(): + return True + + extra_log_data: dict[str, str] = starlette_context.context.get(EXTRA_LOG_DATA_ATTR, {}) + record.__dict__.update(extra_log_data) + + return True + + +def add_extra_data_to_current_request_logs( + data: dict[str, str | int | float | bool | None] +) -> None: + """Add data to every log record for the current request.""" + assert starlette_context.context.exists(), "Must be in a request context" + + extra_log_data = starlette_context.context.get(EXTRA_LOG_DATA_ATTR, {}) + extra_log_data.update(data) + starlette_context.context[EXTRA_LOG_DATA_ATTR] = extra_log_data diff --git a/app/src/logger/flask_logger.py b/app/src/logger/flask_logger.py deleted file mode 100644 index 6aefa833..00000000 --- a/app/src/logger/flask_logger.py +++ /dev/null @@ -1,166 +0,0 @@ -"""Module for adding standard logging functionality to a Flask app. - -This module configures an application's logger to add extra data -to all log messages. Flask application context data such as the -app name and request context data such as the request method, request url -rule, and query parameters are added to the log record. - -This module also configures the Flask application to log every -non-404 request. - -Usage: - import src.logging.flask_logger as flask_logger - - logger = logging.getLogger(__name__) - app = create_app() - - flask_logger.init_app(logger, app) -""" -import logging -import time - -import flask - -logger = logging.getLogger(__name__) -EXTRA_LOG_DATA_ATTR = "extra_log_data" - - -def init_app(app_logger: logging.Logger, app: flask.Flask) -> None: - """Initialize the Flask app logger. - - Adds Flask app context data and Flask request context data - to every log record using log filters. - See https://docs.python.org/3/howto/logging-cookbook.html#using-filters-to-impart-contextual-information - - Also configures the app to log every non-404 request using the given logger. - - Usage: - import src.logging.flask_logger as flask_logger - - logger = logging.getLogger(__name__) - app = create_app() - - flask_logger.init_app(logger, app) - """ - - # Need to add filters to each of the handlers rather than to the logger itself, since - # messages are passed directly to the ancestor loggers’ handlers bypassing any filters - # set on the ancestors. - # See https://docs.python.org/3/library/logging.html#logging.Logger.propagate - for handler in app_logger.handlers: - handler.addFilter(_add_app_context_info_to_log_record) - handler.addFilter(_add_request_context_info_to_log_record) - - # Add request context data to every log record for the current request - # such as request id, request method, request path, and the matching Flask request url rule - app.before_request( - lambda: add_extra_data_to_current_request_logs(_get_request_context_info(flask.request)) - ) - - app.before_request(_track_request_start_time) - app.before_request(_log_start_request) - app.after_request(_log_end_request) - - app_logger.info("initialized flask logger") - - -def add_extra_data_to_current_request_logs( - data: dict[str, str | int | float | bool | None] -) -> None: - """Add data to every log record for the current request.""" - assert flask.has_request_context(), "Must be in a request context" - - extra_log_data = getattr(flask.g, EXTRA_LOG_DATA_ATTR, {}) - extra_log_data.update(data) - setattr(flask.g, EXTRA_LOG_DATA_ATTR, extra_log_data) - - -def _track_request_start_time() -> None: - """Store the request start time in flask.g""" - flask.g.request_start_time = time.perf_counter() - - -def _log_start_request() -> None: - """Log the start of a request. - - This function handles the Flask's before_request event. - See https://tedboy.github.io/flask/interface_src.application_object.html#flask.Flask.before_request - - Additional info about the request will be in the `extra` field - added by `_add_request_context_info_to_log_record` - """ - logger.info("start request") - - -def _log_end_request(response: flask.Response) -> flask.Response: - """Log the end of a request. - - This function handles the Flask's after_request event. - See https://tedboy.github.io/flask/interface_src.application_object.html#flask.Flask.after_request - - Additional info about the request will be in the `extra` field - added by `_add_request_context_info_to_log_record` - """ - - logger.info( - "end request", - extra={ - "response.status_code": response.status_code, - "response.content_length": response.content_length, - "response.content_type": response.content_type, - "response.mimetype": response.mimetype, - "response.time_ms": (time.perf_counter() - flask.g.request_start_time) * 1000, - }, - ) - return response - - -def _add_app_context_info_to_log_record(record: logging.LogRecord) -> bool: - """Add app context data to the log record. - - If there is no app context, then do not add any data. - """ - if not flask.has_app_context(): - return True - - assert flask.current_app is not None - record.__dict__ |= _get_app_context_info(flask.current_app) - - return True - - -def _add_request_context_info_to_log_record(record: logging.LogRecord) -> bool: - """Add request context data to the log record. - - If there is no request context, then do not add any data. - """ - if not flask.has_request_context(): - return True - - assert flask.request is not None - extra_log_data: dict[str, str] = getattr(flask.g, EXTRA_LOG_DATA_ATTR, {}) - record.__dict__.update(extra_log_data) - - return True - - -def _get_app_context_info(app: flask.Flask) -> dict: - return {"app.name": app.name} - - -def _get_request_context_info(request: flask.Request) -> dict: - data = { - "request.id": request.headers.get("x-amzn-requestid", ""), - "request.method": request.method, - "request.path": request.path, - "request.url_rule": str(request.url_rule), - } - - # Add query parameter data in the format request.query. = - # For example, the query string ?foo=bar&baz=qux would be added as - # request.query.foo = bar and request.query.baz = qux - # PII should be kept out of the URL, as URLs are logged in access logs. - # With that assumption, it is safe to log query parameters. - for key, value in request.args.items(): - data[f"request.query.{key}"] = value - return data diff --git a/app/src/logger/pii.py b/app/src/logger/pii.py index 7eb7fb14..e8c7ffe5 100644 --- a/app/src/logger/pii.py +++ b/app/src/logger/pii.py @@ -8,7 +8,7 @@ Example: import logging - import src.logging.pii as pii + import src.logger.pii as pii handler = logging.StreamHandler() handler.addFilter(pii.mask_pii) @@ -22,7 +22,7 @@ Example: import logging - import src.logging.pii as pii + import src.logger.pii as pii logger = logging.getLogger(__name__) logger.addFilter(pii.mask_pii) diff --git a/app/src/logging/__init__.py b/app/src/logging/__init__.py deleted file mode 100644 index 3d6fe9d2..00000000 --- a/app/src/logging/__init__.py +++ /dev/null @@ -1,86 +0,0 @@ -"""Module for initializing logging configuration for the application. - -There are two formatters for the log messages: human-readable and JSON. -The formatter that is used is determined by the environment variable -LOG_FORMAT. If the environment variable is not set, the JSON formatter -is used by default. See src.logging.formatters for more information. - -The logger also adds a PII mask filter to the root logger. See -src.logging.pii for more information. - -Usage: - import src.logging - - with src.logging.init("program name"): - ... - -Once the module has been initialized, the standard logging module can be -used to log messages: - -Example: - import logging - - logger = logging.getLogger(__name__) - logger.info("message") -""" -from contextlib import contextmanager -import logging -import os -import platform -import pwd -import sys -from typing import Any, ContextManager, cast - -import src.logging.config as config - -logger = logging.getLogger(__name__) -_original_argv = tuple(sys.argv) - - -class Log: - def __init__(self, program_name: str) -> None: - self.program_name = program_name - self.root_logger, self.stream_handler = config.configure_logging() - log_program_info(self.program_name) - - def __enter__(self) -> logging.Logger: - return self.root_logger - - def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: - self.root_logger.removeHandler(self.stream_handler) - -@contextmanager -def init(program_name: str): - stream_handler = config.configure_logging() - log_program_info(program_name) - yield - logging.root.removeHandler(stream_handler) - -def log_program_info(program_name: str) -> None: - logger.info( - "start %s: %s %s %s, hostname %s, pid %i, user %i(%s)", - program_name, - platform.python_implementation(), - platform.python_version(), - platform.system(), - platform.node(), - os.getpid(), - os.getuid(), - pwd.getpwuid(os.getuid()).pw_name, - extra={ - "hostname": platform.node(), - "cpu_count": os.cpu_count(), - # If mypy is run on a mac, it will throw a module has no attribute error, even though - # we never actually access it with the conditional. - # - # However, we can't just silence this error, because on linux (e.g. CI/CD) that will - # throw an unused “type: ignore” comment error. Casting to Any instead ensures this - # passes regardless of where mypy is being run - "cpu_usable": ( - len(cast(Any, os).sched_getaffinity(0)) - if "sched_getaffinity" in dir(os) - else "unknown" - ), - }, - ) - logger.info("invoked as: %s", " ".join(_original_argv)) diff --git a/app/src/logging/config.py b/app/src/logging/config.py deleted file mode 100644 index d79b26fc..00000000 --- a/app/src/logging/config.py +++ /dev/null @@ -1,84 +0,0 @@ -from contextlib import contextmanager -import logging -import sys -from typing import Generator, Tuple - -import src.logging.audit -import src.logging.formatters as formatters -import src.logging.pii as pii -from src.util.env_config import PydanticBaseEnvConfig - -logger = logging.getLogger(__name__) - -is_initialized = False - - -class HumanReadableFormatterConfig(PydanticBaseEnvConfig): - message_width: int = formatters.HUMAN_READABLE_FORMATTER_DEFAULT_MESSAGE_WIDTH - - -class LoggingConfig(PydanticBaseEnvConfig): - format = "json" - level = "INFO" - enable_audit = True - human_readable_formatter = HumanReadableFormatterConfig() - - class Config: - env_prefix = "log_" - env_nested_delimiter = "__" - -@contextmanager -def configure_logging() -> logging.Handler: - """Configure logging for the application. - - Configures the root module logger to log to stdout. - Adds a PII mask filter to the root logger. - Also configures log levels third party packages. - """ - config = LoggingConfig() - - # Loggers can be configured using config functions defined - # in logging.config or by directly making calls to the main API - # of the logging module (see https://docs.python.org/3/library/logging.config.html) - # We opt to use the main API using functions like `addHandler` which is - # non-destructive, i.e. it does not overwrite any existing handlers. - # In contrast, logging.config.dictConfig() would overwrite any existing loggers. - # This is important during testing, since fixtures like `caplog` add handlers that would - # get overwritten if we call logging.config.dictConfig() during the scope of the test. - console_handler = logging.StreamHandler(sys.stdout) - formatter = get_formatter(config) - console_handler.setFormatter(formatter) - console_handler.addFilter(pii.mask_pii) - logging.root.removeHandler(console_handler) - logging.root.addHandler(console_handler) - logging.root.setLevel(config.level) - - if config.enable_audit: - src.logging.audit.init() - - # Configure loggers for third party packages - logging.getLogger("alembic").setLevel(logging.INFO) - logging.getLogger("werkzeug").setLevel(logging.WARN) - logging.getLogger("sqlalchemy.pool").setLevel(logging.INFO) - logging.getLogger("sqlalchemy.dialects.postgresql").setLevel(logging.INFO) - - return console_handler - - - -def get_formatter(config: LoggingConfig) -> logging.Formatter: - """Return the formatter used by the root logger. - - The formatter is determined by the environment variable LOG_FORMAT. If the - environment variable is not set, the JSON formatter is used by default. - """ - if config.format == "human-readable": - return get_human_readable_formatter(config.human_readable_formatter) - return formatters.JsonFormatter() - - -def get_human_readable_formatter( - config: HumanReadableFormatterConfig, -) -> formatters.HumanReadableFormatter: - """Return the human readable formatter used by the root logger.""" - return formatters.HumanReadableFormatter(message_width=config.message_width) diff --git a/app/src/services/users/__init__.py b/app/src/services/users/__init__.py index 9d01b20f..9a001819 100644 --- a/app/src/services/users/__init__.py +++ b/app/src/services/users/__init__.py @@ -1,12 +1,9 @@ -from .create_user import CreateUserParams, RoleParams, create_user +from .create_user import create_user from .create_user_csv import create_user_csv from .get_user import get_user -from .patch_user import PatchUserParams, patch_user +from .patch_user import patch_user __all__ = [ - "CreateUserParams", - "PatchUserParams", - "RoleParams", "create_user", "get_user", "patch_user", diff --git a/app/src/services/users/create_user.py b/app/src/services/users/create_user.py index 9721594a..53050ac3 100644 --- a/app/src/services/users/create_user.py +++ b/app/src/services/users/create_user.py @@ -1,40 +1,21 @@ -from datetime import date -from typing import TypedDict - from src.adapters.db import Session -from src.db.models import user_models +from src.api.users.user_schemas import UserModel from src.db.models.user_models import Role, User -class RoleParams(TypedDict): - type: user_models.RoleType - - -class CreateUserParams(TypedDict): - first_name: str - middle_name: str - last_name: str - phone_number: str - date_of_birth: date - is_active: bool - roles: list[RoleParams] - - # TODO: separate controller and service concerns # https://github.com/navapbc/template-application-flask/issues/49#issue-1505008251 -# TODO: Use classes / objects as inputs to service methods -# https://github.com/navapbc/template-application-flask/issues/52 -def create_user(db_session: Session, user_params: CreateUserParams) -> User: +def create_user(db_session: Session, user_model: UserModel) -> User: with db_session.begin(): # TODO: move this code to service and/or persistence layer user = User( - first_name=user_params["first_name"], - middle_name=user_params["middle_name"], - last_name=user_params["last_name"], - phone_number=user_params["phone_number"], - date_of_birth=user_params["date_of_birth"], - is_active=user_params["is_active"], - roles=[Role(type=role["type"]) for role in user_params["roles"]], + first_name=user_model.first_name, + middle_name=user_model.middle_name, + last_name=user_model.last_name, + phone_number=user_model.phone_number, + date_of_birth=user_model.date_of_birth, + is_active=user_model.is_active, + roles=[Role(type=role.type) for role in user_model.roles], ) db_session.add(user) return user diff --git a/app/src/services/users/get_user.py b/app/src/services/users/get_user.py index df4138f5..c6e17e60 100644 --- a/app/src/services/users/get_user.py +++ b/app/src/services/users/get_user.py @@ -1,4 +1,6 @@ -import apiflask +import uuid + +from fastapi import HTTPException from sqlalchemy import orm from src.adapters.db import Session @@ -7,15 +9,13 @@ # TODO: separate controller and service concerns # https://github.com/navapbc/template-application-flask/issues/49#issue-1505008251 -# TODO: Use classes / objects as inputs to service methods -# https://github.com/navapbc/template-application-flask/issues/52 -def get_user(db_session: Session, user_id: str) -> User: +def get_user(db_session: Session, user_id: uuid.UUID) -> User: # TODO: move this to service and/or persistence layer result = db_session.get(User, user_id, options=[orm.selectinload(User.roles)]) if result is None: # TODO move HTTP related logic out of service layer to controller layer and just return None from here # https://github.com/navapbc/template-application-flask/pull/51#discussion_r1053754975 - raise apiflask.HTTPError(404, message=f"Could not find user with ID {user_id}") + raise HTTPException(status_code=404, detail=f"Could not find user with ID {user_id}") return result diff --git a/app/src/services/users/patch_user.py b/app/src/services/users/patch_user.py index 3a21c039..06026c12 100644 --- a/app/src/services/users/patch_user.py +++ b/app/src/services/users/patch_user.py @@ -1,34 +1,21 @@ import bisect -from datetime import date +import uuid from operator import attrgetter -from typing import TypedDict -import apiflask +from fastapi import HTTPException from sqlalchemy import orm from src.adapters.db import Session +from src.api.users.user_schemas import RoleModel, UserModelPatch from src.db.models.user_models import Role, User -from src.services.users.create_user import RoleParams - - -class PatchUserParams(TypedDict, total=False): - first_name: str - middle_name: str - last_name: str - phone_number: str - date_of_birth: date - is_active: bool - roles: list[RoleParams] # TODO: separate controller and service concerns # https://github.com/navapbc/template-application-flask/issues/49#issue-1505008251 -# TODO: Use classes / objects as inputs to service methods -# https://github.com/navapbc/template-application-flask/issues/52 def patch_user( db_session: Session, - user_id: str, - patch_user_params: PatchUserParams, + user_id: uuid.UUID, + patch_user_params: UserModelPatch, ) -> User: with db_session.begin(): @@ -38,21 +25,21 @@ def patch_user( if user is None: # TODO move HTTP related logic out of service layer to controller layer and just return None from here # https://github.com/navapbc/template-application-flask/pull/51#discussion_r1053754975 - raise apiflask.HTTPError(404, message=f"Could not find user with ID {user_id}") + raise HTTPException(status_code=404, detail=f"Could not find user with ID {user_id}") - for key, value in patch_user_params.items(): + for key, value in patch_user_params.model_dump(exclude_unset=True).items(): if key == "roles": - _handle_role_patch(db_session, user, patch_user_params["roles"]) + _handle_role_patch(db_session, user, patch_user_params.roles) continue setattr(user, key, value) return user -def _handle_role_patch(db_session: Session, user: User, request_roles: list[RoleParams]) -> None: +def _handle_role_patch(db_session: Session, user: User, request_roles: list[RoleModel]) -> None: current_role_types = set([role.type for role in user.roles]) - request_role_types = set([role["type"] for role in request_roles]) + request_role_types = set([role.type for role in request_roles]) roles_to_delete = [role for role in user.roles if role.type not in request_role_types] diff --git a/app/src/util/env_config.py b/app/src/util/env_config.py index 40ea8a45..56017bd7 100644 --- a/app/src/util/env_config.py +++ b/app/src/util/env_config.py @@ -1,6 +1,6 @@ import os -from pydantic_settings import BaseSettings +from pydantic_settings import BaseSettings, SettingsConfigDict import src @@ -12,5 +12,4 @@ class PydanticBaseEnvConfig(BaseSettings): - class Config: - env_file = env_file + model_config = SettingsConfigDict(env_file=env_file) diff --git a/app/tests/conftest.py b/app/tests/conftest.py index e6b1c198..6c4d2cf8 100644 --- a/app/tests/conftest.py +++ b/app/tests/conftest.py @@ -2,10 +2,10 @@ import _pytest.monkeypatch import boto3 -import flask -import flask.testing +import fastapi import moto import pytest +from fastapi.testclient import TestClient import src.adapters.db as db import src.app as app_entry @@ -115,18 +115,13 @@ def enable_factory_create(monkeypatch, db_session) -> db.Session: # Make app session scoped so the database connection pool is only created once # for the test session. This speeds up the tests. @pytest.fixture(scope="session") -def app(db_client) -> flask.Flask: +def app(db_client) -> fastapi.FastAPI: return app_entry.create_app() @pytest.fixture -def client(app: flask.Flask) -> flask.testing.FlaskClient: - return app.test_client() - - -@pytest.fixture -def cli_runner(app: flask.Flask) -> flask.testing.CliRunner: - return app.test_cli_runner() +def client(app: fastapi.FastAPI) -> TestClient: + return TestClient(app) @pytest.fixture diff --git a/app/tests/src/adapters/db/test_fastapi_db.py b/app/tests/src/adapters/db/test_fastapi_db.py new file mode 100644 index 00000000..32dab0bb --- /dev/null +++ b/app/tests/src/adapters/db/test_fastapi_db.py @@ -0,0 +1,56 @@ +import typing + +import pytest +from fastapi import Depends, FastAPI, Request +from fastapi.testclient import TestClient +from sqlalchemy import text + +import src.adapters.db as db +import src.adapters.db.fastapi_db as fastapi_db + + +# Define an isolated example FastAPI app fixture specific to this test module +# to avoid dependencies on any project-specific fixtures in conftest.py +@pytest.fixture +def example_app() -> FastAPI: + app = FastAPI() + db_client = db.PostgresDBClient() + fastapi_db.register_db_client(db_client, app) + return app + + +def test_get_db(example_app: FastAPI): + @example_app.get("/hello") + def hello(request: Request) -> dict: + with fastapi_db.get_db_client(request.app).get_connection() as conn: + return {"data": conn.scalar(text("SELECT 'hello, world'"))} + + response = TestClient(example_app).get("/hello") + assert response.json() == {"data": "hello, world"} + + +def test_with_db_session_depends(example_app: FastAPI): + @example_app.get("/hello") + def hello(db_session: typing.Annotated[db.Session, Depends(fastapi_db.DbSessionDependency())]): + with db_session.begin(): + return {"data": db_session.scalar(text("SELECT 'hello, world'"))} + + response = TestClient(example_app).get("/hello") + assert response.json() == {"data": "hello, world"} + + +def test_with_db_session_depends_not_default_name(example_app: FastAPI): + db_client = db.PostgresDBClient() + fastapi_db.register_db_client(db_client, example_app, client_name="something_else") + + @example_app.get("/hello") + def hello( + db_session: typing.Annotated[ + db.Session, Depends(fastapi_db.DbSessionDependency(client_name="something_else")) + ] + ): + with db_session.begin(): + return {"data": db_session.scalar(text("SELECT 'hello, world'"))} + + response = TestClient(example_app).get("/hello") + assert response.json() == {"data": "hello, world"} diff --git a/app/tests/src/adapters/db/test_flask_db.py b/app/tests/src/adapters/db/test_flask_db.py deleted file mode 100644 index c95ee12d..00000000 --- a/app/tests/src/adapters/db/test_flask_db.py +++ /dev/null @@ -1,51 +0,0 @@ -import pytest -from flask import Flask, current_app -from sqlalchemy import text - -import src.adapters.db as db -import src.adapters.db.flask_db as flask_db - - -# Define an isolated example Flask app fixture specific to this test module -# to avoid dependencies on any project-specific fixtures in conftest.py -@pytest.fixture -def example_app() -> Flask: - app = Flask(__name__) - db_client = db.PostgresDBClient() - flask_db.register_db_client(db_client, app) - return app - - -def test_get_db(example_app: Flask): - @example_app.route("/hello") - def hello(): - with flask_db.get_db(current_app).get_connection() as conn: - return {"data": conn.scalar(text("SELECT 'hello, world'"))} - - response = example_app.test_client().get("/hello") - assert response.get_json() == {"data": "hello, world"} - - -def test_with_db_session(example_app: Flask): - @example_app.route("/hello") - @flask_db.with_db_session() - def hello(db_session: db.Session): - with db_session.begin(): - return {"data": db_session.scalar(text("SELECT 'hello, world'"))} - - response = example_app.test_client().get("/hello") - assert response.get_json() == {"data": "hello, world"} - - -def test_with_db_session_not_default_name(example_app: Flask): - db_client = db.PostgresDBClient() - flask_db.register_db_client(db_client, example_app, client_name="something_else") - - @example_app.route("/hello") - @flask_db.with_db_session(client_name="something_else") - def hello(db_session: db.Session): - with db_session.begin(): - return {"data": db_session.scalar(text("SELECT 'hello, world'"))} - - response = example_app.test_client().get("/hello") - assert response.get_json() == {"data": "hello, world"} diff --git a/app/tests/src/auth/test_api_key_auth.py b/app/tests/src/auth/test_api_key_auth.py index 152cc63f..cf28bf5f 100644 --- a/app/tests/src/auth/test_api_key_auth.py +++ b/app/tests/src/auth/test_api_key_auth.py @@ -1,29 +1,19 @@ import pytest -from apiflask import HTTPError -from flask import g +from fastapi import HTTPException -from src.auth.api_key_auth import API_AUTH_USER, verify_token +from src.auth.api_key_auth import verify_api_key -def test_verify_token_success(app, api_auth_token): - # Passing it the configured auth token successfully returns a user - with app.app_context(): # So we can attach the user to the flask app - user_map = verify_token(api_auth_token) - - assert user_map.get("uid") == API_AUTH_USER.id - assert g.get("current_user") == API_AUTH_USER - - -def test_verify_token_invalid_token(api_auth_token): +def test_verify_api_key_invalid_token(api_auth_token): # If you pass it the wrong token - with pytest.raises(HTTPError): - verify_token("not the right token") + with pytest.raises(HTTPException): + verify_api_key("not the right token") -def test_verify_token_no_configuration(monkeypatch): +def test_verify_api_key_no_configuration(monkeypatch): # Remove the API_AUTH_TOKEN env var if set in # your local environment monkeypatch.delenv("API_AUTH_TOKEN", raising=False) # If the auth token is not setup - with pytest.raises(HTTPError): - verify_token("any token") + with pytest.raises(HTTPException): + verify_api_key("any token") diff --git a/app/src/api/schemas/__init__.py b/app/tests/src/logger/__init__.py similarity index 100% rename from app/src/api/schemas/__init__.py rename to app/tests/src/logger/__init__.py diff --git a/app/tests/src/logging/test_audit.py b/app/tests/src/logger/test_audit.py similarity index 98% rename from app/tests/src/logging/test_audit.py rename to app/tests/src/logger/test_audit.py index b361f54b..cb4aa944 100644 --- a/app/tests/src/logging/test_audit.py +++ b/app/tests/src/logger/test_audit.py @@ -1,5 +1,5 @@ # -# Tests for src.logging.audit. +# Tests for src.logger.audit. # import logging @@ -141,7 +141,7 @@ def test_audit_hook( pass assert len(caplog.records) == len(expected_records) - for record, expected_record in zip(caplog.records, expected_records): + for record, expected_record in zip(caplog.records, expected_records, strict=True): assert record.levelname == "AUDIT" assert_record_match(record, expected_record) @@ -161,7 +161,7 @@ def test_os_kill(init_audit_hook, caplog: pytest.LogCaptureFixture): ] assert len(caplog.records) == len(expected_records) - for record, expected_record in zip(caplog.records, expected_records): + for record, expected_record in zip(caplog.records, expected_records, strict=True): assert record.levelname == "AUDIT" assert_record_match(record, expected_record) @@ -232,7 +232,7 @@ def test_repeated_audit_logs( ] assert len(caplog.records) == len(expected_records) - for record, expected_record in zip(caplog.records, expected_records): + for record, expected_record in zip(caplog.records, expected_records, strict=True): assert record.levelname == "AUDIT" assert_record_match(record, expected_record) diff --git a/app/tests/src/logging/test_flask_logger.py b/app/tests/src/logger/test_fastapi_logger.py similarity index 57% rename from app/tests/src/logging/test_flask_logger.py rename to app/tests/src/logger/test_fastapi_logger.py index 67315c80..4b6a803e 100644 --- a/app/tests/src/logging/test_flask_logger.py +++ b/app/tests/src/logger/test_fastapi_logger.py @@ -3,14 +3,21 @@ import time import pytest -from flask import Flask +from fastapi import FastAPI +from fastapi.testclient import TestClient +from starlette_context.middleware import RawContextMiddleware -import src.logger.flask_logger as flask_logger +import src.logger.fastapi_logger as fastapi_logger +from src.app import get_global_dependencies from tests.lib.assertions import assert_dict_contains @pytest.fixture def logger(): + # The test client used in tests calls the httpx library + # which does its own logging which add noise we don't want. + logging.getLogger("httpx").setLevel(logging.ERROR) + logger = logging.getLogger("src") before_level = logger.level @@ -24,14 +31,16 @@ def logger(): @pytest.fixture def app(logger): - app = Flask("test_app_name") + app = FastAPI(title="test_app_name", dependencies=get_global_dependencies()) - @app.get("/hello/") - def hello(name): + @app.get("/hello/{name}") + def hello(name: str) -> str: logging.getLogger("src.hello").info(f"hello, {name}!") return "ok" - flask_logger.init_app(logger, app) + fastapi_logger.init_app(logger, app) + app.add_middleware(RawContextMiddleware) + return app @@ -43,10 +52,11 @@ def hello(name): {"msg": "hello, jane!"}, { "msg": "end request", + "app.name": "test_app_name", "response.status_code": 200, - "response.content_length": 2, - "response.content_type": "text/html; charset=utf-8", - "response.mimetype": "text/html", + "response.content_length": "4", + "response.content_type": "application/json", + "response.charset": "utf-8", }, ], id="200", @@ -57,10 +67,11 @@ def hello(name): {"msg": "start request"}, { "msg": "end request", + "app.name": "test_app_name", "response.status_code": 404, - "response.content_length": 207, - "response.content_type": "text/html; charset=utf-8", - "response.mimetype": "text/html", + "response.content_length": "22", + "response.content_type": "application/json", + "response.charset": "utf-8", }, ], id="404", @@ -73,69 +84,63 @@ def hello(name): test_request_lifecycle_logs_data, ) def test_request_lifecycle_logs( - app: Flask, caplog: pytest.LogCaptureFixture, route, expected_extras + app: FastAPI, caplog: pytest.LogCaptureFixture, route, expected_extras ): - app.test_client().get(route) + TestClient(app).get(route) # Assert that the log messages are present # There should be the route log message that is logged in the before_request handler # as part of every request, followed by the log message in the route handler itself. # then the log message in the after_request handler. - assert len(caplog.records) == len(expected_extras) - for record, expected_extra in zip(caplog.records, expected_extras): - assert_dict_contains(record.__dict__, expected_extra) - - -def test_app_context_extra_attributes(app: Flask, caplog: pytest.LogCaptureFixture): - # Assert that extra attributes related to the app context are present in all log records - expected_extra = {"app.name": "test_app_name"} - - app.test_client().get("/hello/jane") - - assert len(caplog.records) > 0 - for record in caplog.records: + assert len(caplog.records) == len(expected_extras), "|".join( + [record.__dict__["msg"] for record in caplog.records] + ) + for record, expected_extra in zip(caplog.records, expected_extras, strict=True): assert_dict_contains(record.__dict__, expected_extra) -def test_request_context_extra_attributes(app: Flask, caplog: pytest.LogCaptureFixture): +def test_request_context_extra_attributes(app: FastAPI, caplog: pytest.LogCaptureFixture): # Assert that the extra attributes related to the request context are present in all log records expected_extra = { "request.id": "", "request.method": "GET", "request.path": "/hello/jane", - "request.url_rule": "/hello/", + # "request.url_rule": "/hello/", "request.query.up": "high", "request.query.down": "low", } - app.test_client().get("/hello/jane?up=high&down=low") + TestClient(app).get("/hello/jane?up=high&down=low") assert len(caplog.records) > 0 for record in caplog.records: assert_dict_contains(record.__dict__, expected_extra) + # After the first log (the start request log), we'll have the url_rule + expected_extra["request.url_rule"] = "/hello/{name}" + -def test_add_extra_log_data_for_current_request(app: Flask, caplog: pytest.LogCaptureFixture): - @app.get("/pet/") - def pet(name): - flask_logger.add_extra_data_to_current_request_logs({"pet.name": name}) +def test_add_extra_log_data_for_current_request(app: FastAPI, caplog: pytest.LogCaptureFixture): + @app.get("/pet/{name}") + def pet(name: str) -> str: + fastapi_logger.add_extra_data_to_current_request_logs({"pet.name": name}) logging.getLogger("test.pet").info(f"petting {name}") return "ok" - app.test_client().get("/pet/kitty") + TestClient(app).get("/pet/kitty") last_record = caplog.records[-1] assert_dict_contains(last_record.__dict__, {"pet.name": "kitty"}) -def test_log_response_time(app: Flask, caplog: pytest.LogCaptureFixture): +def test_log_response_time(app: FastAPI, caplog: pytest.LogCaptureFixture): @app.get("/sleep") def sleep(): time.sleep(0.1) # 0.1 s = 100 ms return "ok" - app.test_client().get("/sleep") + TestClient(app).get("/sleep") last_record = caplog.records[-1] assert "response.time_ms" in last_record.__dict__ diff --git a/app/tests/src/logging/test_formatters.py b/app/tests/src/logger/test_formatters.py similarity index 100% rename from app/tests/src/logging/test_formatters.py rename to app/tests/src/logger/test_formatters.py diff --git a/app/tests/src/logging/test_logging.py b/app/tests/src/logger/test_logging.py similarity index 97% rename from app/tests/src/logging/test_logging.py rename to app/tests/src/logger/test_logging.py index 0a1b530c..afad9cc5 100644 --- a/app/tests/src/logging/test_logging.py +++ b/app/tests/src/logger/test_logging.py @@ -12,7 +12,7 @@ def init_test_logger(caplog: pytest.LogCaptureFixture, monkeypatch: pytest.MonkeyPatch): caplog.set_level(logging.DEBUG) monkeypatch.setenv("LOG_FORMAT", "human-readable") - with src.logging.init("test_logging"): + with src.logger.init("test_logging"): yield @@ -27,7 +27,7 @@ def test_init(caplog: pytest.LogCaptureFixture, monkeypatch, log_format, expecte caplog.set_level(logging.DEBUG) monkeypatch.setenv("LOG_FORMAT", log_format) - with src.logging.init("test_logging"): + with src.logger.init("test_logging"): records = caplog.records assert len(records) == 2 diff --git a/app/tests/src/logging/test_pii.py b/app/tests/src/logger/test_pii.py similarity index 100% rename from app/tests/src/logging/test_pii.py rename to app/tests/src/logger/test_pii.py diff --git a/app/tests/src/logging/__init__.py b/app/tests/src/logging/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/app/tests/src/route/test_user_route.py b/app/tests/src/route/test_user_route.py index f8b38efd..3eeda876 100644 --- a/app/tests/src/route/test_user_route.py +++ b/app/tests/src/route/test_user_route.py @@ -28,7 +28,7 @@ def base_request(): @pytest.fixture def created_user(client, api_auth_token, base_request): response = client.post("/v1/users", json=base_request, headers={"X-Auth": api_auth_token}) - return response.get_json()["data"] + return response.json() test_create_and_get_user_data = [ @@ -45,7 +45,7 @@ def test_create_and_get_user(client, base_request, api_auth_token, roles): "roles": roles, } post_response = client.post("/v1/users", json=request, headers={"X-Auth": api_auth_token}) - post_response_data = post_response.get_json()["data"] + post_response_data = post_response.json() expected_response = { **request, "id": post_response_data["id"], @@ -59,26 +59,41 @@ def test_create_and_get_user(client, base_request, api_auth_token, roles): assert post_response_data["updated_at"] is not None # Get the user - user_id = post_response.get_json()["data"]["id"] + user_id = post_response.json()["id"] get_response = client.get(f"/v1/users/{user_id}", headers={"X-Auth": api_auth_token}) assert get_response.status_code == 200 - get_response_data = get_response.get_json()["data"] + get_response_data = get_response.json() assert get_response_data == expected_response test_create_user_bad_request_data = [ pytest.param( {}, - { - "first_name": ["Missing data for required field."], - "last_name": ["Missing data for required field."], - "phone_number": ["Missing data for required field."], - "date_of_birth": ["Missing data for required field."], - "is_active": ["Missing data for required field."], - "roles": ["Missing data for required field."], - }, + [ + { + "input": {}, + "loc": ["body", "first_name"], + "msg": "Field required", + "type": "missing", + }, + {"input": {}, "loc": ["body", "last_name"], "msg": "Field required", "type": "missing"}, + { + "input": {}, + "loc": ["body", "phone_number"], + "msg": "Field required", + "type": "missing", + }, + { + "input": {}, + "loc": ["body", "date_of_birth"], + "msg": "Field required", + "type": "missing", + }, + {"input": {}, "loc": ["body", "is_active"], "msg": "Field required", "type": "missing"}, + {"input": {}, "loc": ["body", "roles"], "msg": "Field required", "type": "missing"}, + ], id="missing all required fields", ), pytest.param( @@ -91,25 +106,70 @@ def test_create_and_get_user(client, base_request, api_auth_token, roles): "is_active": 6, "roles": 7, }, - { - "first_name": ["Not a valid string."], - "middle_name": ["Not a valid string."], - "last_name": ["Not a valid string."], - "phone_number": ["Not a valid string."], - "date_of_birth": ["Not a valid date."], - "is_active": ["Not a valid boolean."], - "roles": ["Not a valid list."], - }, + [ + { + "input": 1, + "loc": ["body", "first_name"], + "msg": "Input should be a valid string", + "type": "string_type", + }, + { + "input": 2, + "loc": ["body", "middle_name"], + "msg": "Input should be a valid string", + "type": "string_type", + }, + { + "input": 3, + "loc": ["body", "last_name"], + "msg": "Input should be a valid string", + "type": "string_type", + }, + { + "input": 5, + "loc": ["body", "phone_number"], + "msg": "Input should be a valid string", + "type": "string_type", + }, + { + "input": 4, + "loc": ["body", "date_of_birth"], + "msg": "Datetimes provided to dates should have zero time - e.g. be exact " "dates", + "type": "date_from_datetime_inexact", + }, + { + "input": 6, + "loc": ["body", "is_active"], + "msg": "Input should be a valid boolean, unable to interpret input", + "type": "bool_parsing", + }, + { + "input": 7, + "loc": ["body", "roles"], + "msg": "Input should be a valid list", + "type": "list_type", + }, + ], id="invalid types", ), pytest.param( get_base_request() | {"roles": [{"type": "Mime"}, {"type": "Clown"}]}, - { - "roles": { - "0": {"type": ["Must be one of: USER, ADMIN."]}, - "1": {"type": ["Must be one of: USER, ADMIN."]}, - } - }, + [ + { + "ctx": {"expected": "'USER' or 'ADMIN'"}, + "input": "Mime", + "loc": ["body", "roles", 0, "type"], + "msg": "Input should be 'USER' or 'ADMIN'", + "type": "enum", + }, + { + "ctx": {"expected": "'USER' or 'ADMIN'"}, + "input": "Clown", + "loc": ["body", "roles", 1, "type"], + "msg": "Input should be 'USER' or 'ADMIN'", + "type": "enum", + }, + ], id="invalid role type", ), ] @@ -120,7 +180,7 @@ def test_create_user_bad_request(client, api_auth_token, request_data, expected_ response = client.post("/v1/users", json=request_data, headers={"X-Auth": api_auth_token}) assert response.status_code == 422 - response_data = response.get_json()["detail"]["json"] + response_data = response.json()["detail"] assert response_data == expected_response_data @@ -130,7 +190,7 @@ def test_patch_user(client, api_auth_token, created_user): patch_response = client.patch( f"/v1/users/{user_id}", json=patch_request, headers={"X-Auth": api_auth_token} ) - patch_response_data = patch_response.get_json()["data"] + patch_response_data = patch_response.json() expected_response_data = { **created_user, **patch_request, @@ -141,7 +201,7 @@ def test_patch_user(client, api_auth_token, created_user): assert patch_response_data == expected_response_data get_response = client.get(f"/v1/users/{user_id}", headers={"X-Auth": api_auth_token}) - get_response_data = get_response.get_json()["data"] + get_response_data = get_response.json() assert get_response_data == expected_response_data @@ -155,14 +215,14 @@ def test_patch_user_roles(client, base_request, api_auth_token, initial_roles, u } created_user = client.post( "/v1/users", json=post_request, headers={"X-Auth": api_auth_token} - ).get_json()["data"] + ).json() user_id = created_user["id"] patch_request = {"roles": updated_roles} patch_response = client.patch( f"/v1/users/{user_id}", json=patch_request, headers={"X-Auth": api_auth_token} ) - patch_response_data = patch_response.get_json()["data"] + patch_response_data = patch_response.json() expected_response_data = { **created_user, **patch_request, @@ -173,7 +233,7 @@ def test_patch_user_roles(client, base_request, api_auth_token, initial_roles, u assert patch_response_data == expected_response_data get_response = client.get(f"/v1/users/{user_id}", headers={"X-Auth": api_auth_token}) - get_response_data = get_response.get_json()["data"] + get_response_data = get_response.json() assert get_response_data == expected_response_data @@ -190,10 +250,12 @@ def test_unauthorized(client, method, url, body, api_auth_token): expected_message = ( "The server could not verify that you are authorized to access the URL requested" ) - response = getattr(client, method)(url, json=body, headers={"X-Auth": "incorrect token"}) + response = client.request( + method=method, url=url, json=body, headers={"X-Auth": "incorrect token"} + ) assert response.status_code == 401 - assert response.get_json()["message"] == expected_message + assert response.json()["detail"] == expected_message test_not_found_data = [ @@ -206,7 +268,7 @@ def test_unauthorized(client, method, url, body, api_auth_token): def test_not_found(client, api_auth_token, method, body): user_id = uuid.uuid4() url = f"/v1/users/{user_id}" - response = getattr(client, method)(url, json=body, headers={"X-Auth": api_auth_token}) + response = client.request(method=method, url=url, json=body, headers={"X-Auth": api_auth_token}) assert response.status_code == 404 - assert response.get_json()["message"] == f"Could not find user with ID {user_id}" + assert response.json()["detail"] == f"Could not find user with ID {user_id}" diff --git a/app/tests/src/scripts/__init__.py b/app/tests/src/scripts/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/app/tests/src/scripts/test_create_user_csv.py b/app/tests/src/scripts/test_create_user_csv.py deleted file mode 100644 index ccd04e1e..00000000 --- a/app/tests/src/scripts/test_create_user_csv.py +++ /dev/null @@ -1,60 +0,0 @@ -import os -import os.path as path -import re - -import flask.testing -import pytest -from pytest_lazyfixture import lazy_fixture -from smart_open import open as smart_open - -import src.adapters.db as db -from src.db.models.user_models import User -from tests.src.db.models.factories import UserFactory - - -@pytest.fixture -def prepopulate_user_table(enable_factory_create, db_session: db.Session) -> list[User]: - # First make sure the table is empty, as other tests may have inserted data - # and this test expects a clean slate (unlike most tests that are designed to - # be isolated from other tests) - db_session.query(User).delete() - return [ - UserFactory.create(first_name="Jon", last_name="Doe", is_active=True), - UserFactory.create(first_name="Jane", last_name="Doe", is_active=False), - UserFactory.create( - first_name="Alby", - last_name="Testin", - is_active=True, - ), - ] - - -@pytest.fixture -def tmp_s3_folder(mock_s3_bucket): - return f"s3://{mock_s3_bucket}/path/to/" - - -@pytest.mark.parametrize( - "dir", - [ - pytest.param(lazy_fixture("tmp_s3_folder"), id="write-to-s3"), - pytest.param(lazy_fixture("tmp_path"), id="write-to-local"), - ], -) -def test_create_user_csv( - prepopulate_user_table: list[User], - cli_runner: flask.testing.FlaskCliRunner, - dir: str, -): - cli_runner.invoke(args=["user", "create-csv", "--dir", dir, "--filename", "test.csv"]) - output = smart_open(path.join(dir, "test.csv")).read() - expected_output = open( - path.join(path.dirname(__file__), "test_create_user_csv_expected.csv") - ).read() - assert output == expected_output - - -def test_default_filename(cli_runner: flask.testing.FlaskCliRunner, tmp_path: str): - cli_runner.invoke(args=["user", "create-csv", "--dir", tmp_path]) - filenames = os.listdir(tmp_path) - assert re.match(r"\d{4}-\d{2}-\d{2}-\d{2}-\d{2}-\d{2}-user-roles.csv", filenames[0]) diff --git a/app/tests/src/scripts/test_create_user_csv_expected.csv b/app/tests/src/scripts/test_create_user_csv_expected.csv deleted file mode 100644 index da0371ff..00000000 --- a/app/tests/src/scripts/test_create_user_csv_expected.csv +++ /dev/null @@ -1,4 +0,0 @@ -"User Name","Roles","Is User Active?" -"Jon Doe","ADMIN USER","True" -"Jane Doe","ADMIN USER","False" -"Alby Testin","ADMIN USER","True" diff --git a/docker-compose.yml b/docker-compose.yml index ed88388b..19e30a09 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -24,7 +24,7 @@ services: args: - RUN_UID=${RUN_UID:-4000} - RUN_USER=${RUN_USER:-app} - command: ["poetry", "run", "flask", "--app", "src.app", "run", "--host", "0.0.0.0", "--port", "8080", "--reload"] + command: ["poetry", "run", "python", "src/app.py"] container_name: main-app env_file: ./app/local.env ports: