From 33b9c632e51aee478f461651ec4ab0fc2bc241c0 Mon Sep 17 00:00:00 2001 From: Andrew Wei <32493749+nowei@users.noreply.github.com> Date: Thu, 21 Nov 2024 08:37:01 -0800 Subject: [PATCH] [nit][pre-commit][mypy] add mypy type checking to precommit (#506) * add pre-commit to clean up code * mypy checking * add typing to everything --- .pre-commit-config.yaml | 7 + python/src/functions/create_archive.py | 14 +- .../generate_presigned_url_and_send_email.py | 6 +- .../src/functions/generate_treasury_report.py | 49 ++++--- .../subrecipient_treasury_report_gen.py | 36 +++-- python/src/functions/validate_workbook.py | 14 +- python/src/lib/constants.py | 2 +- python/src/lib/email.py | 7 +- python/src/lib/logging.py | 11 +- python/src/lib/output_template_comparator.py | 41 +++--- python/src/lib/s3_helper.py | 13 +- python/src/lib/treasury_generation_common.py | 2 +- python/src/lib/workbook_utils.py | 13 +- python/src/lib/workbook_validator.py | 70 +++++---- python/src/schemas/project_types.py | 2 +- python/src/schemas/schema_V2024_04_01.py | 64 +++++--- python/src/schemas/schema_V2024_05_24.py | 78 ++++++---- python/src/schemas/schema_versions.py | 53 +++---- python/tests/conftest.py | 137 +++++++++++------- python/tests/src/lib/test_create_archive.py | 6 +- .../lib/test_output_template_comparator.py | 12 +- .../test_subrecipient_treasury_report_gen.py | 105 ++++++++------ .../tests/src/lib/test_treasury_report_1A.py | 80 +++++----- .../tests/src/lib/test_treasury_report_1B.py | 80 +++++----- .../tests/src/lib/test_treasury_report_1C.py | 80 +++++----- .../tests/src/lib/test_workbook_validator.py | 127 ++++++++-------- scripts/sample_lambda.py | 9 +- 27 files changed, 630 insertions(+), 488 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8a7f555a..49155a0a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,3 +14,10 @@ repos: args: [--fix] # Run the formatter. - id: ruff-format + - repo: https://github.com/pre-commit/mirrors-mypy + rev: 'main' + hooks: + - id: mypy + args: [--ignore-missing-imports] + additional_dependencies: + - pydantic diff --git a/python/src/functions/create_archive.py b/python/src/functions/create_archive.py index 38047cff..b777bc15 100644 --- a/python/src/functions/create_archive.py +++ b/python/src/functions/create_archive.py @@ -5,6 +5,7 @@ from typing import Any import boto3 +import structlog from aws_lambda_typing.context import Context from mypy_boto3_s3.client import S3Client from pydantic import BaseModel @@ -20,7 +21,7 @@ class CreateArchiveLambdaPayload(BaseModel): @reset_contextvars -def handle(event: dict[str, Any], _context: Context): +def handle(event: dict[str, Any], _context: Context) -> dict[str, Any]: """Lambda handler for creating an archive of CSV files in S3 Args: @@ -64,8 +65,11 @@ def handle(event: dict[str, Any], _context: Context): def create_archive( - org_id: int, reporting_period_id: int, s3_client: S3Client, logger=None -): + org_id: int, + reporting_period_id: int, + s3_client: S3Client, + logger: structlog.stdlib.BoundLogger = None, +) -> None: """Create a zip archive of CSV files in S3""" if logger is None: @@ -91,8 +95,8 @@ def create_archive( with tempfile.NamedTemporaryFile() as file: with zipfile.ZipFile(file, "w") as zipf: for target_file in target_files: - obj = s3_client.get_object(Bucket=S3_BUCKET, Key=target_file) - zipf.writestr(target_file, obj["Body"].read()) + target_obj = s3_client.get_object(Bucket=S3_BUCKET, Key=target_file) + zipf.writestr(target_file, target_obj["Body"].read()) zipf.close() file.flush() diff --git a/python/src/functions/generate_presigned_url_and_send_email.py b/python/src/functions/generate_presigned_url_and_send_email.py index 8155b43a..78d98837 100644 --- a/python/src/functions/generate_presigned_url_and_send_email.py +++ b/python/src/functions/generate_presigned_url_and_send_email.py @@ -1,5 +1,5 @@ import os -from typing import Optional, Tuple +from typing import Any, Optional, Tuple import boto3 import chevron @@ -28,7 +28,7 @@ class SendTreasuryEmailLambdaPayload(BaseModel): @reset_contextvars -def handle(event: SendTreasuryEmailLambdaPayload, context: Context): +def handle(event: SendTreasuryEmailLambdaPayload, context: Context) -> dict[str, Any]: """Lambda handler for emailing Treasury reports Given a user and organization object- send an email to the user that @@ -102,7 +102,7 @@ def generate_email( def process_event( payload: SendTreasuryEmailLambdaPayload, logger: structlog.stdlib.BoundLogger, -): +) -> bool: """ This function is structured as followed: 1) Check to see if the s3 object exists: diff --git a/python/src/functions/generate_treasury_report.py b/python/src/functions/generate_treasury_report.py index f4ca6863..1cd4dc0b 100644 --- a/python/src/functions/generate_treasury_report.py +++ b/python/src/functions/generate_treasury_report.py @@ -2,7 +2,7 @@ import os import tempfile from datetime import datetime -from typing import IO, Dict, List, Set, Union +from typing import IO, Any, Dict, List, Set, Type, Union import boto3 import structlog @@ -57,7 +57,7 @@ class ProjectLambdaPayload(BaseModel): @reset_contextvars -def handle(event: ProjectLambdaPayload, context: Context): +def handle(event: ProjectLambdaPayload, context: Context) -> Dict[str, Any]: """Lambda handler for generating Treasury Reports This function creates/outputs 3 files (all with the same name but different @@ -94,7 +94,9 @@ def handle(event: ProjectLambdaPayload, context: Context): return {"statusCode": 200, "body": "Success"} -def process_event(payload: ProjectLambdaPayload, logger: structlog.stdlib.BoundLogger): +def process_event( + payload: ProjectLambdaPayload, logger: structlog.stdlib.BoundLogger +) -> Dict[str, Any]: """ This function is structured as followed: 1) Load the metadata @@ -262,6 +264,7 @@ def process_event(payload: ProjectLambdaPayload, logger: structlog.stdlib.BoundL ), file=binary_json_file, ) + return {"statusCode": 200, "body": "Success"} def download_output_file( @@ -310,7 +313,7 @@ def download_output_file( def get_existing_output_metadata( - s3_client, + s3_client: S3Client, organization: OrganizationObj, project_use_code: str, logger: structlog.stdlib.BoundLogger, @@ -330,29 +333,29 @@ def get_existing_output_metadata( ), existing_file_binary, ) + if existing_file_binary: + existing_file_binary.close() + with open(existing_file_binary.name, mode="rt") as existing_file_json: + existing_project_agency_id_to_row_number = json.load( + existing_file_json + ) except ClientError as e: error = e.response.get("Error") or {} if error.get("Code") == "404": logger.info( "There is no existing metadata file for this treasury report" ) - existing_file_binary = None else: raise - if existing_file_binary: - existing_file_binary.close() - with open(existing_file_binary.name, mode="rt") as existing_file_json: - existing_project_agency_id_to_row_number = json.load(existing_file_json) - return existing_project_agency_id_to_row_number def get_projects_to_remove( workbook: Workbook, - ProjectRowSchema: Union[Project1ARow, Project1BRow, Project1CRow], + ProjectRowSchema: Union[Type[Project1ARow], Type[Project1BRow], Type[Project1CRow]], agency_id: str, -): +) -> Set[str]: """ Get the set of '{project_id}_{agency_id}' ids to remove from the existing output file. @@ -360,7 +363,7 @@ def get_projects_to_remove( project_agency_ids_to_remove = set() # Get projects _, projects = validate_project_sheet( - workbook[PROJECT_SHEET], ProjectRowSchema, VERSION + workbook[PROJECT_SHEET], ProjectRowSchema, VERSION.value ) # Store projects to remove for project in projects: @@ -371,16 +374,16 @@ def get_projects_to_remove( def get_outdated_projects_to_remove( - s3_client, + s3_client: S3Client, uploads_by_agency_id: Dict[AgencyId, UploadObj], - ProjectRowSchema: Union[Project1ARow, Project1BRow, Project1CRow], + ProjectRowSchema: Union[Type[Project1ARow], Type[Project1BRow], Type[Project1CRow]], logger: structlog.stdlib.BoundLogger, -): +) -> Set[str]: """ Open the files in the outdated_file_info_list and get the projects to remove. """ - project_agency_ids_to_remove = set() + project_agency_ids_to_remove: Set[str] = set() for agency_id, file_info in uploads_by_agency_id.items(): with tempfile.NamedTemporaryFile() as file: # Download projects from S3 @@ -415,7 +418,7 @@ def update_project_agency_ids_to_row_map( project_agency_ids_to_remove: Set[str], highest_row_num: int, sheet: Worksheet, -): +) -> int: """ Delete rows corresponding to project_agency_ids_to_remove in the existing output file. Also update project_agency_id_to_row_map with the new row @@ -445,14 +448,14 @@ def insert_project_row( sheet: Worksheet, row_num: int, row: Union[Project1ARow, Project1BRow, Project1CRow], -): +) -> None: """ Append project to the xlsx file sheet is Optional only for tests """ row_schema = row.model_json_schema()["properties"] - row_dict = row.dict() + row_dict = row.model_dump() row_with_output_cols = {} for prop in row_dict.keys(): prop_meta = row_schema.get(prop) @@ -473,19 +476,19 @@ def combine_project_rows( output_sheet: Worksheet, project_use_code: str, highest_row_num: int, - ProjectRowSchema: Union[Project1ARow, Project1BRow, Project1CRow], + ProjectRowSchema: Union[Type[Project1ARow], Type[Project1BRow], Type[Project1CRow]], project_id_agency_id_to_upload_date: Dict[str, datetime], project_id_agency_id_to_row_num: Dict[str, int], created_at: datetime, agency_id: str, -): +) -> int: """ Combine projects together and check for conflicts. If there is a conflict, choose the most recent project based on created_at time. """ # Get projects result = validate_project_sheet( - project_workbook[PROJECT_SHEET], ProjectRowSchema, VERSION + project_workbook[PROJECT_SHEET], ProjectRowSchema, VERSION.value ) projects: List[Union[Project1ARow, Project1BRow, Project1CRow]] = result[1] # Get project rows from workbook diff --git a/python/src/functions/subrecipient_treasury_report_gen.py b/python/src/functions/subrecipient_treasury_report_gen.py index 3a76bfb4..4e685ae0 100644 --- a/python/src/functions/subrecipient_treasury_report_gen.py +++ b/python/src/functions/subrecipient_treasury_report_gen.py @@ -1,6 +1,7 @@ import json import os import tempfile +import typing from datetime import datetime from typing import Any, Dict, Optional @@ -36,7 +37,7 @@ class SubrecipientLambdaPayload(BaseModel): @reset_contextvars -def handle(event: Dict[str, Any], context: Context): +def handle(event: Dict[str, Any], context: Context) -> Dict[str, Any]: """Lambda handler for generating subrecipients file for treasury report Args: @@ -75,7 +76,7 @@ def handle(event: Dict[str, Any], context: Context): def process_event( payload: SubrecipientLambdaPayload, logger: structlog.stdlib.BoundLogger -): +) -> Dict[str, Any]: """ This function should: 1. Parse necessary inputs from the event @@ -97,7 +98,10 @@ def process_event( logger.warning( f"Subrecipients file for organization {organization_id} and reporting period {reporting_period_id} does not have any subrecipients listed" ) - return + return { + "statusCode": 400, + "body": f"Subrecipients file for organization {organization_id} and reporting period {reporting_period_id} does not have any subrecipients listed", + } workbook = get_output_workbook(s3_client, output_template_id) write_subrecipients_to_workbook( @@ -125,11 +129,11 @@ def get_output_workbook(s3_client: S3Client, output_template_id: int) -> Workboo def get_recent_subrecipients( - s3_client, - organization_id, - reporting_period_id, + s3_client: S3Client, + organization_id: int, + reporting_period_id: int, logger: structlog.stdlib.BoundLogger, -) -> Optional[dict]: +) -> Optional[dict[str, Any]]: recent_subrecipients = {} with tempfile.NamedTemporaryFile() as recent_subrecipients_file: @@ -149,7 +153,7 @@ def get_recent_subrecipients( logger.info( f"No subrecipients for organization {organization_id} and reporting period {reporting_period_id}" ) - return + return {} else: raise @@ -161,12 +165,12 @@ def get_recent_subrecipients( logger.exception( f"Subrecipients file for organization {organization_id} and reporting period {reporting_period_id} does not contain valid JSON" ) - return + return {} return recent_subrecipients -def no_subrecipients_in_file(recent_subrecipients: dict): +def no_subrecipients_in_file(recent_subrecipients: dict[str, Any]) -> bool: """ Helper method to determine if the recent_subrecipients JSON object in the recent subrecipients file downloaded from S3 has actual subrecipients in it or not @@ -179,10 +183,10 @@ def no_subrecipients_in_file(recent_subrecipients: dict): def write_subrecipients_to_workbook( - recent_subrecipients: dict, + recent_subrecipients: dict[str, Any], workbook: Workbook, logger: structlog.stdlib.BoundLogger, -): +) -> None: """ Given an output template, in the form of a `workbook` preloaded with openpyxl, go through a list of `recent_subrecipients` and write information for each of them into the workbook @@ -202,7 +206,9 @@ def write_subrecipients_to_workbook( for k, v in getSubrecipientRowClass( version_string=most_recent_upload["version"] ).model_fields.items(): - output_column = v.json_schema_extra["output_column"] + output_column = typing.cast(Dict[str, str], v.json_schema_extra)[ + "output_column" + ] if not output_column: logger.error(f"No output column specified for field name {k}, skipping") continue @@ -220,7 +226,7 @@ def write_subrecipients_to_workbook( row_to_edit += 1 -def get_most_recent_upload(subrecipient): +def get_most_recent_upload(subrecipient: Dict[str, Any]) -> Any: """ Small helper method to sort subrecipientUploads for a given subrecipient by updated date, and return the most recent one @@ -237,7 +243,7 @@ def upload_workbook( workbook: Workbook, s3client: S3Client, organization: OrganizationObj, -): +) -> None: """ Handles upload of workbook to S3, both in xlsx and csv formats """ diff --git a/python/src/functions/validate_workbook.py b/python/src/functions/validate_workbook.py index 90100c70..107e184d 100644 --- a/python/src/functions/validate_workbook.py +++ b/python/src/functions/validate_workbook.py @@ -1,6 +1,6 @@ import json import tempfile -from typing import IO, List, Union, Dict +from typing import IO, Dict, List, Union from urllib.parse import unquote_plus import boto3 @@ -10,14 +10,16 @@ from mypy_boto3_s3.client import S3Client from src.lib.logging import get_logger, reset_contextvars -from src.lib.workbook_validator import validate from src.lib.s3_helper import download_s3_object +from src.lib.workbook_validator import Subrecipients, validate -type ValidationResults = Dict[str, Union[List[Dict[str, str]], str, None]] +type ValidationResults = Dict[ + str, Union[List[Dict[str, str]], str, None, Subrecipients] +] @reset_contextvars -def handle(event: S3Event, context: Context): +def handle(event: S3Event, context: Context) -> None: """Lambda handler for validating workbooks uploaded to S3 Args: @@ -79,8 +81,8 @@ def validate_workbook(file: IO[bytes]) -> ValidationResults: def save_validation_results( - client: S3Client, bucket, key: str, results: ValidationResults -): + client: S3Client, bucket: str, key: str, results: ValidationResults +) -> None: """Persists workbook validation results to S3. Args: diff --git a/python/src/lib/constants.py b/python/src/lib/constants.py index f6c14f00..b43a47f7 100644 --- a/python/src/lib/constants.py +++ b/python/src/lib/constants.py @@ -9,7 +9,7 @@ class OutputTemplateFilename(Enum): CPF1CMultiPurposeCommunityTemplate = "CPF1CMultiPurposeCommunityTemplate" CPFSubrecipientTemplate = "CPFSubrecipientTemplate" - def __str__(self): + def __str__(self) -> str: return self.value diff --git a/python/src/lib/email.py b/python/src/lib/email.py index d5aa484b..7f71cb99 100644 --- a/python/src/lib/email.py +++ b/python/src/lib/email.py @@ -1,13 +1,18 @@ import os import boto3 +import structlog from botocore.exceptions import ClientError CHARSET = "UTF-8" def send_email( - dest_email: str, email_html: str, email_text: str, subject: str, logger + dest_email: str, + email_html: str, + email_text: str, + subject: str, + logger: structlog.stdlib.BoundLogger, ) -> bool: # Email user email_client = boto3.client("ses") diff --git a/python/src/lib/logging.py b/python/src/lib/logging.py index 8b707274..62efa534 100644 --- a/python/src/lib/logging.py +++ b/python/src/lib/logging.py @@ -1,10 +1,11 @@ import functools import logging import sys +from typing import Any, Callable, Dict, List import structlog -shared_processors = [ +shared_processors: List[Callable] = [ structlog.contextvars.merge_contextvars, structlog.processors.add_log_level, structlog.processors.TimeStamper(fmt="iso", key="ts"), @@ -17,7 +18,7 @@ ), ] -processors = shared_processors + [] +processors: List[Callable] = shared_processors + [] if sys.stderr.isatty(): processors += [ structlog.dev.ConsoleRenderer(), @@ -36,11 +37,11 @@ ) -def get_logger(*args, **kwargs) -> structlog.stdlib.BoundLogger: +def get_logger(*args: str, **kwargs: str) -> structlog.stdlib.BoundLogger: return structlog.get_logger(*args, **kwargs) -def reset_contextvars(func): +def reset_contextvars(func: Callable) -> Callable: """Decorator that resets context-local log values prior to each call to the decorated function. @@ -51,7 +52,7 @@ def reset_contextvars(func): """ @functools.wraps(func) - def inner(*args, **kwargs): + def inner(*args: List[Any], **kwargs: Dict[str, Any]) -> Any: structlog.contextvars.unbind_contextvars() return func(*args, **kwargs) diff --git a/python/src/lib/output_template_comparator.py b/python/src/lib/output_template_comparator.py index 5fe76d90..bac7917f 100644 --- a/python/src/lib/output_template_comparator.py +++ b/python/src/lib/output_template_comparator.py @@ -5,9 +5,10 @@ import zipfile from collections import defaultdict from dataclasses import dataclass, field -from typing import Dict, List, Set +from typing import IO, Any, Dict, List, Set from openpyxl import load_workbook +from openpyxl.worksheet.worksheet import Worksheet HEADER_ROW_INDEX = 4 @@ -16,22 +17,22 @@ class CPFDiffReport: new_sheets: Dict[str, List[str]] = field(default_factory=dict) removed_sheets: Dict[str, List[str]] = field(default_factory=dict) - row_count_changed: Dict[str, List[str]] = field(default_factory=dict) - column_count_changed: Dict[str, List[str]] = field(default_factory=dict) + row_count_changed: Dict[str, str] = field(default_factory=dict) + column_count_changed: Dict[str, str] = field(default_factory=dict) column_differences: Dict[str, List[str]] = field(default_factory=dict) cell_value_changed: Dict[str, List[str]] = field(default_factory=dict) new_files: List[str] = field(default_factory=list) removed_files: List[str] = field(default_factory=list) - def __post_init__(self): + def __post_init__(self) -> None: self.new_sheets = defaultdict(list) self.removed_sheets = defaultdict(list) - self.row_count_changed = defaultdict(list) - self.column_count_changed = defaultdict(list) + self.row_count_changed = defaultdict(lambda: "") + self.column_count_changed = defaultdict(lambda: "") self.column_differences = defaultdict(list) self.cell_value_changed = defaultdict(list) - def summary_report(self): + def summary_report(self) -> str: return f""" New files: {self._format_item(self.new_files)} Removed files: {self._format_item(self.removed_files)} @@ -44,20 +45,20 @@ def summary_report(self): """ @staticmethod - def _format_item(item: any): + def _format_item(item: Any) -> str: if isinstance(item, list): return "\n".join(item) elif isinstance(item, dict): return "\n".join([f"{k}: {v}" for k, v in item.items()]) else: - return item + return str(item) class CPFFileArchive: """Class for working with a CPF file archive""" _zip_file: zipfile.ZipFile - _file_name_map: dict + _file_name_map: dict[str, str] def __init__(self, zip_file: zipfile.ZipFile): self._zip_file = zip_file @@ -67,11 +68,11 @@ def normalized_file_names(self) -> Set[str]: """Returns a dictionary of normalized file names""" return set(self._file_name_map.keys()) - def file_by_name(self, name) -> zipfile.ZipExtFile: + def file_by_name(self, name: str) -> IO[bytes]: """Returns a file object by name""" return self._zip_file.open(self._file_name_map[name]) - def _normalize_names(self): + def _normalize_names(self) -> None: """Normalizes file names""" suffix_regex = r"\s\(\d+\)" normalized_files = {} @@ -83,7 +84,7 @@ def _normalize_names(self): def compare_workbooks( latest_archive: CPFFileArchive, previous_archive: CPFFileArchive -) -> tuple: +) -> tuple[set[str], set[str], set[str]]: latest_files = latest_archive.normalized_file_names() previous_files = previous_archive.normalized_file_names() new_files = latest_files - previous_files @@ -92,7 +93,9 @@ def compare_workbooks( return common_files, new_files, removed_files -def compare_sheet_columns(previous_sheet, latest_sheet): +def compare_sheet_columns( + previous_sheet: Worksheet, latest_sheet: Worksheet +) -> tuple[set[Any], set[Any], dict[Any, tuple[int, int]]]: previous_column_headers = [ previous_sheet.cell(row=HEADER_ROW_INDEX, column=column).value for column in range(1, previous_sheet.max_column + 1) @@ -116,7 +119,11 @@ def compare_sheet_columns(previous_sheet, latest_sheet): return added_columns, removed_columns, header_map -def compare_cell_values(previous_sheet, latest_sheet, header_map): +def compare_cell_values( + previous_sheet: Worksheet, + latest_sheet: Worksheet, + header_map: dict[str, tuple[int, int]], +) -> list[str]: differences = [] for row in range(HEADER_ROW_INDEX, latest_sheet.max_row + 1): for header, (previous_column, latest_column) in header_map.items(): @@ -131,7 +138,7 @@ def compare_cell_values(previous_sheet, latest_sheet, header_map): def compare( latest_zip_files: zipfile.ZipFile, previous_zip_files: zipfile.ZipFile -) -> List[str]: +) -> CPFDiffReport: # read the files from the zip files # compare the files # return the differences @@ -203,7 +210,7 @@ def compare( return differences -def load_files(zip_path) -> zipfile.ZipFile: +def load_files(zip_path: str) -> zipfile.ZipFile: """Loads XLSX files from a zip file. Args: diff --git a/python/src/lib/s3_helper.py b/python/src/lib/s3_helper.py index fc81e487..bca00002 100644 --- a/python/src/lib/s3_helper.py +++ b/python/src/lib/s3_helper.py @@ -8,7 +8,9 @@ from src.lib.logging import get_logger -def download_s3_object(client: S3Client, bucket: str, key: str, destination: IO[bytes]): +def download_s3_object( + client: S3Client, bucket: str, key: str, destination: IO[bytes] +) -> None: """Downloads an S3 object to a local file. Args: @@ -31,8 +33,8 @@ def upload_generated_file_to_s3( client: S3Client, bucket: str, key: str, - file: Union[IO[bytes], tempfile._TemporaryFileWrapper], -): + file: Union[IO[bytes], "tempfile._TemporaryFileWrapper[str]"], +) -> None: """Persists file to S3. Args: @@ -67,7 +69,8 @@ def get_presigned_url( ) -> Optional[str]: logger = get_logger() try: - response = s3_client.head_object( + # Check file exists + _resp = s3_client.head_object( Bucket=bucket, Key=key, ) @@ -75,6 +78,7 @@ def get_presigned_url( logger.exception(f"Unable to retrieve head object for key: {key}") return None + response = None try: response = s3_client.generate_presigned_url( "get_object", @@ -86,6 +90,5 @@ def get_presigned_url( ) except ClientError: logger.exception(f"Unable to retrieve presigned URL for key: {key}") - return None return response diff --git a/python/src/lib/treasury_generation_common.py b/python/src/lib/treasury_generation_common.py index ea810753..6523fb70 100644 --- a/python/src/lib/treasury_generation_common.py +++ b/python/src/lib/treasury_generation_common.py @@ -42,7 +42,7 @@ def get_output_template( output_template_id: int, project: str, destination: IO[bytes], -): +) -> None: """Downloads an empty output template from S3. Args: diff --git a/python/src/lib/workbook_utils.py b/python/src/lib/workbook_utils.py index 77143c8f..8392d4c9 100644 --- a/python/src/lib/workbook_utils.py +++ b/python/src/lib/workbook_utils.py @@ -10,7 +10,7 @@ """ -def escape_for_csv(text: Optional[str]): +def escape_for_csv(text: Optional[str]) -> Optional[str]: if not text: return text text = text.replace("\n", " -- ") @@ -22,7 +22,7 @@ def convert_xlsx_to_csv( csv_file: "_TemporaryFileWrapper[str]", file: Workbook, num_rows: int, -): +) -> "_TemporaryFileWrapper[str]": """ Convert xlsx file to csv. """ @@ -33,7 +33,12 @@ def convert_xlsx_to_csv( for eachrow in sheet.rows: if row_num > num_rows: break - csv_file_handle.writerow([escape_for_csv(cell.value) for cell in eachrow]) + csv_file_handle.writerow( + [ + escape_for_csv(str(cell.value) if cell.value else None) + for cell in eachrow + ] + ) row_num = row_num + 1 return csv_file @@ -46,7 +51,7 @@ def convert_xlsx_to_csv( """ -def find_last_populated_row(worksheet: Worksheet, start_row: int, column: str): +def find_last_populated_row(worksheet: Worksheet, start_row: int, column: str) -> int: last_populated_row = start_row for row in range(start_row, worksheet.max_row + 1): diff --git a/python/src/lib/workbook_validator.py b/python/src/lib/workbook_validator.py index 47d540c3..0b9f2ba9 100644 --- a/python/src/lib/workbook_validator.py +++ b/python/src/lib/workbook_validator.py @@ -1,3 +1,4 @@ +import typing from enum import Enum from typing import IO, Any, Dict, Iterable, List, Optional, Tuple, Type, Union @@ -15,6 +16,7 @@ Project1BRow, Project1CRow, SubrecipientRow, + V2024_05_24_ProjectRows, Version, getCoverSheetRowClass, getSchemaByProject, @@ -23,9 +25,6 @@ getVersionFromString, ) -type Errors = List[WorkbookError] -type Subrecipients = List[SubrecipientRow] - LOGIC_SHEET = "Logic" COVER_SHEET = "Cover" PROJECT_SHEET = "Project" @@ -69,15 +68,21 @@ def __init__( self.severity = severity -def map_values_to_headers(headers: Tuple, values: Iterable[Any]): +type Errors = List[WorkbookError] +type Subrecipients = List[SubrecipientRow] + + +def map_values_to_headers( + headers: Tuple[str, ...], values: Iterable[Any] +) -> Dict[str, Any]: return dict(zip(headers, values)) -def is_empty_row(row_values: Tuple): +def is_empty_row(row_values: Tuple[Optional[Any], ...]) -> bool: return all(value in (None, "") for value in row_values) -def get_headers(sheet: Worksheet, cell_range: str) -> tuple: +def get_headers(sheet: Worksheet, cell_range: str) -> Tuple[Any, ...]: return tuple(header_cell.value for header_cell in sheet[cell_range][0]) @@ -92,15 +97,18 @@ def get_project_use_code( values, it raises an error. """ metadata = getSchemaMetadata(version_string) - if not row_dict: + actual_dict: Dict[str, str] + if row_dict: + actual_dict = row_dict + else: cover_header = get_headers(cover_sheet, metadata["Cover"]["header_range"]) cover_row = map(lambda cell: cell.value, cover_sheet[2]) - row_dict = map_values_to_headers(cover_header, cover_row) + actual_dict = map_values_to_headers(cover_header, cover_row) version = getVersionFromString(version_string) codeKey = "Expenditure Category Group" if version != Version.V2024_05_24: codeKey = "Project Use Code" - code = row_dict[codeKey] + code = actual_dict[codeKey] return ProjectType.from_project_name(code) @@ -224,6 +232,8 @@ def validate(workbook: IO[bytes]) -> Tuple[Errors, Optional[str], Subrecipients, ) ], "Unkown", + [], + "", ) finally: @@ -234,9 +244,9 @@ def validate_workbook( workbook: Workbook, return_data: bool = False ) -> Tuple[ Errors, - Optional[str], - List[Union[Project1ARow, Project1BRow, Project1CRow]], - List[SubrecipientRow], + Optional[ProjectType], + Subrecipients, + str, ]: """Validates a given Excel workbook according to CPF validation rules. @@ -290,8 +300,6 @@ def validate_workbook( projects, subrecipients, version_string ) - subrecipients = [subrecipient.model_dump() for subrecipient in subrecipients] - return (errors, project_use_code, subrecipients, version_string) @@ -372,13 +380,15 @@ def validate_cover_sheet( ) return (errors, None, None) - project_schema = getSchemaByProject(version_string, project_use_code) + project_schema = getSchemaByProject( + getVersionFromString(version_string), project_use_code + ) return (errors, project_schema, project_use_code) def validate_project_sheet( project_sheet: Worksheet, - project_schema: Type[Union[Project1ARow, Project1BRow, Project1CRow]], + project_schema: Union[Type[Project1ARow], Type[Project1BRow], Type[Project1CRow]], version_string: str, ) -> Tuple[Errors, List[Union[Project1ARow, Project1BRow, Project1CRow]]]: errors = [] @@ -409,8 +419,8 @@ def validate_project_sheet( errors += [ WorkbookError( message="Upload doesn’t include any project records.", - row=INITIAL_STARTING_ROW + 1, - col=0, + row=str(INITIAL_STARTING_ROW + 1), + col=str(0), tab=PROJECT_SHEET, field_name="", severity=ErrorLevel.ERR.name, @@ -457,6 +467,10 @@ def validate_projects_subrecipients( if getVersionFromString(version_string) != Version.active_version(): return [] + projects_v2024: List[V2024_05_24_ProjectRows] = typing.cast( + List[V2024_05_24_ProjectRows], projects + ) + errors = [] subrecipients_by_uei_tin = {} for subrecipient in subrecipients: @@ -464,23 +478,25 @@ def validate_projects_subrecipients( (subrecipient.EIN__c, subrecipient.Unique_Entity_Identifier__c) ] = subrecipient - for project in projects: + for project in projects_v2024: if ( subrecipients_by_uei_tin.get( (project.Subrecipient_TIN__c, project.Subrecipient_UEI__c) ) is None ): - col_name_tin = project.__class__.model_fields[ - "Subrecipient_TIN__c" - ].json_schema_extra["column"] - col_name_uei = project.__class__.model_fields[ - "Subrecipient_UEI__c" - ].json_schema_extra["column"] + col_name_tin = typing.cast( + Dict[str, str], + project.__class__.model_fields["Subrecipient_TIN__c"].json_schema_extra, + )["column"] + col_name_uei = typing.cast( + Dict[str, str], + project.__class__.model_fields["Subrecipient_UEI__c"].json_schema_extra, + )["column"] errors.append( WorkbookError( message="You must submit a subrecipient record with the same UEI & TIN numbers entered for this project", - row=project.row_num, + row=str(project.row_num), col=f"{col_name_uei}, {col_name_tin}", tab="Project", field_name="Subrecipient_TIN__c and Subrecipient_UEI__c", @@ -499,7 +515,7 @@ def validate_projects_subrecipients( file_path = sys.argv[1] with open(file_path, "rb") as f: - errors, project_use_code = validate(f) + errors, project_use_code, subrecipients, version_str = validate(f) if errors: print("Errors found:") for error in errors: diff --git a/python/src/schemas/project_types.py b/python/src/schemas/project_types.py index 7a5f21d7..2600cea9 100644 --- a/python/src/schemas/project_types.py +++ b/python/src/schemas/project_types.py @@ -15,7 +15,7 @@ def from_project_name(cls, project_name: str) -> "ProjectType": f"Project name '{project_name}' is not a recognized project type." ) - def __str__(self): + def __str__(self) -> str: return self.value diff --git a/python/src/schemas/schema_V2024_04_01.py b/python/src/schemas/schema_V2024_04_01.py index 428d9c2e..dd3feb41 100644 --- a/python/src/schemas/schema_V2024_04_01.py +++ b/python/src/schemas/schema_V2024_04_01.py @@ -1,6 +1,6 @@ from datetime import datetime from enum import Enum -from typing import Any, Optional +from typing import Any, Optional, Union from pydantic import ( BaseModel, @@ -11,27 +11,27 @@ field_validator, ) -from src.schemas.project_types import NAME_BY_PROJECT from src.schemas.custom_types import ( CustomDecimal_7Digits, CustomDecimal_12Digits, CustomDecimal_13Digits, - CustomInt_GE1, CustomInt_GE0_LELARGE, CustomInt_GE0_LELARGE2, CustomInt_GE0_LELARGE3, CustomInt_GE0_LELARGE4, + CustomInt_GE1, CustomStr_MIN1, - CustomStr_MIN1_MAX100, + CustomStr_MIN1_MAX5, CustomStr_MIN1_MAX10, CustomStr_MIN1_MAX20, - CustomStr_MIN12_MAX12, - CustomStr_MIN9_MAX9, - CustomStr_MIN1_MAX3000, - CustomStr_MIN1_MAX5, CustomStr_MIN1_MAX40, CustomStr_MIN1_MAX80, + CustomStr_MIN1_MAX100, + CustomStr_MIN1_MAX3000, + CustomStr_MIN9_MAX9, + CustomStr_MIN12_MAX12, ) +from src.schemas.project_types import NAME_BY_PROJECT, ProjectType class StateAbbreviation(str, Enum): @@ -346,7 +346,7 @@ class BaseProjectRow(BaseModel): "Actual_operations_date__c", ) @classmethod - def parse_mm_dd_yyyy_dates(cls, v): + def parse_mm_dd_yyyy_dates(cls, v: Union[str, datetime]) -> datetime: if isinstance(v, str): try: return datetime.strptime(v, "%m/%d/%Y") @@ -383,8 +383,10 @@ def serialize_mm_dd_yyyy_dates(self, value: datetime) -> str: "Project_Status__c", ) @classmethod - def validate_field(cls, v: Any, info: ValidationInfo, **kwargs): - if isinstance(v, str) and v.strip == "": + def validate_field( + cls, v: Any, info: ValidationInfo, **kwargs: dict[str, Any] + ) -> Any: + if isinstance(v, str) and v.strip() == "": raise ValueError(f"Value is required for {info.field_name}") return v @@ -535,8 +537,10 @@ class Project1ARow(BaseProjectRow): "Affordable_Connectivity_Program_ACP__c", ) @classmethod - def validate_field(cls, v: Any, info: ValidationInfo, **kwargs): - if isinstance(v, str) and v.strip == "": + def validate_field( + cls, v: Any, info: ValidationInfo, **kwargs: dict[str, Any] + ) -> Any: + if isinstance(v, str) and v.strip() == "": raise ValueError(f"Value is required for {info.field_name}") return v @@ -605,8 +609,10 @@ class AddressFields(BaseModel): "Zip_Code_Planned__c", ) @classmethod - def validate_field(cls, v: Any, info: ValidationInfo, **kwargs): - if isinstance(v, str) and v.strip == "": + def validate_field( + cls, v: Any, info: ValidationInfo, **kwargs: dict[str, Any] + ) -> Any: + if isinstance(v, str) and v.strip() == "": raise ValueError(f"Value is required for {info.field_name}") return v @@ -757,8 +763,10 @@ class Project1BRow(BaseProjectRow, AddressFields): "Measurement_of_Effectiveness__c", ) @classmethod - def validate_field(cls, v: Any, info: ValidationInfo, **kwargs): - if isinstance(v, str) and v.strip == "": + def validate_field( + cls, v: Any, info: ValidationInfo, **kwargs: dict[str, Any] + ) -> Any: + if isinstance(v, str) and v.strip() == "": raise ValueError(f"Value is required for {info.field_name}") return v @@ -860,8 +868,10 @@ class Project1CRow(BaseProjectRow, AddressFields): @field_validator("Access_to_Public_Transit__c") @classmethod - def validate_field(cls, v: Any, info: ValidationInfo, **kwargs): - if isinstance(v, str) and v.strip == "": + def validate_field( + cls, v: Any, info: ValidationInfo, **kwargs: dict[str, Any] + ) -> Any: + if isinstance(v, str) and v.strip() == "": raise ValueError(f"Value is required for {info.field_name}") return v @@ -939,8 +949,10 @@ class SubrecipientRow(BaseModel): "State_Abbreviated__c", ) @classmethod - def validate_field(cls, v: Any, info: ValidationInfo, **kwargs): - if isinstance(v, str) and v.strip == "": + def validate_field( + cls, v: Any, info: ValidationInfo, **kwargs: dict[str, Any] + ) -> Any: + if isinstance(v, str) and v.strip() == "": raise ValueError(f"Value is required for {info.field_name}") return v @@ -982,15 +994,19 @@ class CoverSheetRow(BaseModel): @field_validator("project_use_code") @classmethod - def validate_code(cls, v: Any, info: ValidationInfo, **kwargs): + def validate_code( + cls, v: Any, info: ValidationInfo, **kwargs: dict[str, Any] + ) -> Any: if v is None or v.strip() == "": raise ValueError("EC code must be set") return v @field_validator("project_use_name") @classmethod - def validate_code_name_pair(cls, v: Any, info: ValidationInfo, **kwargs): - project_use_code = info.data.get("project_use_code") + def validate_code_name_pair( + cls, v: Any, info: ValidationInfo, **kwargs: dict[str, Any] + ) -> Any: + project_use_code = ProjectType(info.data.get("project_use_code")) expected_name = NAME_BY_PROJECT.get(project_use_code) if not expected_name: diff --git a/python/src/schemas/schema_V2024_05_24.py b/python/src/schemas/schema_V2024_05_24.py index bcfdcf4e..7d1c532c 100644 --- a/python/src/schemas/schema_V2024_05_24.py +++ b/python/src/schemas/schema_V2024_05_24.py @@ -1,6 +1,6 @@ from datetime import datetime from enum import Enum -from typing import Any, Optional +from typing import Any, Optional, Union from pydantic import ( BaseModel, @@ -11,28 +11,28 @@ field_validator, ) -from src.schemas.project_types import NAME_BY_PROJECT from src.schemas.custom_types import ( CustomDecimal_7Digits, CustomDecimal_12Digits, CustomDecimal_13Digits, CustomDecimal_15Digits, - CustomInt_GE1, CustomInt_GE0_LELARGE, CustomInt_GE0_LELARGE2, CustomInt_GE0_LELARGE3, CustomInt_GE0_LELARGE4, + CustomInt_GE1, CustomStr_MIN1, - CustomStr_MIN1_MAX100, + CustomStr_MIN1_MAX5, CustomStr_MIN1_MAX10, CustomStr_MIN1_MAX20, - CustomStr_MIN12_MAX12, - CustomStr_MIN9_MAX9, - CustomStr_MIN1_MAX3000, - CustomStr_MIN1_MAX5, CustomStr_MIN1_MAX40, CustomStr_MIN1_MAX80, + CustomStr_MIN1_MAX100, + CustomStr_MIN1_MAX3000, + CustomStr_MIN9_MAX9, + CustomStr_MIN12_MAX12, ) +from src.schemas.project_types import NAME_BY_PROJECT, ProjectType class StateAbbreviation(str, Enum): @@ -87,7 +87,7 @@ class StateAbbreviation(str, Enum): WI = "WI" WY = "WY" - def __str__(self): + def __str__(self) -> str: return self.value @@ -100,7 +100,7 @@ class CapitalAssetOwnershipType(str, Enum): COOPERATIVE = "6. Co-operative" OTHER = "7. Other" - def __str__(self): + def __str__(self) -> str: return self.value @@ -110,7 +110,7 @@ class ProjectStatusType(str, Enum): MORE_THAN_FIFTY_PERCENT_COMPLETE = "3. More than 50 percent complete" COMPLETED = "4. Completed" - def __str__(self): + def __str__(self) -> str: return self.value @@ -118,7 +118,7 @@ class YesNoType(str, Enum): YES = "Yes" NO = "No" - def __str__(self): + def __str__(self) -> str: return self.value @@ -128,7 +128,7 @@ class TechType(str, Enum): FIXED_WIRELESS = "3. Fixed Wireless" OTHER = "4. Other" - def __str__(self): + def __str__(self) -> str: return self.value @@ -136,7 +136,7 @@ class ProjectInvestmentType(str, Enum): NEW_CONSTRUCTION = "1. New Construction" RENOVATION = "2. Renovation" - def __str__(self): + def __str__(self) -> str: return self.value @@ -602,7 +602,7 @@ class BaseProjectRow(BaseModel): "Actual_operations_date__c", ) @classmethod - def parse_mm_dd_yyyy_dates(cls, v): + def parse_mm_dd_yyyy_dates(cls, v: Union[datetime, str]) -> datetime: if isinstance(v, str): try: return datetime.strptime(v, "%m/%d/%Y") @@ -639,8 +639,10 @@ def serialize_mm_dd_yyyy_dates(self, value: datetime) -> str: "Project_Status__c", ) @classmethod - def validate_field(cls, v: Any, info: ValidationInfo, **kwargs): - if isinstance(v, str) and v.strip == "": + def validate_field( + cls, v: Any, info: ValidationInfo, **kwargs: dict[str, Any] + ) -> Any: + if isinstance(v, str) and v.strip() == "": raise ValueError(f"Value is required for {info.field_name}") return v @@ -796,8 +798,10 @@ class Project1ARow(BaseProjectRow): "Affordable_Connectivity_Program_ACP__c", ) @classmethod - def validate_field(cls, v: Any, info: ValidationInfo, **kwargs): - if isinstance(v, str) and v.strip == "": + def validate_field( + cls, v: Any, info: ValidationInfo, **kwargs: dict[str, Any] + ) -> Any: + if isinstance(v, str) and v.strip() == "": raise ValueError(f"Value is required for {info.field_name}") return v @@ -914,8 +918,10 @@ class AddressFields(BaseModel): "Zip_Code_Planned__c", ) @classmethod - def validate_field(cls, v: Any, info: ValidationInfo, **kwargs): - if isinstance(v, str) and v.strip == "": + def validate_field( + cls, v: Any, info: ValidationInfo, **kwargs: dict[str, Any] + ) -> Any: + if isinstance(v, str) and v.strip() == "": raise ValueError(f"Value is required for {info.field_name}") return v @@ -1102,8 +1108,10 @@ class Project1BRow(BaseProjectRow, AddressFields): "Measurement_of_Effectiveness__c", ) @classmethod - def validate_field(cls, v: Any, info: ValidationInfo, **kwargs): - if isinstance(v, str) and v.strip == "": + def validate_field( + cls, v: Any, info: ValidationInfo, **kwargs: dict[str, Any] + ) -> Any: + if isinstance(v, str) and v.strip() == "": raise ValueError(f"Value is required for {info.field_name}") return v @@ -1237,8 +1245,10 @@ class Project1CRow(BaseProjectRow, AddressFields): @field_validator("Access_to_Public_Transit__c") @classmethod - def validate_field(cls, v: Any, info: ValidationInfo, **kwargs): - if isinstance(v, str) and v.strip == "": + def validate_field( + cls, v: Any, info: ValidationInfo, **kwargs: dict[str, Any] + ) -> Any: + if isinstance(v, str) and v.strip() == "": raise ValueError(f"Value is required for {info.field_name}") return v @@ -1333,8 +1343,10 @@ class SubrecipientRow(BaseModel): "State_Abbreviated__c", ) @classmethod - def validate_field(cls, v: Any, info: ValidationInfo, **kwargs): - if isinstance(v, str) and v.strip == "": + def validate_field( + cls, v: Any, info: ValidationInfo, **kwargs: dict[str, Any] + ) -> Any: + if isinstance(v, str) and v.strip() == "": raise ValueError(f"Value is required for {info.field_name}") return v @@ -1376,15 +1388,21 @@ class CoverSheetRow(BaseModel): @field_validator("expenditure_category_group") @classmethod - def validate_code(cls, v: Any, info: ValidationInfo, **kwargs): + def validate_code( + cls, v: Any, info: ValidationInfo, **kwargs: dict[str, Any] + ) -> Any: if v is None or v.strip() == "": raise ValueError("EC code must be set") return v @field_validator("detailed_expenditure_category") @classmethod - def validate_code_name_pair(cls, v: Any, info: ValidationInfo, **kwargs): - expenditure_category_group = info.data.get("expenditure_category_group") + def validate_code_name_pair( + cls, v: Any, info: ValidationInfo, **kwargs: dict[str, Any] + ) -> Any: + expenditure_category_group = ProjectType( + info.data.get("expenditure_category_group") + ) expected_name = NAME_BY_PROJECT.get(expenditure_category_group) if not expected_name: diff --git a/python/src/schemas/schema_versions.py b/python/src/schemas/schema_versions.py index 12515a20..b2296df4 100644 --- a/python/src/schemas/schema_versions.py +++ b/python/src/schemas/schema_versions.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Type, Union +from typing import Any, List, Type, Union from pydantic import BaseModel, Field, ValidationInfo, field_validator @@ -30,24 +30,18 @@ """ -class SubrecipientRow: - Type[Union[V2024_04_01_SubrecipientRow, V2024_05_24_SubrecipientRow]] +type SubrecipientRow = Union[V2024_04_01_SubrecipientRow, V2024_05_24_SubrecipientRow] -class CoverSheetRow: - Type[Union[V2024_04_01_CoverSheetRow, V2024_05_24_CoverSheetRow]] +type CoverSheetRow = Union[V2024_04_01_CoverSheetRow, V2024_05_24_CoverSheetRow] -class Project1ARow: - Type[Union[V2024_04_01_Project1ARow, V2024_05_24_Project1ARow]] - - -class Project1BRow: - Type[Union[V2024_04_01_Project1BRow, V2024_05_24_Project1BRow]] - - -class Project1CRow: - Type[Union[V2024_04_01_Project1CRow, V2024_05_24_Project1CRow]] +Project1ARow = Union[V2024_04_01_Project1ARow, V2024_05_24_Project1ARow] +Project1BRow = Union[V2024_04_01_Project1BRow, V2024_05_24_Project1BRow] +Project1CRow = Union[V2024_04_01_Project1CRow, V2024_05_24_Project1CRow] +V2024_05_24_ProjectRows = Union[ + V2024_05_24_Project1ARow, V2024_05_24_Project1BRow, V2024_05_24_Project1CRow +] class Version(Enum): @@ -61,11 +55,11 @@ def active_version(cls) -> "Version": return cls.V2024_05_24 @classmethod - def compatible_older_versions(cls): + def compatible_older_versions(cls) -> List["Version"]: return [cls.V2023_12_12, cls.V2024_01_07, cls.V2024_04_01] @classmethod - def compatible_newer_versions(cls): + def compatible_newer_versions(cls) -> List["Version"]: return [] @@ -74,7 +68,9 @@ class LogicSheetVersion(BaseModel): @field_validator("version") @classmethod - def validate_field(cls, version: Version, info: ValidationInfo, **kwargs): + def validate_field( + cls, version: Version, info: ValidationInfo, **kwargs: dict[str, Any] + ) -> Version: if version == Version.active_version(): return version elif version in Version.compatible_older_versions(): @@ -97,15 +93,15 @@ def getVersionFromString(version_string: str) -> Version: # validate_logic_sheet version = None try: - version = Version._value2member_map_[version_string] - except KeyError: + version = Version(version_string) + except Exception: # Handle the edge case of a bad version with the latest schema # We should have already collected a user-facing error for this in validate_logic_sheet version = Version.active_version() return version -def getSubrecipientRowClass(version_string: str) -> SubrecipientRow: +def getSubrecipientRowClass(version_string: str) -> Type[SubrecipientRow]: version = getVersionFromString(version_string) match version: case Version.V2024_05_24: @@ -114,7 +110,7 @@ def getSubrecipientRowClass(version_string: str) -> SubrecipientRow: return V2024_04_01_SubrecipientRow -def getCoverSheetRowClass(version_string: str) -> CoverSheetRow: +def getCoverSheetRowClass(version_string: str) -> Type[CoverSheetRow]: version = getVersionFromString(version_string) match version: case Version.V2024_05_24: @@ -124,9 +120,8 @@ def getCoverSheetRowClass(version_string: str) -> CoverSheetRow: def getSchemaByProject( - version_string: Version, project_type: ProjectType -) -> Union[Project1ARow, Project1BRow, Project1CRow]: - version = getVersionFromString(version_string) + version: Version, project_type: ProjectType +) -> Union[Type[Project1ARow], Type[Project1BRow], Type[Project1CRow]]: match project_type: case ProjectType._1A: return getProject1ARow(version) @@ -136,7 +131,7 @@ def getSchemaByProject( return getProject1CRow(version) -def getProject1ARow(version: Version) -> Project1ARow: +def getProject1ARow(version: Version) -> Type[Project1ARow]: match version: case Version.V2024_05_24: return V2024_05_24_Project1ARow @@ -144,7 +139,7 @@ def getProject1ARow(version: Version) -> Project1ARow: return V2024_04_01_Project1ARow -def getProject1BRow(version: Version) -> Project1BRow: +def getProject1BRow(version: Version) -> Type[Project1BRow]: match version: case Version.V2024_05_24: return V2024_05_24_Project1BRow @@ -152,7 +147,7 @@ def getProject1BRow(version: Version) -> Project1BRow: return V2024_04_01_Project1BRow -def getProject1CRow(version: Version) -> Project1CRow: +def getProject1CRow(version: Version) -> Type[Project1CRow]: match version: case Version.V2024_05_24: return V2024_05_24_Project1CRow @@ -160,7 +155,7 @@ def getProject1CRow(version: Version) -> Project1CRow: return V2024_04_01_Project1CRow -def getSchemaMetadata(version_string: str) -> dict: +def getSchemaMetadata(version_string: str) -> dict[str, dict[str, Any]]: version = getVersionFromString(version_string) match version: case Version.V2024_05_24: diff --git a/python/tests/conftest.py b/python/tests/conftest.py index df57b3f1..5bf845b4 100644 --- a/python/tests/conftest.py +++ b/python/tests/conftest.py @@ -1,9 +1,10 @@ import zipfile -from typing import BinaryIO +from typing import IO, Any, BinaryIO, Dict import openpyxl import pytest from aws_lambda_typing.context import Context +from openpyxl.worksheet.worksheet import Worksheet from src.functions.subrecipient_treasury_report_gen import SubrecipientLambdaPayload from src.lib.output_template_comparator import CPFFileArchive @@ -48,12 +49,16 @@ def valid_workbook_old_compatible_schema() -> openpyxl.Workbook: @pytest.fixture -def valid_coversheet(valid_workbook) -> openpyxl.worksheet.worksheet.Worksheet: +def valid_coversheet( + valid_workbook: openpyxl.Workbook, +) -> Worksheet: return valid_workbook["Cover"] @pytest.fixture -def valid_project_sheet(valid_workbook) -> openpyxl.worksheet.worksheet.Worksheet: +def valid_project_sheet( + valid_workbook: openpyxl.Workbook, +) -> Worksheet: return valid_workbook["Project"] @@ -63,12 +68,12 @@ def valid_workbook_1A() -> openpyxl.Workbook: @pytest.fixture -def valid_project_sheet_1A() -> openpyxl.worksheet.worksheet.Worksheet: +def valid_project_sheet_1A() -> Worksheet: return openpyxl.load_workbook(_SAMPLE_TEMPLATE_1A)["Project"] @pytest.fixture -def valid_workbook_1A_with_conflict() -> openpyxl.worksheet.worksheet.Worksheet: +def valid_workbook_1A_with_conflict() -> openpyxl.Workbook: workbook = openpyxl.load_workbook(_SAMPLE_TEMPLATE_1A) valid_project_sheet_1A = workbook["Project"] valid_project_sheet_1A["C13"] = "updated project 1a test" @@ -76,7 +81,7 @@ def valid_workbook_1A_with_conflict() -> openpyxl.worksheet.worksheet.Worksheet: @pytest.fixture -def valid_second_workbook_1A_sheet() -> openpyxl.worksheet.worksheet.Worksheet: +def valid_second_workbook_1A_sheet() -> openpyxl.Workbook: workbook = openpyxl.load_workbook(_SAMPLE_TEMPLATE_1A) valid_project_sheet_1A = workbook["Project"] valid_project_sheet_1A["C13"] = "test 2" @@ -92,12 +97,12 @@ def valid_workbook_1B() -> openpyxl.Workbook: @pytest.fixture -def valid_project_sheet_1B() -> openpyxl.worksheet.worksheet.Worksheet: +def valid_project_sheet_1B() -> Worksheet: return openpyxl.load_workbook(_SAMPLE_TEMPLATE_1B)["Project"] @pytest.fixture -def valid_workbook_1B_with_conflict() -> openpyxl.worksheet.worksheet.Worksheet: +def valid_workbook_1B_with_conflict() -> openpyxl.Workbook: workbook = openpyxl.load_workbook(_SAMPLE_TEMPLATE_1B) valid_project_sheet_1B = workbook["Project"] valid_project_sheet_1B["C13"] = "updated project 1B test" @@ -105,7 +110,7 @@ def valid_workbook_1B_with_conflict() -> openpyxl.worksheet.worksheet.Worksheet: @pytest.fixture -def valid_second_workbook_1B_sheet() -> openpyxl.worksheet.worksheet.Worksheet: +def valid_second_workbook_1B_sheet() -> openpyxl.Workbook: workbook = openpyxl.load_workbook(_SAMPLE_TEMPLATE_1B) valid_project_sheet_1B = workbook["Project"] valid_project_sheet_1B["C13"] = "test 2" @@ -121,12 +126,12 @@ def valid_workbook_1C() -> openpyxl.Workbook: @pytest.fixture -def valid_project_sheet_1C() -> openpyxl.worksheet.worksheet.Worksheet: +def valid_project_sheet_1C() -> Worksheet: return openpyxl.load_workbook(_SAMPLE_TEMPLATE_1C)["Project"] @pytest.fixture -def valid_workbook_1C_with_conflict() -> openpyxl.worksheet.worksheet.Worksheet: +def valid_workbook_1C_with_conflict() -> openpyxl.Workbook: workbook = openpyxl.load_workbook(_SAMPLE_TEMPLATE_1C) valid_project_sheet_1C = workbook["Project"] valid_project_sheet_1C["C13"] = "updated project 1c test" @@ -134,7 +139,7 @@ def valid_workbook_1C_with_conflict() -> openpyxl.worksheet.worksheet.Worksheet: @pytest.fixture -def valid_second_workbook_1C_sheet() -> openpyxl.worksheet.worksheet.Worksheet: +def valid_second_workbook_1C_sheet() -> openpyxl.Workbook: workbook = openpyxl.load_workbook(_SAMPLE_TEMPLATE_1C) valid_project_sheet_1C = workbook["Project"] valid_project_sheet_1C["C13"] = "test 2" @@ -145,109 +150,131 @@ def valid_second_workbook_1C_sheet() -> openpyxl.worksheet.worksheet.Worksheet: @pytest.fixture -def valid_subrecipientsheet(valid_workbook) -> openpyxl.worksheet.worksheet.Worksheet: +def valid_subrecipientsheet( + valid_workbook: openpyxl.Workbook, +) -> Worksheet: return valid_workbook["Subrecipients"] @pytest.fixture -def invalid_cover_sheet(valid_coversheet): +def invalid_cover_sheet(valid_coversheet: Worksheet) -> Worksheet: valid_coversheet["A2"] = "INVALID" return valid_coversheet @pytest.fixture -def invalid_cover_sheet_missing_code(valid_coversheet): - valid_coversheet["A2"] = None +def invalid_cover_sheet_missing_code( + valid_coversheet: Worksheet, +) -> Worksheet: + valid_coversheet["A2"] = "" return valid_coversheet @pytest.fixture -def invalid_cover_sheet_empty_code(valid_coversheet): +def invalid_cover_sheet_empty_code( + valid_coversheet: Worksheet, +) -> Worksheet: valid_coversheet["A2"] = " " return valid_coversheet @pytest.fixture -def invalid_cover_sheet_empty_desc(valid_coversheet): +def invalid_cover_sheet_empty_desc( + valid_coversheet: Worksheet, +) -> Worksheet: valid_coversheet["B2"] = " " return valid_coversheet @pytest.fixture -def invalid_project_sheet(valid_project_sheet): +def invalid_project_sheet(valid_project_sheet: Worksheet) -> Worksheet: valid_project_sheet["D13"] = "X" * 21 return valid_project_sheet @pytest.fixture -def invalid_project_sheet_missing_field(valid_project_sheet): - valid_project_sheet["D13"] = None +def invalid_project_sheet_missing_field( + valid_project_sheet: Worksheet, +) -> Worksheet: + valid_project_sheet["D13"] = "" return valid_project_sheet @pytest.fixture -def invalid_project_sheet_empty_field(valid_project_sheet): +def invalid_project_sheet_empty_field( + valid_project_sheet: Worksheet, +) -> Worksheet: valid_project_sheet["D13"] = " " return valid_project_sheet @pytest.fixture -def invalid_project_sheet_unmatching_subrecipient_tin_field(valid_workbook): +def invalid_project_sheet_unmatching_subrecipient_tin_field( + valid_workbook: openpyxl.Workbook, +) -> openpyxl.Workbook: valid_workbook["Subrecipients"]["E13"] = "123123124" return valid_workbook @pytest.fixture -def invalid_project_sheet_unmatching_subrecipient_uei_field(valid_workbook): +def invalid_project_sheet_unmatching_subrecipient_uei_field( + valid_workbook: openpyxl.Workbook, +) -> openpyxl.Workbook: valid_workbook["Subrecipients"]["F13"] = "123412341235" return valid_workbook @pytest.fixture -def invalid_project_sheet_unmatching_subrecipient_tin_uei_field(valid_workbook): +def invalid_project_sheet_unmatching_subrecipient_tin_uei_field( + valid_workbook: openpyxl.Workbook, +) -> openpyxl.Workbook: valid_workbook["Subrecipients"]["E13"] = "123123124" valid_workbook["Subrecipients"]["F13"] = "123412341235" return valid_workbook @pytest.fixture -def invalid_subrecipient_sheet(valid_subrecipientsheet): +def invalid_subrecipient_sheet( + valid_subrecipientsheet: Worksheet, +) -> Worksheet: valid_subrecipientsheet["E13"] = "INVALID" return valid_subrecipientsheet @pytest.fixture -def valid_subrecipient_sheet_blank_optional_fields(valid_subrecipientsheet): - valid_subrecipientsheet["K13"] = None - valid_subrecipientsheet["M13"] = None - valid_subrecipientsheet["N13"] = None +def valid_subrecipient_sheet_blank_optional_fields( + valid_subrecipientsheet: Worksheet, +) -> Worksheet: + valid_subrecipientsheet["K13"] = "" + valid_subrecipientsheet["M13"] = "" + valid_subrecipientsheet["N13"] = "" return valid_subrecipientsheet @pytest.fixture -def sample_template(): +def sample_template() -> IO[bytes]: return open(_SAMPLE_TEMPLATE, "rb") @pytest.fixture -def template_workbook(): +def template_workbook() -> openpyxl.Workbook: return openpyxl.load_workbook(_SAMPLE_TEMPLATE) @pytest.fixture -def template_workbook_two(): +def template_workbook_two() -> openpyxl.Workbook: return openpyxl.load_workbook(_SAMPLE_TEMPLATE_2) @pytest.fixture -def cpf_file_archive(sample_template): +def cpf_file_archive(sample_template: IO[bytes]) -> CPFFileArchive: cpf_archive_file = zipfile.ZipFile("_tmp_test.zip", "w") cpf_archive_file.writestr("2024-05-19/TestFile.xlsx", sample_template.read()) return CPFFileArchive(cpf_archive_file) @pytest.fixture -def cpf_file_archive_two(sample_template): +def cpf_file_archive_two(sample_template: IO[bytes]) -> CPFFileArchive: cpf_archive_file = zipfile.ZipFile("_tmp_test.zip", "w") cpf_archive_file.writestr("2024-05-19/TestFile.xlsx", sample_template.read()) cpf_archive_file.writestr("2024-05-19/TestFile2.xlsx", sample_template.read()) @@ -270,22 +297,22 @@ def output_1C_template() -> openpyxl.Workbook: @pytest.fixture -def valid_aws_typing_context(): - valid_context = Context - valid_context._aws_request_id = "dummy_aws_request_id" - valid_context._log_group_name = "dummy_log_group_name" - valid_context._log_stream_name = "dummy_log_stream_name" - valid_context._function_name = "dummy_function_name" - valid_context._memory_limit_in_mb = "128" - valid_context._function_version = "$LATEST" - valid_context._invoked_function_arn = ( +def valid_aws_typing_context() -> Context: + valid_context = Context() + valid_context.aws_request_id = "dummy_aws_request_id" + valid_context.log_group_name = "dummy_log_group_name" + valid_context.log_stream_name = "dummy_log_stream_name" + valid_context.function_name = "dummy_function_name" + valid_context.memory_limit_in_mb = "128" + valid_context.function_version = "$LATEST" + valid_context.invoked_function_arn = ( "arn:aws:lambda:dummy-region:123456789012:function:dummy_function_name" ) return valid_context @pytest.fixture -def valid_subrecipients_json_content(): +def valid_subrecipients_json_content() -> Dict[str, Any]: return { "subrecipients": [ { @@ -319,7 +346,7 @@ def valid_subrecipients_json_content(): }, "createdAt": "2024-07-11T03:15:13.054Z", "updatedAt": "2024-07-15T20:06:44.379Z", - "version": "V2024_05_24", + "version": "v:20240524", } ], } @@ -328,7 +355,7 @@ def valid_subrecipients_json_content(): @pytest.fixture -def sample_subrecipients_generation_event(): +def sample_subrecipients_generation_event() -> Dict[str, Any]: return { "organization": { "id": 12, @@ -340,7 +367,7 @@ def sample_subrecipients_generation_event(): @pytest.fixture -def sample_subrecipients_lambda_payload(): +def sample_subrecipients_lambda_payload() -> SubrecipientLambdaPayload: return SubrecipientLambdaPayload( organization=OrganizationObj( id=99, @@ -352,37 +379,37 @@ def sample_subrecipients_lambda_payload(): @pytest.fixture -def invalid_json_content(): +def invalid_json_content() -> str: return '{"subrecipients": [{"id": 1, "name": "subrecipient 3" invalid json' @pytest.fixture -def no_subrecipients_key_json_content(): +def no_subrecipients_key_json_content() -> Dict[str, Any]: return {"other_key": []} @pytest.fixture -def no_subrecipients_list_json_content(): +def no_subrecipients_list_json_content() -> Dict[str, Any]: return {"subrecipients": "not_a_list"} @pytest.fixture -def empty_subrecipients_list_json_content(): +def empty_subrecipients_list_json_content() -> Dict[str, Any]: return {"subrecipients": []} @pytest.fixture -def subrecipients_no_uploads(): +def subrecipients_no_uploads() -> dict[str, list[dict[str, Any]]]: return {"subrecipients": [{"id": 1, "Name": "Bob"}]} @pytest.fixture -def empty_subrecipient_treasury_template(): +def empty_subrecipient_treasury_template() -> openpyxl.Workbook: return openpyxl.load_workbook(_SAMPLE_SUBRECIPIENT_TEMPLATE_EMPTY) @pytest.fixture -def sample_subrecipient_uploads_with_dates(): +def sample_subrecipient_uploads_with_dates() -> Dict[str, Any]: return { "subrecipientUploads": [ {"id": 1, "updatedAt": "2023-07-01T12:00:00Z"}, diff --git a/python/tests/src/lib/test_create_archive.py b/python/tests/src/lib/test_create_archive.py index 1bf128e1..0f0d2e3e 100644 --- a/python/tests/src/lib/test_create_archive.py +++ b/python/tests/src/lib/test_create_archive.py @@ -6,7 +6,7 @@ from src.functions.create_archive import CreateArchiveLambdaPayload, create_archive -def test_create_archive_creates_zip(): +def test_create_archive_creates_zip() -> None: org_id = 1234 reporting_period_id = 5678 s3_client = MagicMock() @@ -37,7 +37,7 @@ def test_create_archive_creates_zip(): ) -def test_create_archive_lambda_payload(): +def test_create_archive_lambda_payload() -> None: organizationObj = { "organization": { "id": 1234, @@ -52,7 +52,7 @@ def test_create_archive_lambda_payload(): ) -def test_create_archive_lambda_payload_failed(): +def test_create_archive_lambda_payload_failed() -> None: organizationObj = { # "id": "1234", Missing a required field "preferences": {"current_reporting_period_id": 5678}, diff --git a/python/tests/src/lib/test_output_template_comparator.py b/python/tests/src/lib/test_output_template_comparator.py index 839cca5f..2eb85a41 100644 --- a/python/tests/src/lib/test_output_template_comparator.py +++ b/python/tests/src/lib/test_output_template_comparator.py @@ -8,7 +8,7 @@ ) -def test_compare_workbooks_same_values(cpf_file_archive: CPFFileArchive): +def test_compare_workbooks_same_values(cpf_file_archive: CPFFileArchive) -> None: latest_archive = cpf_file_archive previous_archive = cpf_file_archive common_files, new_files, removed_files = compare_workbooks( @@ -21,7 +21,7 @@ def test_compare_workbooks_same_values(cpf_file_archive: CPFFileArchive): def test_compare_workbooks_new_file_values( cpf_file_archive: CPFFileArchive, cpf_file_archive_two: CPFFileArchive -): +) -> None: common_files, new_files, removed_files = compare_workbooks( cpf_file_archive_two, cpf_file_archive ) @@ -30,7 +30,7 @@ def test_compare_workbooks_new_file_values( assert len(removed_files) == 0 -def test_compare_sheets_same_values(template_workbook: Workbook): +def test_compare_sheets_same_values(template_workbook: Workbook) -> None: valid_project_sheet = template_workbook["Baseline"] added_columns, removed_columns, header_map = compare_sheet_columns( valid_project_sheet, valid_project_sheet @@ -42,7 +42,7 @@ def test_compare_sheets_same_values(template_workbook: Workbook): def test_compare_sheets_different_values( template_workbook: Workbook, template_workbook_two: Workbook -): +) -> None: valid_project_sheet_1 = template_workbook["Baseline"] valid_project_sheet_2 = template_workbook_two["Baseline"] added_columns, removed_columns, header_map = compare_sheet_columns( @@ -55,7 +55,7 @@ def test_compare_sheets_different_values( def test_compare_cells_same_values( template_workbook: Workbook, template_workbook_two: Workbook -): +) -> None: valid_project_sheet_1 = template_workbook["Baseline"] valid_project_sheet_2 = template_workbook_two["Baseline"] _, _, header_map = compare_sheet_columns( @@ -67,7 +67,7 @@ def test_compare_cells_same_values( assert len(cell_value_differences) == 4 -def test_compare_cells_diff_values(template_workbook: Workbook): +def test_compare_cells_diff_values(template_workbook: Workbook) -> None: valid_project_sheet = template_workbook["Baseline"] _, _, header_map = compare_sheet_columns(valid_project_sheet, valid_project_sheet) cell_value_differences = compare_cell_values( diff --git a/python/tests/src/lib/test_subrecipient_treasury_report_gen.py b/python/tests/src/lib/test_subrecipient_treasury_report_gen.py index 805ce406..c83c33fe 100644 --- a/python/tests/src/lib/test_subrecipient_treasury_report_gen.py +++ b/python/tests/src/lib/test_subrecipient_treasury_report_gen.py @@ -1,6 +1,7 @@ import json from tempfile import NamedTemporaryFile -from unittest.mock import ANY, MagicMock, call, patch +from typing import Any, Dict, List +from unittest.mock import ANY, MagicMock, Mock, call, patch import pytest from aws_lambda_typing.context import Context @@ -9,6 +10,7 @@ from src.functions.subrecipient_treasury_report_gen import ( FIRST_BLANK_ROW_NUM, WORKSHEET_NAME, + SubrecipientLambdaPayload, get_most_recent_upload, handle, process_event, @@ -21,8 +23,8 @@ class TestHandleNoEventOrNoContext: @patch("src.functions.subrecipient_treasury_report_gen.get_logger") def test_handle_no_event_provided( - self, mock_get_logger, valid_aws_typing_context: Context - ): + self, mock_get_logger: Mock, valid_aws_typing_context: Context + ) -> None: mock_logger = MagicMock() mock_get_logger.return_value = mock_logger @@ -30,10 +32,10 @@ def test_handle_no_event_provided( mock_logger.exception.assert_called_once_with("Missing event or context") @patch("src.functions.subrecipient_treasury_report_gen.get_logger") - def test_handle_no_context_provided(self, mock_get_logger): + def test_handle_no_context_provided(self, mock_get_logger: Mock) -> None: mock_logger = MagicMock() mock_get_logger.return_value = mock_logger - event = {} + event: dict[str, str] = {} context = None handle(event, context) @@ -44,8 +46,8 @@ def test_handle_no_context_provided(self, mock_get_logger): class TestHandleIncompleteEventInput: @patch("src.functions.subrecipient_treasury_report_gen.get_logger") def test_handle_no_organization_id( - self, mock_get_logger, valid_aws_typing_context: Context - ): + self, mock_get_logger: Mock, valid_aws_typing_context: Context + ) -> None: mock_logger = MagicMock() mock_get_logger.return_value = mock_logger @@ -66,8 +68,8 @@ def test_handle_no_organization_id( @patch("src.functions.subrecipient_treasury_report_gen.get_logger") def test_handle_no_preferences( - self, mock_get_logger, valid_aws_typing_context: Context - ): + self, mock_get_logger: Mock, valid_aws_typing_context: Context + ) -> None: mock_logger = MagicMock() mock_get_logger.return_value = mock_logger @@ -79,8 +81,8 @@ def test_handle_no_preferences( @patch("src.functions.subrecipient_treasury_report_gen.get_logger") def test_handle_no_current_reporting_period_id( - self, mock_get_logger, valid_aws_typing_context: Context - ): + self, mock_get_logger: Mock, valid_aws_typing_context: Context + ) -> None: mock_logger = MagicMock() mock_get_logger.return_value = mock_logger @@ -92,8 +94,8 @@ def test_handle_no_current_reporting_period_id( @patch("src.functions.subrecipient_treasury_report_gen.get_logger") def test_handle_no_output_template_id( - self, mock_get_logger, valid_aws_typing_context: Context - ): + self, mock_get_logger: Mock, valid_aws_typing_context: Context + ) -> None: mock_logger = MagicMock() mock_get_logger.return_value = mock_logger @@ -117,12 +119,12 @@ class TestInvalidSubrecipientsFile: @patch("src.functions.subrecipient_treasury_report_gen.tempfile.NamedTemporaryFile") def test_invalid_subrecipients_file( self, - mock_tempfile, - mock_boto_client, - invalid_json_content, - sample_subrecipients_lambda_payload, - monkeypatch, - ): + mock_tempfile: Mock, + mock_boto_client: Mock, + invalid_json_content: str, + sample_subrecipients_lambda_payload: SubrecipientLambdaPayload, + monkeypatch: pytest.MonkeyPatch, + ) -> None: monkeypatch.setenv("REPORTING_DATA_BUCKET_NAME", "test-cpf-reporter") mock_s3_client = MagicMock() mock_boto_client.return_value = mock_s3_client @@ -143,12 +145,12 @@ def test_invalid_subrecipients_file( @patch("src.functions.subrecipient_treasury_report_gen.tempfile.NamedTemporaryFile") def test_no_subrecipients_key( self, - mock_tempfile, - mock_boto_client, - no_subrecipients_key_json_content, - sample_subrecipients_lambda_payload, - monkeypatch, - ): + mock_tempfile: Mock, + mock_boto_client: Mock, + no_subrecipients_key_json_content: str, + sample_subrecipients_lambda_payload: SubrecipientLambdaPayload, + monkeypatch: pytest.MonkeyPatch, + ) -> None: monkeypatch.setenv("REPORTING_DATA_BUCKET_NAME", "test-cpf-reporter") mock_s3_client = MagicMock() mock_boto_client.return_value = mock_s3_client @@ -169,12 +171,12 @@ def test_no_subrecipients_key( @patch("src.functions.subrecipient_treasury_report_gen.tempfile.NamedTemporaryFile") def test_no_subrecipients_list( self, - mock_tempfile, - mock_boto_client, - no_subrecipients_list_json_content, - sample_subrecipients_lambda_payload, - monkeypatch, - ): + mock_tempfile: Mock, + mock_boto_client: Mock, + no_subrecipients_list_json_content: str, + sample_subrecipients_lambda_payload: SubrecipientLambdaPayload, + monkeypatch: pytest.MonkeyPatch, + ) -> None: monkeypatch.setenv("REPORTING_DATA_BUCKET_NAME", "test-cpf-reporter") mock_s3_client = MagicMock() mock_boto_client.return_value = mock_s3_client @@ -195,12 +197,12 @@ def test_no_subrecipients_list( @patch("src.functions.subrecipient_treasury_report_gen.tempfile.NamedTemporaryFile") def test_empty_subrecipients_list( self, - mock_tempfile, - mock_boto_client, - empty_subrecipients_list_json_content, - sample_subrecipients_lambda_payload, - monkeypatch, - ): + mock_tempfile: Mock, + mock_boto_client: Mock, + empty_subrecipients_list_json_content: str, + sample_subrecipients_lambda_payload: SubrecipientLambdaPayload, + monkeypatch: pytest.MonkeyPatch, + ) -> None: monkeypatch.setenv("REPORTING_DATA_BUCKET_NAME", "test-cpf-reporter") mock_s3_client = MagicMock() mock_boto_client.return_value = mock_s3_client @@ -220,8 +222,10 @@ def test_empty_subrecipients_list( class TestWriteSubrecipientsToWorkbook: def test_write_subrecipients_to_workbook_no_uploads( - self, subrecipients_no_uploads, empty_subrecipient_treasury_template - ): + self, + subrecipients_no_uploads: dict[str, list[dict[str, Any]]], + empty_subrecipient_treasury_template: Workbook, + ) -> None: mock_logger = MagicMock() write_subrecipients_to_workbook( @@ -233,8 +237,10 @@ def test_write_subrecipients_to_workbook_no_uploads( ) def test_write_subrecipients_to_workbook_empty_output_file_valid_subrecipients( - self, valid_subrecipients_json_content, empty_subrecipient_treasury_template - ): + self, + valid_subrecipients_json_content: dict[str, list[dict[str, Any]]], + empty_subrecipient_treasury_template: Workbook, + ) -> None: mock_logger = MagicMock() write_subrecipients_to_workbook( @@ -275,8 +281,11 @@ class TestUploadSubrecipientWorkbook: @patch("src.functions.subrecipient_treasury_report_gen.upload_generated_file_to_s3") @patch("src.functions.subrecipient_treasury_report_gen.convert_xlsx_to_csv") def test_upload_workbook( - self, mock_convert_xlsx_to_csv, mock_upload_generated_file_to_s3, monkeypatch - ): + self, + mock_convert_xlsx_to_csv: Mock, + mock_upload_generated_file_to_s3: Mock, + monkeypatch: pytest.MonkeyPatch, + ) -> None: bucket_name = "test-cpf-reporter" monkeypatch.setenv("REPORTING_DATA_BUCKET_NAME", bucket_name) mock_s3_client = MagicMock() @@ -306,18 +315,20 @@ def test_upload_workbook( class TestGetMostRecentUpload: - def test_get_most_recent_upload(self, sample_subrecipient_uploads_with_dates): + def test_get_most_recent_upload( + self, sample_subrecipient_uploads_with_dates: dict[str, list[dict[str, Any]]] + ) -> None: result = get_most_recent_upload(sample_subrecipient_uploads_with_dates) assert result["id"] == 2 - def test_get_most_recent_upload_with_single_entry(self): + def test_get_most_recent_upload_with_single_entry(self) -> None: single_entry_subrecipient = { "subrecipientUploads": [{"id": 1, "updatedAt": "2023-07-01T12:00:00Z"}] } result = get_most_recent_upload(single_entry_subrecipient) assert result["id"] == 1 - def test_get_most_recent_upload_with_empty_list(self): - empty_subrecipient = {"subrecipientUploads": []} + def test_get_most_recent_upload_with_empty_list(self) -> None: + empty_subrecipient: Dict[str, List[str]] = {"subrecipientUploads": []} with pytest.raises(IndexError): get_most_recent_upload(empty_subrecipient) diff --git a/python/tests/src/lib/test_treasury_report_1A.py b/python/tests/src/lib/test_treasury_report_1A.py index bf25c678..31c3d089 100644 --- a/python/tests/src/lib/test_treasury_report_1A.py +++ b/python/tests/src/lib/test_treasury_report_1A.py @@ -1,13 +1,15 @@ from datetime import datetime, timedelta -from src.schemas.schema_versions import getSchemaByProject +from openpyxl import Workbook +from openpyxl.worksheet.worksheet import Worksheet + from src.functions.generate_treasury_report import ( combine_project_rows, - update_project_agency_ids_to_row_map, get_projects_to_remove, + update_project_agency_ids_to_row_map, ) -from src.schemas.schema_versions import Version from src.schemas.project_types import ProjectType +from src.schemas.schema_versions import Version, getSchemaByProject OUTPUT_STARTING_ROW = 8 project_use_code = ProjectType._1A @@ -17,13 +19,13 @@ FIRST_ID = "123" SECOND_ID = "44" V2024_05_24_VERSION_STRING = Version.V2024_05_24.value -AGENCY_ID = 999 +AGENCY_ID = "999" PROJECT_AGENCY_ID = f"{FIRST_ID}_{AGENCY_ID}" PROJECT_AGENCY_ID_2 = f"{SECOND_ID}_{AGENCY_ID}" class TestGenerateOutput1A: - def test_get_projects_to_remove(self, valid_workbook_1A): + def test_get_projects_to_remove(self, valid_workbook_1A: Workbook) -> None: project_agency_ids_to_remove = get_projects_to_remove( workbook=valid_workbook_1A, ProjectRowSchema=ProjectRowSchema, @@ -32,7 +34,7 @@ def test_get_projects_to_remove(self, valid_workbook_1A): assert len(project_agency_ids_to_remove) == 1 assert PROJECT_AGENCY_ID in project_agency_ids_to_remove - def test_update_project_agency_ids_to_row_map_1(self): + def test_update_project_agency_ids_to_row_map_1(self) -> None: project_agency_id_to_row_map = { f"1_{AGENCY_ID}": 13, f"2_{AGENCY_ID}": 14, @@ -45,7 +47,7 @@ def test_update_project_agency_ids_to_row_map_1(self): update_project_agency_ids_to_row_map( project_agency_id_to_row_map=project_agency_id_to_row_map, project_agency_ids_to_remove=project_agency_ids_to_remove, - sheet=None, + sheet=Worksheet(None), highest_row_num=highest_row_num, ) assert len(project_agency_id_to_row_map.keys()) == 4 @@ -55,7 +57,7 @@ def test_update_project_agency_ids_to_row_map_1(self): assert project_agency_id_to_row_map.get(f"4_{AGENCY_ID}") == 15 assert project_agency_id_to_row_map.get(f"5_{AGENCY_ID}") == 16 - def test_update_project_agency_ids_to_row_map_2(self): + def test_update_project_agency_ids_to_row_map_2(self) -> None: project_agency_id_to_row_map = { f"1_{AGENCY_ID}": 13, f"2_{AGENCY_ID}": 14, @@ -68,7 +70,7 @@ def test_update_project_agency_ids_to_row_map_2(self): update_project_agency_ids_to_row_map( project_agency_id_to_row_map=project_agency_id_to_row_map, project_agency_ids_to_remove=project_agency_ids_to_remove, - sheet=None, + sheet=Worksheet(None), highest_row_num=highest_row_num, ) assert len(project_agency_id_to_row_map.keys()) == 5 @@ -78,7 +80,7 @@ def test_update_project_agency_ids_to_row_map_2(self): assert project_agency_id_to_row_map.get(f"4_{AGENCY_ID}") == 16 assert project_agency_id_to_row_map.get(f"5_{AGENCY_ID}") == 17 - def test_update_project_agency_ids_to_row_map_3(self): + def test_update_project_agency_ids_to_row_map_3(self) -> None: project_agency_id_to_row_map = { f"1_{AGENCY_ID}": 13, f"2_{AGENCY_ID}": 14, @@ -91,7 +93,7 @@ def test_update_project_agency_ids_to_row_map_3(self): update_project_agency_ids_to_row_map( project_agency_id_to_row_map=project_agency_id_to_row_map, project_agency_ids_to_remove=project_agency_ids_to_remove, - sheet=None, + sheet=Worksheet(None), highest_row_num=highest_row_num, ) assert len(project_agency_id_to_row_map.keys()) == 4 @@ -101,7 +103,7 @@ def test_update_project_agency_ids_to_row_map_3(self): assert project_agency_id_to_row_map.get(f"4_{AGENCY_ID}") == 15 assert project_agency_id_to_row_map.get(f"5_{AGENCY_ID}") == 16 - def test_update_project_agency_ids_to_row_map_4(self): + def test_update_project_agency_ids_to_row_map_4(self) -> None: project_agency_id_to_row_map = { f"1_{AGENCY_ID}": 13, f"2_{AGENCY_ID}": 14, @@ -114,7 +116,7 @@ def test_update_project_agency_ids_to_row_map_4(self): update_project_agency_ids_to_row_map( project_agency_id_to_row_map=project_agency_id_to_row_map, project_agency_ids_to_remove=project_agency_ids_to_remove, - sheet=None, + sheet=Worksheet(None), highest_row_num=highest_row_num, ) assert len(project_agency_id_to_row_map.keys()) == 4 @@ -124,7 +126,7 @@ def test_update_project_agency_ids_to_row_map_4(self): assert project_agency_id_to_row_map.get(f"4_{AGENCY_ID}") == 16 assert project_agency_id_to_row_map.get(f"5_{AGENCY_ID}") is None - def test_update_project_agency_ids_to_row_map_5(self): + def test_update_project_agency_ids_to_row_map_5(self) -> None: project_agency_id_to_row_map = { f"1_{AGENCY_ID}": 13, f"2_{AGENCY_ID}": 14, @@ -137,7 +139,7 @@ def test_update_project_agency_ids_to_row_map_5(self): update_project_agency_ids_to_row_map( project_agency_id_to_row_map=project_agency_id_to_row_map, project_agency_ids_to_remove=project_agency_ids_to_remove, - sheet=None, + sheet=Worksheet(None), highest_row_num=highest_row_num, ) assert len(project_agency_id_to_row_map.keys()) == 3 @@ -147,7 +149,7 @@ def test_update_project_agency_ids_to_row_map_5(self): assert project_agency_id_to_row_map.get(f"4_{AGENCY_ID}") is None assert project_agency_id_to_row_map.get(f"5_{AGENCY_ID}") == 15 - def test_update_project_agency_ids_to_row_map_6(self): + def test_update_project_agency_ids_to_row_map_6(self) -> None: project_agency_id_to_row_map = { f"1_{AGENCY_ID}": 13, f"2_{AGENCY_ID}": 14, @@ -160,7 +162,7 @@ def test_update_project_agency_ids_to_row_map_6(self): update_project_agency_ids_to_row_map( project_agency_id_to_row_map=project_agency_id_to_row_map, project_agency_ids_to_remove=project_agency_ids_to_remove, - sheet=None, + sheet=Worksheet(None), highest_row_num=highest_row_num, ) assert len(project_agency_id_to_row_map.keys()) == 3 @@ -170,13 +172,13 @@ def test_update_project_agency_ids_to_row_map_6(self): assert project_agency_id_to_row_map.get(f"4_{AGENCY_ID}") == 14 assert project_agency_id_to_row_map.get(f"5_{AGENCY_ID}") == 15 - def test_combine_project_rows(self, valid_workbook_1A): - project_id_agency_id_to_upload_date = {} - project_id_agency_id_to_row_num = {} + def test_combine_project_rows(self, valid_workbook_1A: Workbook) -> None: + project_id_agency_id_to_upload_date: dict[str, datetime] = {} + project_id_agency_id_to_row_num: dict[str, int] = {} createdAt = datetime.now() new_highest_row_num = combine_project_rows( project_workbook=valid_workbook_1A, - output_sheet=None, + output_sheet=Worksheet(None), project_use_code=project_use_code, highest_row_num=12, ProjectRowSchema=ProjectRowSchema, @@ -190,15 +192,15 @@ def test_combine_project_rows(self, valid_workbook_1A): assert new_highest_row_num == 13 def test_combine_project_rows_with_conflicts_1( - self, valid_workbook_1A, valid_workbook_1A_with_conflict - ): + self, valid_workbook_1A: Workbook, valid_workbook_1A_with_conflict: Workbook + ) -> None: """Choose the project with the later created at date""" - project_id_agency_id_to_upload_date = {} - project_id_agency_id_to_row_num = {} + project_id_agency_id_to_upload_date: dict[str, datetime] = {} + project_id_agency_id_to_row_num: dict[str, int] = {} createdAt1 = datetime.now() highest_row_num = combine_project_rows( project_workbook=valid_workbook_1A, - output_sheet=None, + output_sheet=Worksheet(None), project_use_code=project_use_code, highest_row_num=12, ProjectRowSchema=ProjectRowSchema, @@ -211,7 +213,7 @@ def test_combine_project_rows_with_conflicts_1( createdAt2 = datetime.now() new_highest_row_num = combine_project_rows( project_workbook=valid_workbook_1A_with_conflict, - output_sheet=None, + output_sheet=Worksheet(None), project_use_code=project_use_code, highest_row_num=highest_row_num, ProjectRowSchema=ProjectRowSchema, @@ -225,15 +227,15 @@ def test_combine_project_rows_with_conflicts_1( assert new_highest_row_num == 13 def test_combine_project_rows_with_conflicts_2( - self, valid_workbook_1A, valid_workbook_1A_with_conflict - ): + self, valid_workbook_1A: Workbook, valid_workbook_1A_with_conflict: Workbook + ) -> None: """Choose the project with the later created at date""" - project_id_agency_id_to_upload_date = {} - project_id_agency_id_to_row_num = {} + project_id_agency_id_to_upload_date: dict[str, datetime] = {} + project_id_agency_id_to_row_num: dict[str, int] = {} createdAt1 = datetime.now() highest_row_num = combine_project_rows( project_workbook=valid_workbook_1A, - output_sheet=None, + output_sheet=Worksheet(None), project_use_code=project_use_code, highest_row_num=12, ProjectRowSchema=ProjectRowSchema, @@ -246,7 +248,7 @@ def test_combine_project_rows_with_conflicts_2( createdAt2 = datetime.now() - timedelta(days=1) new_highest_row_num = combine_project_rows( project_workbook=valid_workbook_1A_with_conflict, - output_sheet=None, + output_sheet=Worksheet(None), project_use_code=project_use_code, highest_row_num=highest_row_num, ProjectRowSchema=ProjectRowSchema, @@ -260,15 +262,15 @@ def test_combine_project_rows_with_conflicts_2( assert new_highest_row_num == 13 def test_combine_project_rows_multiple( - self, valid_workbook_1A, valid_second_workbook_1A_sheet - ): + self, valid_workbook_1A: Workbook, valid_second_workbook_1A_sheet: Workbook + ) -> None: """Choose the project with the later created at date""" - project_id_agency_id_to_upload_date = {} - project_id_agency_id_to_row_num = {} + project_id_agency_id_to_upload_date: dict[str, datetime] = {} + project_id_agency_id_to_row_num: dict[str, int] = {} createdAt1 = datetime.now() highest_row_num = combine_project_rows( project_workbook=valid_workbook_1A, - output_sheet=None, + output_sheet=Worksheet(None), project_use_code=project_use_code, highest_row_num=12, ProjectRowSchema=ProjectRowSchema, @@ -281,7 +283,7 @@ def test_combine_project_rows_multiple( createdAt2 = datetime.now() new_highest_row_num = combine_project_rows( project_workbook=valid_second_workbook_1A_sheet, - output_sheet=None, + output_sheet=Worksheet(None), project_use_code=project_use_code, highest_row_num=highest_row_num, ProjectRowSchema=ProjectRowSchema, diff --git a/python/tests/src/lib/test_treasury_report_1B.py b/python/tests/src/lib/test_treasury_report_1B.py index f0beed6e..b05646f4 100644 --- a/python/tests/src/lib/test_treasury_report_1B.py +++ b/python/tests/src/lib/test_treasury_report_1B.py @@ -1,13 +1,15 @@ from datetime import datetime, timedelta -from src.schemas.schema_versions import getSchemaByProject +from openpyxl import Workbook +from openpyxl.worksheet.worksheet import Worksheet + from src.functions.generate_treasury_report import ( combine_project_rows, - update_project_agency_ids_to_row_map, get_projects_to_remove, + update_project_agency_ids_to_row_map, ) -from src.schemas.schema_versions import Version from src.schemas.project_types import ProjectType +from src.schemas.schema_versions import Version, getSchemaByProject OUTPUT_STARTING_ROW = 8 project_use_code = ProjectType._1B @@ -17,13 +19,13 @@ FIRST_ID = "222" SECOND_ID = "44" V2024_05_24_VERSION_STRING = Version.V2024_05_24.value -AGENCY_ID = 999 +AGENCY_ID = "999" PROJECT_AGENCY_ID = f"{FIRST_ID}_{AGENCY_ID}" PROJECT_AGENCY_ID_2 = f"{SECOND_ID}_{AGENCY_ID}" class TestGenerateOutput1B: - def test_get_projects_to_remove(self, valid_workbook_1B): + def test_get_projects_to_remove(self, valid_workbook_1B: Workbook) -> None: project_agency_ids_to_remove = get_projects_to_remove( workbook=valid_workbook_1B, ProjectRowSchema=ProjectRowSchema, @@ -32,7 +34,7 @@ def test_get_projects_to_remove(self, valid_workbook_1B): assert len(project_agency_ids_to_remove) == 1 assert PROJECT_AGENCY_ID in project_agency_ids_to_remove - def test_update_project_agency_ids_to_row_map_1(self): + def test_update_project_agency_ids_to_row_map_1(self) -> None: project_agency_id_to_row_map = { f"1_{AGENCY_ID}": 13, f"2_{AGENCY_ID}": 14, @@ -45,7 +47,7 @@ def test_update_project_agency_ids_to_row_map_1(self): update_project_agency_ids_to_row_map( project_agency_id_to_row_map=project_agency_id_to_row_map, project_agency_ids_to_remove=project_agency_ids_to_remove, - sheet=None, + sheet=Worksheet(None), highest_row_num=highest_row_num, ) assert len(project_agency_id_to_row_map.keys()) == 4 @@ -55,7 +57,7 @@ def test_update_project_agency_ids_to_row_map_1(self): assert project_agency_id_to_row_map.get(f"4_{AGENCY_ID}") == 15 assert project_agency_id_to_row_map.get(f"5_{AGENCY_ID}") == 16 - def test_update_project_agency_ids_to_row_map_2(self): + def test_update_project_agency_ids_to_row_map_2(self) -> None: project_agency_id_to_row_map = { f"1_{AGENCY_ID}": 13, f"2_{AGENCY_ID}": 14, @@ -68,7 +70,7 @@ def test_update_project_agency_ids_to_row_map_2(self): update_project_agency_ids_to_row_map( project_agency_id_to_row_map=project_agency_id_to_row_map, project_agency_ids_to_remove=project_agency_ids_to_remove, - sheet=None, + sheet=Worksheet(None), highest_row_num=highest_row_num, ) assert len(project_agency_id_to_row_map.keys()) == 5 @@ -78,7 +80,7 @@ def test_update_project_agency_ids_to_row_map_2(self): assert project_agency_id_to_row_map.get(f"4_{AGENCY_ID}") == 16 assert project_agency_id_to_row_map.get(f"5_{AGENCY_ID}") == 17 - def test_update_project_agency_ids_to_row_map_3(self): + def test_update_project_agency_ids_to_row_map_3(self) -> None: project_agency_id_to_row_map = { f"1_{AGENCY_ID}": 13, f"2_{AGENCY_ID}": 14, @@ -91,7 +93,7 @@ def test_update_project_agency_ids_to_row_map_3(self): update_project_agency_ids_to_row_map( project_agency_id_to_row_map=project_agency_id_to_row_map, project_agency_ids_to_remove=project_agency_ids_to_remove, - sheet=None, + sheet=Worksheet(None), highest_row_num=highest_row_num, ) assert len(project_agency_id_to_row_map.keys()) == 4 @@ -101,7 +103,7 @@ def test_update_project_agency_ids_to_row_map_3(self): assert project_agency_id_to_row_map.get(f"4_{AGENCY_ID}") == 15 assert project_agency_id_to_row_map.get(f"5_{AGENCY_ID}") == 16 - def test_update_project_agency_ids_to_row_map_4(self): + def test_update_project_agency_ids_to_row_map_4(self) -> None: project_agency_id_to_row_map = { f"1_{AGENCY_ID}": 13, f"2_{AGENCY_ID}": 14, @@ -114,7 +116,7 @@ def test_update_project_agency_ids_to_row_map_4(self): update_project_agency_ids_to_row_map( project_agency_id_to_row_map=project_agency_id_to_row_map, project_agency_ids_to_remove=project_agency_ids_to_remove, - sheet=None, + sheet=Worksheet(None), highest_row_num=highest_row_num, ) assert len(project_agency_id_to_row_map.keys()) == 4 @@ -124,7 +126,7 @@ def test_update_project_agency_ids_to_row_map_4(self): assert project_agency_id_to_row_map.get(f"4_{AGENCY_ID}") == 16 assert project_agency_id_to_row_map.get(f"5_{AGENCY_ID}") is None - def test_update_project_agency_ids_to_row_map_5(self): + def test_update_project_agency_ids_to_row_map_5(self) -> None: project_agency_id_to_row_map = { f"1_{AGENCY_ID}": 13, f"2_{AGENCY_ID}": 14, @@ -137,7 +139,7 @@ def test_update_project_agency_ids_to_row_map_5(self): update_project_agency_ids_to_row_map( project_agency_id_to_row_map=project_agency_id_to_row_map, project_agency_ids_to_remove=project_agency_ids_to_remove, - sheet=None, + sheet=Worksheet(None), highest_row_num=highest_row_num, ) assert len(project_agency_id_to_row_map.keys()) == 3 @@ -147,7 +149,7 @@ def test_update_project_agency_ids_to_row_map_5(self): assert project_agency_id_to_row_map.get(f"4_{AGENCY_ID}") is None assert project_agency_id_to_row_map.get(f"5_{AGENCY_ID}") == 15 - def test_update_project_agency_ids_to_row_map_6(self): + def test_update_project_agency_ids_to_row_map_6(self) -> None: project_agency_id_to_row_map = { f"1_{AGENCY_ID}": 13, f"2_{AGENCY_ID}": 14, @@ -160,7 +162,7 @@ def test_update_project_agency_ids_to_row_map_6(self): update_project_agency_ids_to_row_map( project_agency_id_to_row_map=project_agency_id_to_row_map, project_agency_ids_to_remove=project_agency_ids_to_remove, - sheet=None, + sheet=Worksheet(None), highest_row_num=highest_row_num, ) assert len(project_agency_id_to_row_map.keys()) == 3 @@ -170,13 +172,13 @@ def test_update_project_agency_ids_to_row_map_6(self): assert project_agency_id_to_row_map.get(f"4_{AGENCY_ID}") == 14 assert project_agency_id_to_row_map.get(f"5_{AGENCY_ID}") == 15 - def test_combine_project_rows(self, valid_workbook_1B): - project_id_agency_id_to_upload_date = {} - project_id_agency_id_to_row_num = {} + def test_combine_project_rows(self, valid_workbook_1B: Workbook) -> None: + project_id_agency_id_to_upload_date: dict[str, datetime] = {} + project_id_agency_id_to_row_num: dict[str, int] = {} createdAt = datetime.now() new_highest_row_num = combine_project_rows( project_workbook=valid_workbook_1B, - output_sheet=None, + output_sheet=Worksheet(None), project_use_code=project_use_code, highest_row_num=12, ProjectRowSchema=ProjectRowSchema, @@ -190,15 +192,15 @@ def test_combine_project_rows(self, valid_workbook_1B): assert new_highest_row_num == 13 def test_combine_project_rows_with_conflicts_1( - self, valid_workbook_1B, valid_workbook_1B_with_conflict - ): + self, valid_workbook_1B: Workbook, valid_workbook_1B_with_conflict: Workbook + ) -> None: """Choose the project with the later created at date""" - project_id_agency_id_to_upload_date = {} - project_id_agency_id_to_row_num = {} + project_id_agency_id_to_upload_date: dict[str, datetime] = {} + project_id_agency_id_to_row_num: dict[str, int] = {} createdAt1 = datetime.now() highest_row_num = combine_project_rows( project_workbook=valid_workbook_1B, - output_sheet=None, + output_sheet=Worksheet(None), project_use_code=project_use_code, highest_row_num=12, ProjectRowSchema=ProjectRowSchema, @@ -211,7 +213,7 @@ def test_combine_project_rows_with_conflicts_1( createdAt2 = datetime.now() new_highest_row_num = combine_project_rows( project_workbook=valid_workbook_1B_with_conflict, - output_sheet=None, + output_sheet=Worksheet(None), project_use_code=project_use_code, highest_row_num=highest_row_num, ProjectRowSchema=ProjectRowSchema, @@ -225,15 +227,15 @@ def test_combine_project_rows_with_conflicts_1( assert new_highest_row_num == 13 def test_combine_project_rows_with_conflicts_2( - self, valid_workbook_1B, valid_workbook_1B_with_conflict - ): + self, valid_workbook_1B: Workbook, valid_workbook_1B_with_conflict: Workbook + ) -> None: """Choose the project with the later created at date""" - project_id_agency_id_to_upload_date = {} - project_id_agency_id_to_row_num = {} + project_id_agency_id_to_upload_date: dict[str, datetime] = {} + project_id_agency_id_to_row_num: dict[str, int] = {} createdAt1 = datetime.now() highest_row_num = combine_project_rows( project_workbook=valid_workbook_1B, - output_sheet=None, + output_sheet=Worksheet(None), project_use_code=project_use_code, highest_row_num=12, ProjectRowSchema=ProjectRowSchema, @@ -246,7 +248,7 @@ def test_combine_project_rows_with_conflicts_2( createdAt2 = datetime.now() - timedelta(days=1) new_highest_row_num = combine_project_rows( project_workbook=valid_workbook_1B_with_conflict, - output_sheet=None, + output_sheet=Worksheet(None), project_use_code=project_use_code, highest_row_num=highest_row_num, ProjectRowSchema=ProjectRowSchema, @@ -260,15 +262,15 @@ def test_combine_project_rows_with_conflicts_2( assert new_highest_row_num == 13 def test_combine_project_rows_multiple( - self, valid_workbook_1B, valid_second_workbook_1B_sheet - ): + self, valid_workbook_1B: Workbook, valid_second_workbook_1B_sheet: Workbook + ) -> None: """Choose the project with the later created at date""" - project_id_agency_id_to_upload_date = {} - project_id_agency_id_to_row_num = {} + project_id_agency_id_to_upload_date: dict[str, datetime] = {} + project_id_agency_id_to_row_num: dict[str, int] = {} createdAt1 = datetime.now() highest_row_num = combine_project_rows( project_workbook=valid_workbook_1B, - output_sheet=None, + output_sheet=Worksheet(None), project_use_code=project_use_code, highest_row_num=12, ProjectRowSchema=ProjectRowSchema, @@ -281,7 +283,7 @@ def test_combine_project_rows_multiple( createdAt2 = datetime.now() new_highest_row_num = combine_project_rows( project_workbook=valid_second_workbook_1B_sheet, - output_sheet=None, + output_sheet=Worksheet(None), project_use_code=project_use_code, highest_row_num=highest_row_num, ProjectRowSchema=ProjectRowSchema, diff --git a/python/tests/src/lib/test_treasury_report_1C.py b/python/tests/src/lib/test_treasury_report_1C.py index d781f9c1..ed6c4199 100644 --- a/python/tests/src/lib/test_treasury_report_1C.py +++ b/python/tests/src/lib/test_treasury_report_1C.py @@ -1,13 +1,15 @@ from datetime import datetime, timedelta -from src.schemas.schema_versions import getSchemaByProject +from openpyxl import Workbook +from openpyxl.worksheet.worksheet import Worksheet + from src.functions.generate_treasury_report import ( combine_project_rows, - update_project_agency_ids_to_row_map, get_projects_to_remove, + update_project_agency_ids_to_row_map, ) -from src.schemas.schema_versions import Version from src.schemas.project_types import ProjectType +from src.schemas.schema_versions import Version, getSchemaByProject OUTPUT_STARTING_ROW = 8 project_use_code = ProjectType._1C @@ -17,13 +19,13 @@ FIRST_ID = "33" SECOND_ID = "44" V2024_05_24_VERSION_STRING = Version.V2024_05_24.value -AGENCY_ID = 999 +AGENCY_ID = "999" PROJECT_AGENCY_ID = f"{FIRST_ID}_{AGENCY_ID}" PROJECT_AGENCY_ID_2 = f"{SECOND_ID}_{AGENCY_ID}" class TestGenerateOutput1C: - def test_get_projects_to_remove(self, valid_workbook_1C): + def test_get_projects_to_remove(self, valid_workbook_1C: Workbook) -> None: project_agency_ids_to_remove = get_projects_to_remove( workbook=valid_workbook_1C, ProjectRowSchema=ProjectRowSchema, @@ -32,7 +34,7 @@ def test_get_projects_to_remove(self, valid_workbook_1C): assert len(project_agency_ids_to_remove) == 1 assert PROJECT_AGENCY_ID in project_agency_ids_to_remove - def test_update_project_agency_ids_to_row_map_1(self): + def test_update_project_agency_ids_to_row_map_1(self) -> None: project_agency_id_to_row_map = { f"1_{AGENCY_ID}": 13, f"2_{AGENCY_ID}": 14, @@ -45,7 +47,7 @@ def test_update_project_agency_ids_to_row_map_1(self): update_project_agency_ids_to_row_map( project_agency_id_to_row_map=project_agency_id_to_row_map, project_agency_ids_to_remove=project_agency_ids_to_remove, - sheet=None, + sheet=Worksheet(None), highest_row_num=highest_row_num, ) assert len(project_agency_id_to_row_map.keys()) == 4 @@ -55,7 +57,7 @@ def test_update_project_agency_ids_to_row_map_1(self): assert project_agency_id_to_row_map.get(f"4_{AGENCY_ID}") == 15 assert project_agency_id_to_row_map.get(f"5_{AGENCY_ID}") == 16 - def test_update_project_agency_ids_to_row_map_2(self): + def test_update_project_agency_ids_to_row_map_2(self) -> None: project_agency_id_to_row_map = { f"1_{AGENCY_ID}": 13, f"2_{AGENCY_ID}": 14, @@ -68,7 +70,7 @@ def test_update_project_agency_ids_to_row_map_2(self): update_project_agency_ids_to_row_map( project_agency_id_to_row_map=project_agency_id_to_row_map, project_agency_ids_to_remove=project_agency_ids_to_remove, - sheet=None, + sheet=Worksheet(None), highest_row_num=highest_row_num, ) assert len(project_agency_id_to_row_map.keys()) == 5 @@ -78,7 +80,7 @@ def test_update_project_agency_ids_to_row_map_2(self): assert project_agency_id_to_row_map.get(f"4_{AGENCY_ID}") == 16 assert project_agency_id_to_row_map.get(f"5_{AGENCY_ID}") == 17 - def test_update_project_agency_ids_to_row_map_3(self): + def test_update_project_agency_ids_to_row_map_3(self) -> None: project_agency_id_to_row_map = { f"1_{AGENCY_ID}": 13, f"2_{AGENCY_ID}": 14, @@ -91,7 +93,7 @@ def test_update_project_agency_ids_to_row_map_3(self): update_project_agency_ids_to_row_map( project_agency_id_to_row_map=project_agency_id_to_row_map, project_agency_ids_to_remove=project_agency_ids_to_remove, - sheet=None, + sheet=Worksheet(None), highest_row_num=highest_row_num, ) assert len(project_agency_id_to_row_map.keys()) == 4 @@ -101,7 +103,7 @@ def test_update_project_agency_ids_to_row_map_3(self): assert project_agency_id_to_row_map.get(f"4_{AGENCY_ID}") == 15 assert project_agency_id_to_row_map.get(f"5_{AGENCY_ID}") == 16 - def test_update_project_agency_ids_to_row_map_4(self): + def test_update_project_agency_ids_to_row_map_4(self) -> None: project_agency_id_to_row_map = { f"1_{AGENCY_ID}": 13, f"2_{AGENCY_ID}": 14, @@ -114,7 +116,7 @@ def test_update_project_agency_ids_to_row_map_4(self): update_project_agency_ids_to_row_map( project_agency_id_to_row_map=project_agency_id_to_row_map, project_agency_ids_to_remove=project_agency_ids_to_remove, - sheet=None, + sheet=Worksheet(None), highest_row_num=highest_row_num, ) assert len(project_agency_id_to_row_map.keys()) == 4 @@ -124,7 +126,7 @@ def test_update_project_agency_ids_to_row_map_4(self): assert project_agency_id_to_row_map.get(f"4_{AGENCY_ID}") == 16 assert project_agency_id_to_row_map.get(f"5_{AGENCY_ID}") is None - def test_update_project_agency_ids_to_row_map_5(self): + def test_update_project_agency_ids_to_row_map_5(self) -> None: project_agency_id_to_row_map = { f"1_{AGENCY_ID}": 13, f"2_{AGENCY_ID}": 14, @@ -137,7 +139,7 @@ def test_update_project_agency_ids_to_row_map_5(self): update_project_agency_ids_to_row_map( project_agency_id_to_row_map=project_agency_id_to_row_map, project_agency_ids_to_remove=project_agency_ids_to_remove, - sheet=None, + sheet=Worksheet(None), highest_row_num=highest_row_num, ) assert len(project_agency_id_to_row_map.keys()) == 3 @@ -147,7 +149,7 @@ def test_update_project_agency_ids_to_row_map_5(self): assert project_agency_id_to_row_map.get(f"4_{AGENCY_ID}") is None assert project_agency_id_to_row_map.get(f"5_{AGENCY_ID}") == 15 - def test_update_project_agency_ids_to_row_map_6(self): + def test_update_project_agency_ids_to_row_map_6(self) -> None: project_agency_id_to_row_map = { f"1_{AGENCY_ID}": 13, f"2_{AGENCY_ID}": 14, @@ -160,7 +162,7 @@ def test_update_project_agency_ids_to_row_map_6(self): update_project_agency_ids_to_row_map( project_agency_id_to_row_map=project_agency_id_to_row_map, project_agency_ids_to_remove=project_agency_ids_to_remove, - sheet=None, + sheet=Worksheet(None), highest_row_num=highest_row_num, ) assert len(project_agency_id_to_row_map.keys()) == 3 @@ -170,13 +172,13 @@ def test_update_project_agency_ids_to_row_map_6(self): assert project_agency_id_to_row_map.get(f"4_{AGENCY_ID}") == 14 assert project_agency_id_to_row_map.get(f"5_{AGENCY_ID}") == 15 - def test_combine_project_rows(self, valid_workbook_1C): - project_id_agency_id_to_upload_date = {} - project_id_agency_id_to_row_num = {} + def test_combine_project_rows(self, valid_workbook_1C: Workbook) -> None: + project_id_agency_id_to_upload_date: dict[str, datetime] = {} + project_id_agency_id_to_row_num: dict[str, int] = {} createdAt = datetime.now() new_highest_row_num = combine_project_rows( project_workbook=valid_workbook_1C, - output_sheet=None, + output_sheet=Worksheet(None), project_use_code=project_use_code, highest_row_num=12, ProjectRowSchema=ProjectRowSchema, @@ -190,15 +192,15 @@ def test_combine_project_rows(self, valid_workbook_1C): assert new_highest_row_num == 13 def test_combine_project_rows_with_conflicts_1( - self, valid_workbook_1C, valid_workbook_1C_with_conflict - ): + self, valid_workbook_1C: Workbook, valid_workbook_1C_with_conflict: Workbook + ) -> None: """Choose the project with the later created at date""" - project_id_agency_id_to_upload_date = {} - project_id_agency_id_to_row_num = {} + project_id_agency_id_to_upload_date: dict[str, datetime] = {} + project_id_agency_id_to_row_num: dict[str, int] = {} createdAt1 = datetime.now() highest_row_num = combine_project_rows( project_workbook=valid_workbook_1C, - output_sheet=None, + output_sheet=Worksheet(None), project_use_code=project_use_code, highest_row_num=12, ProjectRowSchema=ProjectRowSchema, @@ -211,7 +213,7 @@ def test_combine_project_rows_with_conflicts_1( createdAt2 = datetime.now() new_highest_row_num = combine_project_rows( project_workbook=valid_workbook_1C_with_conflict, - output_sheet=None, + output_sheet=Worksheet(None), project_use_code=project_use_code, highest_row_num=highest_row_num, ProjectRowSchema=ProjectRowSchema, @@ -225,15 +227,15 @@ def test_combine_project_rows_with_conflicts_1( assert new_highest_row_num == 13 def test_combine_project_rows_with_conflicts_2( - self, valid_workbook_1C, valid_workbook_1C_with_conflict - ): + self, valid_workbook_1C: Workbook, valid_workbook_1C_with_conflict: Workbook + ) -> None: """Choose the project with the later created at date""" - project_id_agency_id_to_upload_date = {} - project_id_agency_id_to_row_num = {} + project_id_agency_id_to_upload_date: dict[str, datetime] = {} + project_id_agency_id_to_row_num: dict[str, int] = {} createdAt1 = datetime.now() highest_row_num = combine_project_rows( project_workbook=valid_workbook_1C, - output_sheet=None, + output_sheet=Worksheet(None), project_use_code=project_use_code, highest_row_num=12, ProjectRowSchema=ProjectRowSchema, @@ -246,7 +248,7 @@ def test_combine_project_rows_with_conflicts_2( createdAt2 = datetime.now() - timedelta(days=1) new_highest_row_num = combine_project_rows( project_workbook=valid_workbook_1C_with_conflict, - output_sheet=None, + output_sheet=Worksheet(None), project_use_code=project_use_code, highest_row_num=highest_row_num, ProjectRowSchema=ProjectRowSchema, @@ -260,15 +262,15 @@ def test_combine_project_rows_with_conflicts_2( assert new_highest_row_num == 13 def test_combine_project_rows_multiple( - self, valid_workbook_1C, valid_second_workbook_1C_sheet - ): + self, valid_workbook_1C: Workbook, valid_second_workbook_1C_sheet: Workbook + ) -> None: """Choose the project with the later created at date""" - project_id_agency_id_to_upload_date = {} - project_id_agency_id_to_row_num = {} + project_id_agency_id_to_upload_date: dict[str, datetime] = {} + project_id_agency_id_to_row_num: dict[str, int] = {} createdAt1 = datetime.now() highest_row_num = combine_project_rows( project_workbook=valid_workbook_1C, - output_sheet=None, + output_sheet=Worksheet(None), project_use_code=project_use_code, highest_row_num=12, ProjectRowSchema=ProjectRowSchema, @@ -281,7 +283,7 @@ def test_combine_project_rows_multiple( createdAt2 = datetime.now() new_highest_row_num = combine_project_rows( project_workbook=valid_second_workbook_1C_sheet, - output_sheet=None, + output_sheet=Worksheet(None), project_use_code=project_use_code, highest_row_num=highest_row_num, ProjectRowSchema=ProjectRowSchema, diff --git a/python/tests/src/lib/test_workbook_validator.py b/python/tests/src/lib/test_workbook_validator.py index 9a81f873..8ad084c4 100644 --- a/python/tests/src/lib/test_workbook_validator.py +++ b/python/tests/src/lib/test_workbook_validator.py @@ -1,4 +1,5 @@ -from typing import BinaryIO +import typing +from typing import Any, BinaryIO, Optional, Tuple import pytest from openpyxl import Workbook @@ -16,16 +17,21 @@ validate_workbook, validate_workbook_sheets, ) -from src.schemas.schema_versions import getSchemaByProject, Version +from src.schemas.project_types import ProjectType +from src.schemas.schema_versions import ( + V2024_05_24_ProjectRows, + Version, + getSchemaByProject, +) -SAMPLE_EXPENDITURE_CATEGORY_GROUP = "1A" -V2024_05_24_VERSION_STRING = Version.V2024_05_24.value +SAMPLE_EXPENDITURE_CATEGORY_GROUP = ProjectType._1A +V2024_05_24_VERSION = Version.V2024_05_24 class TestValidateWorkbookWithOldSchema: def test_valid_workbook_old_compatible_schema( self, valid_workbook_old_compatible_schema: Workbook - ): + ) -> None: errors, version_string = validate_logic_sheet( valid_workbook_old_compatible_schema["Logic"] ) @@ -43,7 +49,7 @@ class TestIsEmptyRow: ((None if i % 2 == 0 else "" for i in range(11)),), ], ) - def test_empty_row_data(self, row_data): + def test_empty_row_data(self, row_data: Tuple[Any]) -> None: assert is_empty_row(row_data) is True @pytest.mark.parametrize( @@ -53,19 +59,19 @@ def test_empty_row_data(self, row_data): ((i if i % 2 == 0 else str(i) if i % 3 == 0 else None for i in range(20)),), ], ) - def test_non_empty_row_data(self, row_data): + def test_non_empty_row_data(self, row_data: Tuple[Optional[Any], ...]) -> None: assert is_empty_row(row_data) is False class TestValidateWorkbook: - def test_valid_full_workbook(self, valid_file: BinaryIO): + def test_valid_full_workbook(self, valid_file: BinaryIO) -> None: result = validate(valid_file) errors = result[0] project_use_code = result[1] assert errors == [] assert project_use_code == SAMPLE_EXPENDITURE_CATEGORY_GROUP - def test_multiple_invalid_sheets(self, valid_workbook: Workbook): + def test_multiple_invalid_sheets(self, valid_workbook: Workbook) -> None: """ Tests that an error in the first sheet doesn't prevent the second sheet from being validated. @@ -85,11 +91,11 @@ def test_multiple_invalid_sheets(self, valid_workbook: Workbook): class TestWorkbookSheets: - def test_valid_set_of_sheets(self, valid_workbook: Workbook): + def test_valid_set_of_sheets(self, valid_workbook: Workbook) -> None: errors = validate_workbook_sheets(valid_workbook) assert errors == [] - def test_missing_sheets(self, valid_workbook: Workbook): + def test_missing_sheets(self, valid_workbook: Workbook) -> None: valid_workbook.remove(valid_workbook["Cover"]) errors = validate_workbook_sheets(valid_workbook) assert errors != [] @@ -100,21 +106,21 @@ def test_missing_sheets(self, valid_workbook: Workbook): class TestValidateCoverSheet: - def test_valid_cover_sheet(self, valid_coversheet: Worksheet): + def test_valid_cover_sheet(self, valid_coversheet: Worksheet) -> None: errors, schema, project_use_code = validate_cover_sheet( - valid_coversheet, V2024_05_24_VERSION_STRING + valid_coversheet, V2024_05_24_VERSION.value ) assert errors == [] assert project_use_code == SAMPLE_EXPENDITURE_CATEGORY_GROUP assert schema == getSchemaByProject( - V2024_05_24_VERSION_STRING, SAMPLE_EXPENDITURE_CATEGORY_GROUP + V2024_05_24_VERSION, SAMPLE_EXPENDITURE_CATEGORY_GROUP ) def test_invalid_cover_sheet_missing_code( self, invalid_cover_sheet_missing_code: Worksheet - ): + ) -> None: errors, schema, project_use_code = validate_cover_sheet( - invalid_cover_sheet_missing_code, V2024_05_24_VERSION_STRING + invalid_cover_sheet_missing_code, V2024_05_24_VERSION.value ) assert errors != [] error = errors[0] @@ -128,9 +134,9 @@ def test_invalid_cover_sheet_missing_code( def test_invalid_cover_sheet_empty_code( self, invalid_cover_sheet_empty_code: Worksheet - ): + ) -> None: errors, schema, project_use_code = validate_cover_sheet( - invalid_cover_sheet_empty_code, V2024_05_24_VERSION_STRING + invalid_cover_sheet_empty_code, V2024_05_24_VERSION.value ) assert errors != [] error = errors[0] @@ -144,9 +150,9 @@ def test_invalid_cover_sheet_empty_code( def test_invalid_cover_sheet_empty_desc( self, invalid_cover_sheet_empty_desc: Worksheet - ): + ) -> None: errors, schema, project_use_code = validate_cover_sheet( - invalid_cover_sheet_empty_desc, V2024_05_24_VERSION_STRING + invalid_cover_sheet_empty_desc, V2024_05_24_VERSION.value ) assert errors != [] error = errors[0] @@ -160,27 +166,26 @@ def test_invalid_cover_sheet_empty_desc( class TestValidateproject_sheet: - def test_valid_project_sheet(self, valid_project_sheet: Worksheet): + def test_valid_project_sheet(self, valid_project_sheet: Worksheet) -> None: errors, projects = validate_project_sheet( valid_project_sheet, - getSchemaByProject( - V2024_05_24_VERSION_STRING, SAMPLE_EXPENDITURE_CATEGORY_GROUP - ), - V2024_05_24_VERSION_STRING, + getSchemaByProject(V2024_05_24_VERSION, SAMPLE_EXPENDITURE_CATEGORY_GROUP), + V2024_05_24_VERSION.value, ) assert errors == [] assert len(projects) == 1 assert projects[0].row_num == 13 - assert projects[0].Subrecipient_UEI__c == "123412341234" - assert projects[0].Subrecipient_TIN__c == "123123123" + project: V2024_05_24_ProjectRows = typing.cast( + V2024_05_24_ProjectRows, projects[0] + ) + assert project.Subrecipient_UEI__c == "123412341234" + assert project.Subrecipient_TIN__c == "123123123" - def test_invalid_project_sheet(self, invalid_project_sheet: Worksheet): + def test_invalid_project_sheet(self, invalid_project_sheet: Worksheet) -> None: errors, _ = validate_project_sheet( invalid_project_sheet, - getSchemaByProject( - V2024_05_24_VERSION_STRING, SAMPLE_EXPENDITURE_CATEGORY_GROUP - ), - V2024_05_24_VERSION_STRING, + getSchemaByProject(V2024_05_24_VERSION, SAMPLE_EXPENDITURE_CATEGORY_GROUP), + V2024_05_24_VERSION.value, ) assert errors != [] error = errors[0] @@ -194,13 +199,11 @@ def test_invalid_project_sheet(self, invalid_project_sheet: Worksheet): def test_invalid_project_sheet_missing_field( self, invalid_project_sheet_missing_field: Worksheet - ): + ) -> None: errors, _ = validate_project_sheet( invalid_project_sheet_missing_field, - getSchemaByProject( - V2024_05_24_VERSION_STRING, SAMPLE_EXPENDITURE_CATEGORY_GROUP - ), - V2024_05_24_VERSION_STRING, + getSchemaByProject(V2024_05_24_VERSION, SAMPLE_EXPENDITURE_CATEGORY_GROUP), + V2024_05_24_VERSION.value, ) assert errors != [] error = errors[0] @@ -211,13 +214,11 @@ def test_invalid_project_sheet_missing_field( def test_invalid_project_sheet_empty_field( self, invalid_project_sheet_empty_field: Worksheet - ): + ) -> None: errors, _ = validate_project_sheet( invalid_project_sheet_empty_field, - getSchemaByProject( - V2024_05_24_VERSION_STRING, SAMPLE_EXPENDITURE_CATEGORY_GROUP - ), - V2024_05_24_VERSION_STRING, + getSchemaByProject(V2024_05_24_VERSION, SAMPLE_EXPENDITURE_CATEGORY_GROUP), + V2024_05_24_VERSION.value, ) assert errors != [] error = errors[0] @@ -228,18 +229,20 @@ def test_invalid_project_sheet_empty_field( class TestValidateSubrecipientSheet: - def test_valid_subrecipient_sheet(self, valid_subrecipientsheet: Worksheet): + def test_valid_subrecipient_sheet(self, valid_subrecipientsheet: Worksheet) -> None: errors, subrecipients = validate_subrecipient_sheet( - valid_subrecipientsheet, V2024_05_24_VERSION_STRING + valid_subrecipientsheet, V2024_05_24_VERSION.value ) assert errors == [] assert len(subrecipients) == 1 assert subrecipients[0].EIN__c == "123123123" assert subrecipients[0].Unique_Entity_Identifier__c == "123412341234" - def test_invalid_subrecipient_sheet(self, invalid_subrecipient_sheet: Worksheet): + def test_invalid_subrecipient_sheet( + self, invalid_subrecipient_sheet: Worksheet + ) -> None: errors, _ = validate_subrecipient_sheet( - invalid_subrecipient_sheet, V2024_05_24_VERSION_STRING + invalid_subrecipient_sheet, V2024_05_24_VERSION.value ) assert errors != [] error = errors[0] @@ -250,61 +253,63 @@ def test_invalid_subrecipient_sheet(self, invalid_subrecipient_sheet: Worksheet) def test_valid_subrecipient_sheet_blank_optional_fields( self, valid_subrecipient_sheet_blank_optional_fields: Worksheet - ): + ) -> None: errors, _ = validate_subrecipient_sheet( - valid_subrecipient_sheet_blank_optional_fields, V2024_05_24_VERSION_STRING + valid_subrecipient_sheet_blank_optional_fields, V2024_05_24_VERSION.value ) assert errors == [] class TestValidateMatchingSubrecipientSheet: def test_invalid_project_sheet_unmatching_subrecipient_tin_field( - self, invalid_project_sheet_unmatching_subrecipient_tin_field: Worksheet - ): + self, invalid_project_sheet_unmatching_subrecipient_tin_field: Workbook + ) -> None: errors, _, _, _ = validate_workbook( invalid_project_sheet_unmatching_subrecipient_tin_field ) assert errors != [] error = errors[0] assert "You must submit a subrecipient" in error.message - assert error.row == 13 + assert error.row == "13" assert error.col == "E, F" assert error.severity == ErrorLevel.ERR.name def test_invalid_project_sheet_unmatching_subrecipient_uei_field( - self, invalid_project_sheet_unmatching_subrecipient_uei_field: Worksheet - ): + self, invalid_project_sheet_unmatching_subrecipient_uei_field: Workbook + ) -> None: errors, _, _, _ = validate_workbook( invalid_project_sheet_unmatching_subrecipient_uei_field ) assert errors != [] error = errors[0] assert "You must submit a subrecipient" in error.message - assert error.row == 13 + assert error.row == "13" assert error.col == "E, F" assert error.severity == ErrorLevel.ERR.name def test_invalid_project_sheet_unmatching_subrecipient_tin_uei_field( - self, invalid_project_sheet_unmatching_subrecipient_tin_uei_field: Worksheet - ): + self, invalid_project_sheet_unmatching_subrecipient_tin_uei_field: Workbook + ) -> None: errors, _, _, _ = validate_workbook( invalid_project_sheet_unmatching_subrecipient_tin_uei_field ) assert errors != [] error = errors[0] assert "You must submit a subrecipient" in error.message - assert error.row == 13 + assert error.row == "13" assert error.col == "E, F" assert error.severity == ErrorLevel.ERR.name class TestGetProjectUseCode: - def test_get_project_use_code(self, valid_coversheet: Worksheet): + def test_get_project_use_code(self, valid_coversheet: Worksheet) -> None: expenditure_category_group = get_project_use_code( - valid_coversheet, V2024_05_24_VERSION_STRING + valid_coversheet, V2024_05_24_VERSION.value ) assert expenditure_category_group == SAMPLE_EXPENDITURE_CATEGORY_GROUP - def test_get_project_use_code_raises_error(self, invalid_cover_sheet: Worksheet): + def test_get_project_use_code_raises_error( + self, invalid_cover_sheet: Worksheet + ) -> None: with pytest.raises(ValueError): - get_project_use_code(invalid_cover_sheet, V2024_05_24_VERSION_STRING) + get_project_use_code(invalid_cover_sheet, V2024_05_24_VERSION.value) diff --git a/scripts/sample_lambda.py b/scripts/sample_lambda.py index 62b537e3..ea66f974 100644 --- a/scripts/sample_lambda.py +++ b/scripts/sample_lambda.py @@ -1,8 +1,13 @@ +from aws_lambda_typing import context as context_ +from aws_lambda_typing import events + """ To build this, zip it using zip scripts/function.zip scripts/sample_lambda.py """ -def lambda_handler(event, context): +def lambda_handler( + event: events.SQSEvent, _context: context_.Context +) -> dict[str, str]: print(event) - return {k: f"{v}_sample" for k, v in event["input"].items()} + return {k: f"{v}_sample" for k, v in event.items()}