diff --git a/src/regtech_data_validator/validator.py b/src/regtech_data_validator/validator.py index 2bd1851..917f62a 100644 --- a/src/regtech_data_validator/validator.py +++ b/src/regtech_data_validator/validator.py @@ -81,7 +81,6 @@ def validate( pd.DataFrame containing validation results data """ findings_df: pl.DataFrame = pl.DataFrame() - error_counts = warning_counts = Counts() try: # since polars dataframes don't normally have an index column, add it, so that we can match @@ -94,7 +93,6 @@ def validate( # `list[dict[str,Any]]`, but it's actually of type `SchemaError` schema_error: SchemaError - error_counts, warning_counts = get_scope_counts(err.schema_errors) if process_errors: for schema_error in err.schema_errors: check = schema_error.check @@ -136,14 +134,7 @@ def validate( findings_df = pl.concat(check_findings) updated_df = add_uid(findings_df, submission_df) - results = ValidationResults( - error_counts=error_counts, - warning_counts=warning_counts, - is_valid=((error_counts.total_count + warning_counts.total_count) == 0), - findings=updated_df, - phase=schema.name, - ) - return results + return updated_df # Add the uid for the record throwing the error/warning to the error dataframe @@ -189,13 +180,21 @@ def validate_batch_csv( if not has_syntax_errors: register_schema = get_register_schema(context) validation_results = validate(register_schema, pl.DataFrame({"uid": all_uids}), 0, True) - if not validation_results.findings.is_empty(): - validation_results.findings = format_findings( - validation_results.findings, + if not validation_results.is_empty(): + validation_results = format_findings( + validation_results, ValidationPhase.LOGICAL.value, [check for col_schema in register_schema.columns.values() for check in col_schema.checks], ) - yield validation_results + error_counts, warning_counts = get_scope_counts(validation_results) + results = ValidationResults( + error_counts=error_counts, + warning_counts=warning_counts, + is_valid=((error_counts.total_count + warning_counts.total_count) == 0), + findings=validation_results, + phase=register_schema.name, + ) + yield results for validation_results, _ in validate_chunks( logic_schema, real_path, batch_size, batch_count, max_errors, logic_checks @@ -220,21 +219,29 @@ def validate_chunks(schema, path, batch_size, batch_count, max_errors, checks): while batches: df = pl.concat(batches) validation_results = validate(schema, df, row_start, process_errors) - if not validation_results.findings.is_empty(): - validation_results.findings = format_findings( - validation_results.findings, validation_results.phase.value, checks - ) + if not validation_results.is_empty(): - total_count += validation_results.findings.height + validation_results = format_findings(validation_results, schema.name.value, checks) + + error_counts, warning_counts = get_scope_counts(validation_results) + results = ValidationResults( + error_counts=error_counts, + warning_counts=warning_counts, + is_valid=((error_counts.total_count + warning_counts.total_count) == 0), + findings=validation_results, + phase=schema.name, + ) + + total_count += results.findings.height if total_count > max_errors and process_errors: process_errors = False - head_count = validation_results.findings.height - (total_count - max_errors) - validation_results.findings = validation_results.findings.head(head_count) + head_count = results.findings.height - (total_count - max_errors) + results.findings = results.findings.head(head_count) row_start += df.height batches = reader.next_batches(batch_count) - yield validation_results, df["uid"].to_list() + yield results, df["uid"].to_list() def get_real_file_path(path): @@ -256,68 +263,34 @@ def gather_errors(schema_error: SchemaError): return schema_error -def get_scope_counts(schema_errors: list[SchemaError]): - singles = [ - error for error in schema_errors if isinstance(error.check, SBLCheck) and error.check.scope == 'single-field' - ] - - single_errors = int( - sum( - [ - (error.check_output.filter(~pl.col("check_output"))).height - for error in singles - if error.check.severity == Severity.ERROR - ] +def get_scope_counts(error_frame: pl.DataFrame): + if not error_frame.is_empty(): + single_errors = error_frame.filter( + (pl.col("validation_type") == Severity.ERROR) & (pl.col("scope") == "single-field") + ).height + single_warnings = error_frame.filter( + (pl.col("validation_type") == Severity.WARNING) & (pl.col("scope") == "single-field") + ).height + register_errors = error_frame.filter( + (pl.col("validation_type") == Severity.ERROR) & (pl.col("scope") == "register") + ).height + multi_errors = error_frame.filter( + (pl.col("validation_type") == Severity.ERROR) & (pl.col("scope") == "multi-field") + ).height + multi_warnings = error_frame.filter( + (pl.col("validation_type") == Severity.WARNING) & (pl.col("scope") == "multi-field") + ).height + + return Counts( + single_field_count=single_errors, + multi_field_count=multi_errors, + register_count=register_errors, + total_count=sum([single_errors, multi_errors, register_errors]), + ), Counts( + single_field_count=single_warnings, + multi_field_count=multi_warnings, + total_count=sum([single_warnings, multi_warnings]), # There are no register-level warnings at this time ) - ) - single_warnings = int( - sum( - [ - (error.check_output.filter(~pl.col("check_output"))).height - for error in singles - if error.check.severity == Severity.WARNING - ] - ) - ) - multi = [ - error for error in schema_errors if isinstance(error.check, SBLCheck) and error.check.scope == 'multi-field' - ] - multi_errors = int( - sum( - [ - (error.check_output.filter(~pl.col("check_output"))).height - for error in multi - if error.check.severity == Severity.ERROR - ] - ) - ) - multi_warnings = int( - sum( - [ - (error.check_output.filter(~pl.col("check_output"))).height - for error in multi - if error.check.severity == Severity.WARNING - ] - ) - ) - register_errors = int( - sum( - [ - (error.check_output.filter(~pl.col("check_output"))).height - for error in schema_errors - if isinstance(error.check, SBLCheck) and error.check.scope == 'register' - ] - ) - ) - - return Counts( - single_field_count=single_errors, - multi_field_count=multi_errors, - register_count=register_errors, - total_count=sum([single_errors, multi_errors, register_errors]), - ), Counts( - single_field_count=single_warnings, - multi_field_count=multi_warnings, - total_count=sum([single_warnings, multi_warnings]), # There are no register-level warnings at this time - ) + else: + return Counts(), Counts()