Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed counting bug #273

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 58 additions & 85 deletions src/regtech_data_validator/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ def validate(
pd.DataFrame containing validation results data
"""
findings_df: pl.DataFrame = pl.DataFrame()
error_counts = warning_counts = Counts()

try:
# since polars dataframes don't normally have an index column, add it, so that we can match
Expand All @@ -94,7 +93,6 @@ def validate(
# `list[dict[str,Any]]`, but it's actually of type `SchemaError`
schema_error: SchemaError

error_counts, warning_counts = get_scope_counts(err.schema_errors)
if process_errors:
for schema_error in err.schema_errors:
check = schema_error.check
Expand Down Expand Up @@ -136,14 +134,7 @@ def validate(
findings_df = pl.concat(check_findings)

updated_df = add_uid(findings_df, submission_df)
results = ValidationResults(
error_counts=error_counts,
warning_counts=warning_counts,
is_valid=((error_counts.total_count + warning_counts.total_count) == 0),
findings=updated_df,
phase=schema.name,
)
return results
return updated_df


# Add the uid for the record throwing the error/warning to the error dataframe
Expand Down Expand Up @@ -189,13 +180,21 @@ def validate_batch_csv(
if not has_syntax_errors:
register_schema = get_register_schema(context)
validation_results = validate(register_schema, pl.DataFrame({"uid": all_uids}), 0, True)
if not validation_results.findings.is_empty():
validation_results.findings = format_findings(
validation_results.findings,
if not validation_results.is_empty():
validation_results = format_findings(
validation_results,
ValidationPhase.LOGICAL.value,
[check for col_schema in register_schema.columns.values() for check in col_schema.checks],
)
yield validation_results
error_counts, warning_counts = get_scope_counts(validation_results)
results = ValidationResults(
error_counts=error_counts,
warning_counts=warning_counts,
is_valid=((error_counts.total_count + warning_counts.total_count) == 0),
findings=validation_results,
phase=register_schema.name,
)
yield results

for validation_results, _ in validate_chunks(
logic_schema, real_path, batch_size, batch_count, max_errors, logic_checks
Expand All @@ -220,21 +219,29 @@ def validate_chunks(schema, path, batch_size, batch_count, max_errors, checks):
while batches:
df = pl.concat(batches)
validation_results = validate(schema, df, row_start, process_errors)
if not validation_results.findings.is_empty():
validation_results.findings = format_findings(
validation_results.findings, validation_results.phase.value, checks
)
if not validation_results.is_empty():

total_count += validation_results.findings.height
validation_results = format_findings(validation_results, schema.name.value, checks)

error_counts, warning_counts = get_scope_counts(validation_results)
results = ValidationResults(
error_counts=error_counts,
warning_counts=warning_counts,
is_valid=((error_counts.total_count + warning_counts.total_count) == 0),
findings=validation_results,
phase=schema.name,
)

total_count += results.findings.height

if total_count > max_errors and process_errors:
process_errors = False
head_count = validation_results.findings.height - (total_count - max_errors)
validation_results.findings = validation_results.findings.head(head_count)
head_count = results.findings.height - (total_count - max_errors)
results.findings = results.findings.head(head_count)

row_start += df.height
batches = reader.next_batches(batch_count)
yield validation_results, df["uid"].to_list()
yield results, df["uid"].to_list()


def get_real_file_path(path):
Expand All @@ -256,68 +263,34 @@ def gather_errors(schema_error: SchemaError):
return schema_error


def get_scope_counts(schema_errors: list[SchemaError]):
singles = [
error for error in schema_errors if isinstance(error.check, SBLCheck) and error.check.scope == 'single-field'
]

single_errors = int(
sum(
[
(error.check_output.filter(~pl.col("check_output"))).height
for error in singles
if error.check.severity == Severity.ERROR
]
def get_scope_counts(error_frame: pl.DataFrame):
if not error_frame.is_empty():
single_errors = error_frame.filter(
(pl.col("validation_type") == Severity.ERROR) & (pl.col("scope") == "single-field")
).height
single_warnings = error_frame.filter(
(pl.col("validation_type") == Severity.WARNING) & (pl.col("scope") == "single-field")
).height
register_errors = error_frame.filter(
(pl.col("validation_type") == Severity.ERROR) & (pl.col("scope") == "register")
).height
multi_errors = error_frame.filter(
(pl.col("validation_type") == Severity.ERROR) & (pl.col("scope") == "multi-field")
).height
multi_warnings = error_frame.filter(
(pl.col("validation_type") == Severity.WARNING) & (pl.col("scope") == "multi-field")
).height

return Counts(
single_field_count=single_errors,
multi_field_count=multi_errors,
register_count=register_errors,
total_count=sum([single_errors, multi_errors, register_errors]),
), Counts(
single_field_count=single_warnings,
multi_field_count=multi_warnings,
total_count=sum([single_warnings, multi_warnings]), # There are no register-level warnings at this time
)
)
single_warnings = int(
sum(
[
(error.check_output.filter(~pl.col("check_output"))).height
for error in singles
if error.check.severity == Severity.WARNING
]
)
)
multi = [
error for error in schema_errors if isinstance(error.check, SBLCheck) and error.check.scope == 'multi-field'
]
multi_errors = int(
sum(
[
(error.check_output.filter(~pl.col("check_output"))).height
for error in multi
if error.check.severity == Severity.ERROR
]
)
)
multi_warnings = int(
sum(
[
(error.check_output.filter(~pl.col("check_output"))).height
for error in multi
if error.check.severity == Severity.WARNING
]
)
)

register_errors = int(
sum(
[
(error.check_output.filter(~pl.col("check_output"))).height
for error in schema_errors
if isinstance(error.check, SBLCheck) and error.check.scope == 'register'
]
)
)

return Counts(
single_field_count=single_errors,
multi_field_count=multi_errors,
register_count=register_errors,
total_count=sum([single_errors, multi_errors, register_errors]),
), Counts(
single_field_count=single_warnings,
multi_field_count=multi_warnings,
total_count=sum([single_warnings, multi_warnings]), # There are no register-level warnings at this time
)
else:
return Counts(), Counts()
Loading