Skip to content

Commit

Permalink
updates to write_json_files to work better with prediction models
Browse files Browse the repository at this point in the history
  • Loading branch information
djm21 committed Jul 17, 2024
1 parent f279234 commit a7b49ac
Showing 1 changed file with 13 additions and 8 deletions.
21 changes: 13 additions & 8 deletions src/sasctl/pzmm/write_json_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -1260,7 +1260,7 @@ def calculate_model_statistics(
if not partition:
continue

data = cls.stat_dataset_to_dataframe(data, target_value)
data = cls.stat_dataset_to_dataframe(data, target_value, target_type)

conn.upload(
data,
Expand Down Expand Up @@ -1392,6 +1392,7 @@ def check_for_data(
def stat_dataset_to_dataframe(
data: Union[DataFrame, List[list], Type["numpy.array"]],
target_value: Union[str, int, float] = None,
target_type: str = 'classification'
) -> DataFrame:
"""
Convert the user supplied statistical dataset from either a pandas DataFrame,
Expand Down Expand Up @@ -1439,13 +1440,15 @@ def stat_dataset_to_dataframe(
if isinstance(data, pd.DataFrame):
if len(data.columns) == 2:
data.columns = ["actual", "predict"]
data["predict_proba"] = data["predict"].gt(target_value).astype(int)
if target_type == 'classification':
data["predict_proba"] = data["predict"].gt(target_value).astype(int)
elif len(data.columns) == 3:
data.columns = ["actual", "predict", "predict_proba"]
elif isinstance(data, list):
if len(data) == 2:
data = pd.DataFrame({"actual": data[0], "predict": data[1]})
data["predict_proba"] = data["predict"].gt(target_value).astype(int)
if target_type == 'classification':
data["predict_proba"] = data["predict"].gt(target_value).astype(int)
elif len(data) == 3:
data = pd.DataFrame(
{
Expand All @@ -1457,7 +1460,8 @@ def stat_dataset_to_dataframe(
elif isinstance(data, np.ndarray):
if len(data) == 2:
data = pd.DataFrame({"actual": data[0, :], "predict": data[1, :]})
data["predict_proba"] = data["predict"].gt(target_value).astype(int)
if target_type == 'classification':
data["predict_proba"] = data["predict"].gt(target_value).astype(int)
elif len(data) == 3:
data = pd.DataFrame(
{"actual": data[0], "predict": data[1], "predict_proba": data[2]}
Expand Down Expand Up @@ -2366,7 +2370,8 @@ def generate_model_card(
)

# Generates dmcas_misc.json file
cls.generate_misc(model_files)
if target_type == 'classification':
cls.generate_misc(model_files)

@staticmethod
def upload_training_data(
Expand Down Expand Up @@ -2617,7 +2622,7 @@ def generate_variable_importance(
if target_type == "classification":
method = "DTREE"
treeCrit = "Entropy"
elif target_type == "interval":
elif target_type == "prediction":
method = "RTREE"
treeCrit = "RSS"
else:
Expand Down Expand Up @@ -2743,14 +2748,14 @@ def generate_misc(cls, model_files: Union[str, Path, dict]):
if isinstance(model_files, dict):
if ROC not in model_files:
raise RuntimeError(
"The ModelProperties.json file must be generated before the model card data "
"The dmcas_roc.json file must be generated before the model card data "
"can be generated."
)
roc_table = model_files[ROC]
else:
if not Path.exists(Path(model_files) / ROC):
raise RuntimeError(
"The ModelProperties.json file must be generated before the model card data "
"The dmcas_roc.json file must be generated before the model card data "
"can be generated."
)
with open(Path(model_files) / ROC, "r") as roc_file:
Expand Down

0 comments on commit a7b49ac

Please sign in to comment.