Skip to content

Commit

Permalink
Fix eval vulnerability (#944)
Browse files Browse the repository at this point in the history
  • Loading branch information
yadavsahil197 authored Dec 6, 2024
1 parent da27704 commit f43a21d
Showing 1 changed file with 9 additions and 10 deletions.
19 changes: 9 additions & 10 deletions src/autolabel/dataset/validation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Data and Schema Validation"""

import ast
import json
import re
from functools import cached_property
Expand All @@ -16,7 +17,6 @@


class NERTaskValidate(BaseModel):

"""
Validate NER Task
Expand Down Expand Up @@ -48,7 +48,6 @@ def validate(self, value: str):


class ClassificationTaskValidate(BaseModel):

"""
Validate Classification Task
Expand All @@ -69,7 +68,7 @@ def validate(self, value: str):
# TODO: This can be made better
if value.startswith("[") and value.endswith("]"):
try:
seed_labels = eval(value)
seed_labels = ast.literal_eval(value)
if not isinstance(seed_labels, list):
raise
unmatched_label = set(seed_labels) - self.labels_set
Expand All @@ -86,7 +85,6 @@ def validate(self, value: str):


class EMTaskValidate(BaseModel):

"""
Validate Entity Matching Task
Expand All @@ -106,7 +104,6 @@ def validate(self, value: str):


class QATaskValidate(BaseModel):

"""
Validate Question Answering Task
Expand All @@ -123,7 +120,6 @@ def validate(self, value: str):


class MLCTaskValidate(BaseModel):

"""
Validate Multilabel Classification Task
Expand All @@ -140,7 +136,7 @@ class MLCTaskValidate(BaseModel):
def validate(self, value: str):
if value.startswith("[") and value.endswith("]"):
try:
seed_labels = eval(value)
seed_labels = ast.literal_eval(value)
if not isinstance(seed_labels, list):
raise ValueError(
f"value: '{value}' is not a list of labels as expected",
Expand Down Expand Up @@ -175,14 +171,14 @@ class DataValidationTasks(BaseModel):


class TaskDataValidation:

"""Task Validation"""

def __init__(self, config: AutolabelConfig):
"""
Task Validation
Args:
----
config: AutolabelConfig = User passed parsed configuration
"""
Expand All @@ -204,7 +200,8 @@ def __init__(self, config: AutolabelConfig):
self.__schema = {col: (StrictStr, ...) for col in self.expected_columns}

self.__validation_task = DataValidationTasks.__dict__[task_type](
label_column=label_column, labels_set=set(labels_list),
label_column=label_column,
labels_set=set(labels_list),
)
self.__data_validation = self.data_validation_and_schema_check(
self.__validation_task,
Expand Down Expand Up @@ -240,20 +237,22 @@ def data_validation_and_schema_check(self, validation_task: BaseModel):
Validate data format and datatype
Args:
----
validation_task (TaskTypeValidate): validation task
Raises:
------
e: Validation error if the inputs are not string
e: Validation error if validation_task fails
Returns:
-------
DataValidation: Pydantic Model for validation
"""
Model = create_model("Model", **self.__schema)

class DataValidation(BaseModel):

"""Data Validation"""

# We define validate as a classmethod such that a dynamic `data` can be passed
Expand Down

0 comments on commit f43a21d

Please sign in to comment.