Skip to content

Commit

Permalink
[nit][pre-commit][mypy] add mypy type checking to precommit (#506)
Browse files Browse the repository at this point in the history
* add pre-commit to clean up code

* mypy checking

* add typing to everything
  • Loading branch information
nowei authored Nov 21, 2024
1 parent ec4391a commit 33b9c63
Show file tree
Hide file tree
Showing 27 changed files with 630 additions and 488 deletions.
7 changes: 7 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,10 @@ repos:
args: [--fix]
# Run the formatter.
- id: ruff-format
- repo: https://github.com/pre-commit/mirrors-mypy
rev: 'main'
hooks:
- id: mypy
args: [--ignore-missing-imports]
additional_dependencies:
- pydantic
14 changes: 9 additions & 5 deletions python/src/functions/create_archive.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any

import boto3
import structlog
from aws_lambda_typing.context import Context
from mypy_boto3_s3.client import S3Client
from pydantic import BaseModel
Expand All @@ -20,7 +21,7 @@ class CreateArchiveLambdaPayload(BaseModel):


@reset_contextvars
def handle(event: dict[str, Any], _context: Context):
def handle(event: dict[str, Any], _context: Context) -> dict[str, Any]:
"""Lambda handler for creating an archive of CSV files in S3
Args:
Expand Down Expand Up @@ -64,8 +65,11 @@ def handle(event: dict[str, Any], _context: Context):


def create_archive(
org_id: int, reporting_period_id: int, s3_client: S3Client, logger=None
):
org_id: int,
reporting_period_id: int,
s3_client: S3Client,
logger: structlog.stdlib.BoundLogger = None,
) -> None:
"""Create a zip archive of CSV files in S3"""

if logger is None:
Expand All @@ -91,8 +95,8 @@ def create_archive(
with tempfile.NamedTemporaryFile() as file:
with zipfile.ZipFile(file, "w") as zipf:
for target_file in target_files:
obj = s3_client.get_object(Bucket=S3_BUCKET, Key=target_file)
zipf.writestr(target_file, obj["Body"].read())
target_obj = s3_client.get_object(Bucket=S3_BUCKET, Key=target_file)
zipf.writestr(target_file, target_obj["Body"].read())

zipf.close()
file.flush()
Expand Down
6 changes: 3 additions & 3 deletions python/src/functions/generate_presigned_url_and_send_email.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import Optional, Tuple
from typing import Any, Optional, Tuple

import boto3
import chevron
Expand Down Expand Up @@ -28,7 +28,7 @@ class SendTreasuryEmailLambdaPayload(BaseModel):


@reset_contextvars
def handle(event: SendTreasuryEmailLambdaPayload, context: Context):
def handle(event: SendTreasuryEmailLambdaPayload, context: Context) -> dict[str, Any]:
"""Lambda handler for emailing Treasury reports
Given a user and organization object- send an email to the user that
Expand Down Expand Up @@ -102,7 +102,7 @@ def generate_email(
def process_event(
payload: SendTreasuryEmailLambdaPayload,
logger: structlog.stdlib.BoundLogger,
):
) -> bool:
"""
This function is structured as followed:
1) Check to see if the s3 object exists:
Expand Down
49 changes: 26 additions & 23 deletions python/src/functions/generate_treasury_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import tempfile
from datetime import datetime
from typing import IO, Dict, List, Set, Union
from typing import IO, Any, Dict, List, Set, Type, Union

import boto3
import structlog
Expand Down Expand Up @@ -57,7 +57,7 @@ class ProjectLambdaPayload(BaseModel):


@reset_contextvars
def handle(event: ProjectLambdaPayload, context: Context):
def handle(event: ProjectLambdaPayload, context: Context) -> Dict[str, Any]:
"""Lambda handler for generating Treasury Reports
This function creates/outputs 3 files (all with the same name but different
Expand Down Expand Up @@ -94,7 +94,9 @@ def handle(event: ProjectLambdaPayload, context: Context):
return {"statusCode": 200, "body": "Success"}


def process_event(payload: ProjectLambdaPayload, logger: structlog.stdlib.BoundLogger):
def process_event(
payload: ProjectLambdaPayload, logger: structlog.stdlib.BoundLogger
) -> Dict[str, Any]:
"""
This function is structured as followed:
1) Load the metadata
Expand Down Expand Up @@ -262,6 +264,7 @@ def process_event(payload: ProjectLambdaPayload, logger: structlog.stdlib.BoundL
),
file=binary_json_file,
)
return {"statusCode": 200, "body": "Success"}


def download_output_file(
Expand Down Expand Up @@ -310,7 +313,7 @@ def download_output_file(


def get_existing_output_metadata(
s3_client,
s3_client: S3Client,
organization: OrganizationObj,
project_use_code: str,
logger: structlog.stdlib.BoundLogger,
Expand All @@ -330,37 +333,37 @@ def get_existing_output_metadata(
),
existing_file_binary,
)
if existing_file_binary:
existing_file_binary.close()
with open(existing_file_binary.name, mode="rt") as existing_file_json:
existing_project_agency_id_to_row_number = json.load(
existing_file_json
)
except ClientError as e:
error = e.response.get("Error") or {}
if error.get("Code") == "404":
logger.info(
"There is no existing metadata file for this treasury report"
)
existing_file_binary = None
else:
raise

if existing_file_binary:
existing_file_binary.close()
with open(existing_file_binary.name, mode="rt") as existing_file_json:
existing_project_agency_id_to_row_number = json.load(existing_file_json)

return existing_project_agency_id_to_row_number


def get_projects_to_remove(
workbook: Workbook,
ProjectRowSchema: Union[Project1ARow, Project1BRow, Project1CRow],
ProjectRowSchema: Union[Type[Project1ARow], Type[Project1BRow], Type[Project1CRow]],
agency_id: str,
):
) -> Set[str]:
"""
Get the set of '{project_id}_{agency_id}' ids to remove from the
existing output file.
"""
project_agency_ids_to_remove = set()
# Get projects
_, projects = validate_project_sheet(
workbook[PROJECT_SHEET], ProjectRowSchema, VERSION
workbook[PROJECT_SHEET], ProjectRowSchema, VERSION.value
)
# Store projects to remove
for project in projects:
Expand All @@ -371,16 +374,16 @@ def get_projects_to_remove(


def get_outdated_projects_to_remove(
s3_client,
s3_client: S3Client,
uploads_by_agency_id: Dict[AgencyId, UploadObj],
ProjectRowSchema: Union[Project1ARow, Project1BRow, Project1CRow],
ProjectRowSchema: Union[Type[Project1ARow], Type[Project1BRow], Type[Project1CRow]],
logger: structlog.stdlib.BoundLogger,
):
) -> Set[str]:
"""
Open the files in the outdated_file_info_list and get the projects to
remove.
"""
project_agency_ids_to_remove = set()
project_agency_ids_to_remove: Set[str] = set()
for agency_id, file_info in uploads_by_agency_id.items():
with tempfile.NamedTemporaryFile() as file:
# Download projects from S3
Expand Down Expand Up @@ -415,7 +418,7 @@ def update_project_agency_ids_to_row_map(
project_agency_ids_to_remove: Set[str],
highest_row_num: int,
sheet: Worksheet,
):
) -> int:
"""
Delete rows corresponding to project_agency_ids_to_remove in the existing
output file. Also update project_agency_id_to_row_map with the new row
Expand Down Expand Up @@ -445,14 +448,14 @@ def insert_project_row(
sheet: Worksheet,
row_num: int,
row: Union[Project1ARow, Project1BRow, Project1CRow],
):
) -> None:
"""
Append project to the xlsx file
sheet is Optional only for tests
"""
row_schema = row.model_json_schema()["properties"]
row_dict = row.dict()
row_dict = row.model_dump()
row_with_output_cols = {}
for prop in row_dict.keys():
prop_meta = row_schema.get(prop)
Expand All @@ -473,19 +476,19 @@ def combine_project_rows(
output_sheet: Worksheet,
project_use_code: str,
highest_row_num: int,
ProjectRowSchema: Union[Project1ARow, Project1BRow, Project1CRow],
ProjectRowSchema: Union[Type[Project1ARow], Type[Project1BRow], Type[Project1CRow]],
project_id_agency_id_to_upload_date: Dict[str, datetime],
project_id_agency_id_to_row_num: Dict[str, int],
created_at: datetime,
agency_id: str,
):
) -> int:
"""
Combine projects together and check for conflicts.
If there is a conflict, choose the most recent project based on created_at time.
"""
# Get projects
result = validate_project_sheet(
project_workbook[PROJECT_SHEET], ProjectRowSchema, VERSION
project_workbook[PROJECT_SHEET], ProjectRowSchema, VERSION.value
)
projects: List[Union[Project1ARow, Project1BRow, Project1CRow]] = result[1]
# Get project rows from workbook
Expand Down
36 changes: 21 additions & 15 deletions python/src/functions/subrecipient_treasury_report_gen.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import os
import tempfile
import typing
from datetime import datetime
from typing import Any, Dict, Optional

Expand Down Expand Up @@ -36,7 +37,7 @@ class SubrecipientLambdaPayload(BaseModel):


@reset_contextvars
def handle(event: Dict[str, Any], context: Context):
def handle(event: Dict[str, Any], context: Context) -> Dict[str, Any]:
"""Lambda handler for generating subrecipients file for treasury report
Args:
Expand Down Expand Up @@ -75,7 +76,7 @@ def handle(event: Dict[str, Any], context: Context):

def process_event(
payload: SubrecipientLambdaPayload, logger: structlog.stdlib.BoundLogger
):
) -> Dict[str, Any]:
"""
This function should:
1. Parse necessary inputs from the event
Expand All @@ -97,7 +98,10 @@ def process_event(
logger.warning(
f"Subrecipients file for organization {organization_id} and reporting period {reporting_period_id} does not have any subrecipients listed"
)
return
return {
"statusCode": 400,
"body": f"Subrecipients file for organization {organization_id} and reporting period {reporting_period_id} does not have any subrecipients listed",
}

workbook = get_output_workbook(s3_client, output_template_id)
write_subrecipients_to_workbook(
Expand Down Expand Up @@ -125,11 +129,11 @@ def get_output_workbook(s3_client: S3Client, output_template_id: int) -> Workboo


def get_recent_subrecipients(
s3_client,
organization_id,
reporting_period_id,
s3_client: S3Client,
organization_id: int,
reporting_period_id: int,
logger: structlog.stdlib.BoundLogger,
) -> Optional[dict]:
) -> Optional[dict[str, Any]]:
recent_subrecipients = {}

with tempfile.NamedTemporaryFile() as recent_subrecipients_file:
Expand All @@ -149,7 +153,7 @@ def get_recent_subrecipients(
logger.info(
f"No subrecipients for organization {organization_id} and reporting period {reporting_period_id}"
)
return
return {}
else:
raise

Expand All @@ -161,12 +165,12 @@ def get_recent_subrecipients(
logger.exception(
f"Subrecipients file for organization {organization_id} and reporting period {reporting_period_id} does not contain valid JSON"
)
return
return {}

return recent_subrecipients


def no_subrecipients_in_file(recent_subrecipients: dict):
def no_subrecipients_in_file(recent_subrecipients: dict[str, Any]) -> bool:
"""
Helper method to determine if the recent_subrecipients JSON object in
the recent subrecipients file downloaded from S3 has actual subrecipients in it or not
Expand All @@ -179,10 +183,10 @@ def no_subrecipients_in_file(recent_subrecipients: dict):


def write_subrecipients_to_workbook(
recent_subrecipients: dict,
recent_subrecipients: dict[str, Any],
workbook: Workbook,
logger: structlog.stdlib.BoundLogger,
):
) -> None:
"""
Given an output template, in the form of a `workbook` preloaded with openpyxl,
go through a list of `recent_subrecipients` and write information for each of them into the workbook
Expand All @@ -202,7 +206,9 @@ def write_subrecipients_to_workbook(
for k, v in getSubrecipientRowClass(
version_string=most_recent_upload["version"]
).model_fields.items():
output_column = v.json_schema_extra["output_column"]
output_column = typing.cast(Dict[str, str], v.json_schema_extra)[
"output_column"
]
if not output_column:
logger.error(f"No output column specified for field name {k}, skipping")
continue
Expand All @@ -220,7 +226,7 @@ def write_subrecipients_to_workbook(
row_to_edit += 1


def get_most_recent_upload(subrecipient):
def get_most_recent_upload(subrecipient: Dict[str, Any]) -> Any:
"""
Small helper method to sort subrecipientUploads for a given subrecipient by updated date,
and return the most recent one
Expand All @@ -237,7 +243,7 @@ def upload_workbook(
workbook: Workbook,
s3client: S3Client,
organization: OrganizationObj,
):
) -> None:
"""
Handles upload of workbook to S3, both in xlsx and csv formats
"""
Expand Down
14 changes: 8 additions & 6 deletions python/src/functions/validate_workbook.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
import tempfile
from typing import IO, List, Union, Dict
from typing import IO, Dict, List, Union
from urllib.parse import unquote_plus

import boto3
Expand All @@ -10,14 +10,16 @@
from mypy_boto3_s3.client import S3Client

from src.lib.logging import get_logger, reset_contextvars
from src.lib.workbook_validator import validate
from src.lib.s3_helper import download_s3_object
from src.lib.workbook_validator import Subrecipients, validate

type ValidationResults = Dict[str, Union[List[Dict[str, str]], str, None]]
type ValidationResults = Dict[
str, Union[List[Dict[str, str]], str, None, Subrecipients]
]


@reset_contextvars
def handle(event: S3Event, context: Context):
def handle(event: S3Event, context: Context) -> None:
"""Lambda handler for validating workbooks uploaded to S3
Args:
Expand Down Expand Up @@ -79,8 +81,8 @@ def validate_workbook(file: IO[bytes]) -> ValidationResults:


def save_validation_results(
client: S3Client, bucket, key: str, results: ValidationResults
):
client: S3Client, bucket: str, key: str, results: ValidationResults
) -> None:
"""Persists workbook validation results to S3.
Args:
Expand Down
Loading

0 comments on commit 33b9c63

Please sign in to comment.