-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Create an Evaluator/Comparer #157
Comments
MarIniOnz
added
python
Related to the python medmodels
p-medium
Priority: Medium
feat
Related to new feature or enhancement
labels
Aug 5, 2024
Previous MTGAN evaluator: """Class for evaluating the results of the MTGAN model."""
import json
import logging
from pathlib import Path
from typing import Any, Dict
import numpy as np
import torch
from medmodels.data_synthesis.mtgan.model.generator.generator import Generator
from medmodels.data_synthesis.mtgan.model.loaders import MTGANDataset
from medmodels.data_synthesis.mtgan.modules.evaluator_functions import (
calculate_distance,
get_basic_statistics,
get_top_k_disease,
)
from medmodels.data_synthesis.mtgan.modules.mtgan_preprocessor import MTGANPreprocessor
from medmodels.medrecord.medrecord import MedRecord
class MTGANEvaluator(torch.nn.Module):
"""Class for evaluating the results of the MTGAN model."""
def __init__(self):
super().__init__()
self.number_admissions_distribution = None
self.generator = None
self.preprocessor = MTGANPreprocessor()
self.batch_size = 32
def _load_hyperparameters(self, hyperparameter_path: Path) -> None:
"""Loads the hyperparameters from a JSON file.
Args:
hyperparameter_path (Path): Path to the hyperparameters JSON file.
Raises:
FileNotFoundError: If the hyperparameters file is not found.
ValueError: If the hyperparameters file does not contain a 'generation' key.
"""
if not hyperparameter_path.exists():
msg = f"The hyperparameters file '{hyperparameter_path}' was not found."
logging.error(msg)
raise FileNotFoundError(msg)
with open(hyperparameter_path, "r", encoding="utf-8") as f:
hyperparameters = json.load(f)
if "generation" not in hyperparameters:
msg = "Hyperparameters file must contain a 'generation' key."
raise ValueError(msg)
self.batch_size = hyperparameters["generation"]["GENERATE_BATCH_SIZE"]
def _load_processed_data(
self,
generator: Generator,
preprocessor: MTGANPreprocessor,
number_admissions_distribution: torch.Tensor,
) -> None:
"""Load the processed data from the generator and preprocessor.
Args:
generator (Generator): Generator model.
preprocessor (MTGANPreprocessor): Preprocessor class.
number_admissions_distribution (torch.Tensor): Number of admissions distribution.
"""
self.generator = generator
self.preprocessor = preprocessor
self.number_admissions_distribution = number_admissions_distribution
def evaluate(
self, real_medrecord: MedRecord, synthetic_medrecord: MedRecord
) -> Dict[str, Any]:
"""Evaluate the results of the MTGAN model.
Args:
real_medrecord (MedRecord): Real MedRecord object.
synthetic_medrecord (MedRecord): Synthetic MedRecord object.
Returns:
Dict[str, Any]: Dictionary containing evaluation statistics.
"""
if self.generator is None or self.number_admissions_distribution is None:
msg = "Training needs to be done before evaluation."
logging.error(msg)
raise ValueError(msg)
# Computing real and synthetic data
real_num_patients = len(
real_medrecord.nodes_in_group(self.preprocessor.patients_group)
)
real_dataset = MTGANDataset(self.preprocessor, medrecord=real_medrecord)
real_data, real_number_admissions = real_dataset[np.arange(real_num_patients)]
real_data, real_number_admissions = (
real_data.cpu().numpy().astype(np.bool),
real_number_admissions.cpu().numpy().astype(np.int16),
)
synthetic_num_patients = len(
synthetic_medrecord.nodes_in_group(self.preprocessor.patients_group)
)
synthetic_dataset = MTGANDataset(
self.preprocessor, medrecord=synthetic_medrecord
)
synthetic_data, synthetic_number_admissions = synthetic_dataset[
np.arange(synthetic_num_patients)
]
synthetic_data, synthetic_number_admissions = (
synthetic_data.cpu().numpy().astype(np.bool),
synthetic_number_admissions.cpu().numpy().astype(np.int16),
)
# Get statistics from real and synthetic data
real_stats = get_basic_statistics(real_data, real_number_admissions)
synthetic_stats = get_basic_statistics(
synthetic_data, synthetic_number_admissions
)
logging.info("Top 10 Real diseases")
top_real_diseases = get_top_k_disease(
real_data, self.preprocessor.index_to_concept_dict, top_k=10
)
logging.info("Top 10 Synthetic diseases")
top_synthetic_diseases = get_top_k_disease(
synthetic_data, self.preprocessor.index_to_concept_dict, top_k=10
)
# Calculate distances between real and synthetic data
distances = calculate_distance(real_data, synthetic_data)
# Get the required number of samples to generate all diseases at least once
needed_samples = self.generator.get_required_number(
self.number_admissions_distribution, self.batch_size
)
stats_dict = {
"real_stats": real_stats,
"synthetic_stats": synthetic_stats,
"top_real_diseases": top_real_diseases,
"top_synthetic_diseases": top_synthetic_diseases,
"distances": distances,
"needed_samples": needed_samples,
}
return stats_dict |
Evaluator functions (could be used or not, specific to MTGAN formatting) """This module contains functions for computing statistics of the data and
distances between real and synthetic data.
Functions:
- get_basic_statistics: get basic statistics: number of true codes, number of all
codes, number of all visits, average number of codes per visit, average number
of visits per patient
- code_count: count the number of occurrences of each code in the data
- get_top_k_disease: get the diseases with their codes and number of occurrences in
order of occurrences and save the top k diseases into a file if file is not None
- normalized_distance: compute normalized distance between two distributions
- pad_distribution: pad the shorter distribution with zeros
- calculate_distance: compute the distances btw. real and synthetic data based on
the following stats: Jensen-Shannon-Divergence and normalized distance for
patients (p) and visits (v)
"""
import logging
from typing import Dict, List, Tuple, Union
import numpy as np
from numpy.typing import NDArray
from scipy.spatial.distance import jensenshannon as jsd
from medmodels.medrecord.types import NodeIndex
def get_basic_statistics(
data: NDArray[np.bool], number_admissions: NDArray[np.int64]
) -> Dict[str, Union[int, float]]:
"""Get basic statistics: number of true codes, number of all codes, number of all
visits, average number of codes per visit, average number of visits per patient.
Args:
data (NDArray[np.bool]): input data, 3D boolean matrix of shape
(num_patients, max_number_admissions, num_codes)
number_admissions (NDArray[np.int64]): number of visits per patient
Returns:
Dict[str, Any]: Dictionary with the following stats:
- number of code types (happening to at least one patient in the data),
- number of all codes (how many diagnoses there are in total in the data),
- number of visits,
- avg codes per visit,
- avg visits per patient
"""
result = data.sum(axis=1).sum(axis=0)
num_types = (result > 0).sum()
num_codes = result.sum()
num_visits = number_admissions.sum()
mean_code_num = num_codes / num_visits if num_visits > 0 else 0
mean_visit_num = num_visits / data.shape[0]
return {
"total_num_code_types": num_types, # number of unique codes
"total_num_codes": num_codes,
"total_num_visits": num_visits,
"mean_codes_per_visit": mean_code_num,
"mean_visit_per_patient": mean_visit_num,
}
def code_count(
data: NDArray[np.bool], index_to_concept_map: Dict[int, NodeIndex]
) -> List[Tuple[NodeIndex, int]]:
"""Count the number of occurrences of each code in the data
Args:
data (NDArray[np.bool]): input data, 3D boolean matrix of shape
(num_patients, max_time_window, num_codes)
number_admissions (NDArray[np.int64]): number of visits per patient
index_code_map (Dict[int, NodeIndex]): mapping of code indices to code values
Returns:
List[Tuple[NodeIndex, int]]: sorted count of codes
"""
count = {
index_to_concept_map[i]: int(data[:, :, i].sum()) for i in range(data.shape[2])
}
return sorted(count.items(), key=lambda item: item[1], reverse=True)
def get_top_k_disease(
data: NDArray[np.bool],
index_to_concept_map: Dict[int, NodeIndex],
top_k: int = 10,
) -> List[Tuple[NodeIndex, int]]:
"""Get the diseases with their codes and number of occurrences in order of
occurrences and save the top k diseases into a file if file is not None
Args:
data (NDArray[np.bool]): input data, 3D boolean matrix of shape
(num_patients, max_time_window, num_codes)
index_to_concept_map (Dict[int, NodeIndex]): mapping of code indices to code values
top_k (int, optional): number of top diseases to return, defaults to 10
Returns:
List[Tuple[NodeIndex, int]]: sorted count of codes
"""
count = code_count(data, index_to_concept_map)
num_diseases = 0
logging.info("--------------------------------------------------")
for code_id, num in count:
num_diseases += 1
logging.info("%s; %d", code_id, num)
if num_diseases == top_k:
break
logging.info("--------------------------------------------------")
return count[:top_k]
def normalized_distance(
distribution1: NDArray[np.float32],
distribution2: NDArray[np.float32],
epsilon: float = 1e-12,
) -> float:
"""Compute normalized distance between two distributions.
The normalized distance is the absolute difference between two probability
distributions divided by their average.
The normalized distance is defined as:
dist = (p1 - p2)^2 / ((p1 + p2) / 2)
where p1 and p2 are the two probability distributions.
Args:
distribution1 (NDArray[np.float32]): first distribution
distribution2 (NDArray[np.float32]): second distribution
epsilon (float, optional): epsilon for smoothing, defaults to 1e-12
Returns:
float: normalized distance
"""
distance = ((distribution1 - distribution2) ** 2 + epsilon) / (
((distribution1 + distribution2) / 2) + epsilon
)
return distance.mean()
def pad_distribution(
distribution1: NDArray[np.float32], distribution2: NDArray[np.float32]
) -> Tuple[NDArray[np.float32], NDArray[np.float32]]:
"""Pad the shorter distribution with zeros
Args:
distribution1 (NDArray[np.float32]): first distribution
distribution2 (NDArray[np.float32]): second distribution
Returns:
Tuple[NDArray[np.float32], NDArray[np.float32]]: padded distributions
"""
if len(distribution1) < len(distribution2):
distribution1 = np.pad(
distribution1, (0, len(distribution2) - len(distribution1)), "constant"
)
elif len(distribution2) < len(distribution1):
distribution2 = np.pad(
distribution2, (0, len(distribution1) - len(distribution2)), "constant"
)
return distribution1, distribution2
def calculate_distance(
real_data: NDArray[np.bool],
synthetic_data: NDArray[np.bool],
) -> Dict[str, float]:
"""Compute the distances btw. real and synthetic data based on the following stats:
Jensen-Shannon-Divergence and normalized distance for patients (p) and visits (v).
JS divergence is a symmetrized and smoothed version of the Kullback-Leibler
divergence. It is a measure of similarity between two probability distributions.
The normalized distance is the absolute difference between two probability
distributions divided by their average.
Args:
real_data (NDArray[np.bool]): real data, boolean matrix of shape
(number_patients, max_number_admissions, number_codes)
synthetic_data (NDArray[np.float32]): synthetic data, boolean matrix of shape
(number_patients, max_number_admissions, number_codes)
Returns:
Dict[str, float]: Dictionary with the following stats:
- JS divergence for the intervals' distributions,
- normalized distance for the intervals' distributions,
- JS divergence for the codes' distributions,
- normalized distance for the codes' distributions,
- JS divergence for the patients' distributions,
- normalized distance for the patients' distributions.
"""
(
real_intervals_distribution,
real_codes_interval_distribution,
real_patients_code_distribution,
) = real_data_distribution(real_data)
(
synthetic_intervals_distribution,
synthetic_codes_interval_distribution,
synthetic_patients_code_distribution,
) = real_data_distribution(synthetic_data)
# Calculate distances for intervals' distributions
real_intervals_distribution, synthetic_intervals_distribution = pad_distribution(
real_intervals_distribution, synthetic_intervals_distribution
)
js_intervals = jsd(
real_intervals_distribution, synthetic_intervals_distribution
).astype(float)
normalized_distance_intervals = normalized_distance(
real_intervals_distribution, synthetic_intervals_distribution
)
# Calculate distances for codes' distributions
js_codes_interval = jsd(
real_codes_interval_distribution, synthetic_codes_interval_distribution
).astype(float)
normalized_distance_codes_interval = normalized_distance(
real_codes_interval_distribution, synthetic_codes_interval_distribution
)
# Calculate distances for patients' distributions
js_patients_code = jsd(
real_patients_code_distribution, synthetic_patients_code_distribution
).astype(float)
normalized_distance_patients_code = normalized_distance(
real_patients_code_distribution, synthetic_patients_code_distribution
)
return {
"JS_divergence_intervals": js_intervals,
"Normalized_distance_intervals": normalized_distance_intervals,
"JS_divergence_codes_interval": js_codes_interval,
"Normalized_distance_codes_interval": normalized_distance_codes_interval,
"JS_divergence_patients_code": js_patients_code,
"Normalized_distance_patients_code": normalized_distance_patients_code,
}
def real_data_distribution(
real_data: NDArray[np.bool],
) -> Tuple[NDArray[np.float32], NDArray[np.float32], NDArray[np.float32]]:
"""Analyse the real data and return different distributions.
Args:
real_data (NDArray[np.bool]): real data, 3D boolean array containing the
patients data. Shape: (num_patients, number_admissions, num_codes).
1 if the patient has the code in the interval, 0 otherwise.
Returns:
Tuple[NDArray[np.float32], NDArray[np.float32], NDArray[np.float32]]: distribution
of the number of time intervals with codes per patient, distribution of the number
of times each code appears in all intervals, distribution of the number of patients
to whom each code is assigned at least once.
"""
# 1. Distribution of how many time intervals each patient has
intervals_per_patient = (real_data.sum(axis=2) > 0).sum(axis=1)
intervals_distribution = np.bincount(intervals_per_patient) / (
len(intervals_per_patient)
)
# 2. Number of times a code appears in the whole dataset
code_interval_distribution = real_data.sum(axis=0).sum(axis=0).astype(float)
code_interval_distribution /= code_interval_distribution.sum()
# 3. Number of patients that have at least once that code
patient_code_distribution = (real_data.sum(axis=1) > 0).sum(axis=0).astype(float)
patient_code_distribution /= patient_code_distribution.sum()
return intervals_distribution, code_interval_distribution, patient_code_distribution |
JabobKrauskopf
added
p-high
Priority: High
and removed
p-medium
Priority: Medium
labels
Aug 20, 2024
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Create an Evaluator class, which takes two medrecords and evaluate the differences and similarities between them.
Tasks:
The text was updated successfully, but these errors were encountered: