Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create an Evaluator/Comparer #157

Open
5 tasks
MarIniOnz opened this issue Aug 5, 2024 · 2 comments
Open
5 tasks

Create an Evaluator/Comparer #157

MarIniOnz opened this issue Aug 5, 2024 · 2 comments
Assignees
Labels
epic Tracks multiple related issues as a high-level task feat Related to new feature or enhancement p-high Priority: High python Related to the python medmodels

Comments

@MarIniOnz
Copy link
Contributor

MarIniOnz commented Aug 5, 2024

Create an Evaluator class, which takes two medrecords and evaluate the differences and similarities between them.

Tasks:

@MarIniOnz MarIniOnz added python Related to the python medmodels p-medium Priority: Medium feat Related to new feature or enhancement labels Aug 5, 2024
@MarIniOnz
Copy link
Contributor Author

Previous MTGAN evaluator:

"""Class for evaluating the results of the MTGAN model."""

import json
import logging
from pathlib import Path
from typing import Any, Dict

import numpy as np
import torch

from medmodels.data_synthesis.mtgan.model.generator.generator import Generator
from medmodels.data_synthesis.mtgan.model.loaders import MTGANDataset
from medmodels.data_synthesis.mtgan.modules.evaluator_functions import (
    calculate_distance,
    get_basic_statistics,
    get_top_k_disease,
)
from medmodels.data_synthesis.mtgan.modules.mtgan_preprocessor import MTGANPreprocessor
from medmodels.medrecord.medrecord import MedRecord


class MTGANEvaluator(torch.nn.Module):
    """Class for evaluating the results of the MTGAN model."""

    def __init__(self):
        super().__init__()
        self.number_admissions_distribution = None
        self.generator = None
        self.preprocessor = MTGANPreprocessor()
        self.batch_size = 32

    def _load_hyperparameters(self, hyperparameter_path: Path) -> None:
        """Loads the hyperparameters from a JSON file.

        Args:
            hyperparameter_path (Path): Path to the hyperparameters JSON file.

        Raises:
            FileNotFoundError: If the hyperparameters file is not found.
            ValueError: If the hyperparameters file does not contain a 'generation' key.
        """
        if not hyperparameter_path.exists():
            msg = f"The hyperparameters file '{hyperparameter_path}' was not found."
            logging.error(msg)
            raise FileNotFoundError(msg)

        with open(hyperparameter_path, "r", encoding="utf-8") as f:
            hyperparameters = json.load(f)
            if "generation" not in hyperparameters:
                msg = "Hyperparameters file must contain a 'generation' key."
                raise ValueError(msg)
            self.batch_size = hyperparameters["generation"]["GENERATE_BATCH_SIZE"]

    def _load_processed_data(
        self,
        generator: Generator,
        preprocessor: MTGANPreprocessor,
        number_admissions_distribution: torch.Tensor,
    ) -> None:
        """Load the processed data from the generator and preprocessor.

        Args:
            generator (Generator): Generator model.
            preprocessor (MTGANPreprocessor): Preprocessor class.
            number_admissions_distribution (torch.Tensor): Number of admissions distribution.
        """
        self.generator = generator
        self.preprocessor = preprocessor
        self.number_admissions_distribution = number_admissions_distribution

    def evaluate(
        self, real_medrecord: MedRecord, synthetic_medrecord: MedRecord
    ) -> Dict[str, Any]:
        """Evaluate the results of the MTGAN model.

        Args:
            real_medrecord (MedRecord): Real MedRecord object.
            synthetic_medrecord (MedRecord): Synthetic MedRecord object.

        Returns:
            Dict[str, Any]: Dictionary containing evaluation statistics.
        """
        if self.generator is None or self.number_admissions_distribution is None:
            msg = "Training needs to be done before evaluation."
            logging.error(msg)
            raise ValueError(msg)

        # Computing real and synthetic data
        real_num_patients = len(
            real_medrecord.nodes_in_group(self.preprocessor.patients_group)
        )
        real_dataset = MTGANDataset(self.preprocessor, medrecord=real_medrecord)
        real_data, real_number_admissions = real_dataset[np.arange(real_num_patients)]
        real_data, real_number_admissions = (
            real_data.cpu().numpy().astype(np.bool),
            real_number_admissions.cpu().numpy().astype(np.int16),
        )

        synthetic_num_patients = len(
            synthetic_medrecord.nodes_in_group(self.preprocessor.patients_group)
        )
        synthetic_dataset = MTGANDataset(
            self.preprocessor, medrecord=synthetic_medrecord
        )
        synthetic_data, synthetic_number_admissions = synthetic_dataset[
            np.arange(synthetic_num_patients)
        ]
        synthetic_data, synthetic_number_admissions = (
            synthetic_data.cpu().numpy().astype(np.bool),
            synthetic_number_admissions.cpu().numpy().astype(np.int16),
        )

        # Get statistics from real and synthetic data
        real_stats = get_basic_statistics(real_data, real_number_admissions)
        synthetic_stats = get_basic_statistics(
            synthetic_data, synthetic_number_admissions
        )

        logging.info("Top 10 Real diseases")
        top_real_diseases = get_top_k_disease(
            real_data, self.preprocessor.index_to_concept_dict, top_k=10
        )
        logging.info("Top 10 Synthetic diseases")
        top_synthetic_diseases = get_top_k_disease(
            synthetic_data, self.preprocessor.index_to_concept_dict, top_k=10
        )

        # Calculate distances between real and synthetic data
        distances = calculate_distance(real_data, synthetic_data)

        # Get the required number of samples to generate all diseases at least once
        needed_samples = self.generator.get_required_number(
            self.number_admissions_distribution, self.batch_size
        )

        stats_dict = {
            "real_stats": real_stats,
            "synthetic_stats": synthetic_stats,
            "top_real_diseases": top_real_diseases,
            "top_synthetic_diseases": top_synthetic_diseases,
            "distances": distances,
            "needed_samples": needed_samples,
        }

        return stats_dict

@MarIniOnz
Copy link
Contributor Author

Evaluator functions (could be used or not, specific to MTGAN formatting)

"""This module contains functions for computing statistics of the data and
distances between real and synthetic data.

Functions:
    - get_basic_statistics: get basic statistics: number of true codes, number of all
        codes, number of all visits, average number of codes per visit, average number
        of visits per patient
    - code_count: count the number of occurrences of each code in the data
    - get_top_k_disease: get the diseases with their codes and number of occurrences in
        order of occurrences and save the top k diseases into a file if file is not None
    - normalized_distance: compute normalized distance between two distributions
    - pad_distribution: pad the shorter distribution with zeros
    - calculate_distance: compute the distances btw. real and synthetic data based on
        the following stats: Jensen-Shannon-Divergence and normalized distance for
        patients (p) and visits (v)
"""

import logging
from typing import Dict, List, Tuple, Union

import numpy as np
from numpy.typing import NDArray
from scipy.spatial.distance import jensenshannon as jsd

from medmodels.medrecord.types import NodeIndex


def get_basic_statistics(
    data: NDArray[np.bool], number_admissions: NDArray[np.int64]
) -> Dict[str, Union[int, float]]:
    """Get basic statistics: number of true codes, number of all codes, number of all
    visits, average number of codes per visit, average number of visits per patient.

    Args:
        data (NDArray[np.bool]): input data, 3D boolean matrix of shape
            (num_patients, max_number_admissions, num_codes)
        number_admissions (NDArray[np.int64]): number of visits per patient

    Returns:
        Dict[str, Any]: Dictionary with the following stats:
            - number of code types (happening to at least one patient in the data),
            - number of all codes (how many diagnoses there are in total in the data),
            - number of visits,
            - avg codes per visit,
            - avg visits per patient
    """
    result = data.sum(axis=1).sum(axis=0)

    num_types = (result > 0).sum()
    num_codes = result.sum()
    num_visits = number_admissions.sum()

    mean_code_num = num_codes / num_visits if num_visits > 0 else 0
    mean_visit_num = num_visits / data.shape[0]

    return {
        "total_num_code_types": num_types,  # number of unique codes
        "total_num_codes": num_codes,
        "total_num_visits": num_visits,
        "mean_codes_per_visit": mean_code_num,
        "mean_visit_per_patient": mean_visit_num,
    }


def code_count(
    data: NDArray[np.bool], index_to_concept_map: Dict[int, NodeIndex]
) -> List[Tuple[NodeIndex, int]]:
    """Count the number of occurrences of each code in the data

    Args:
        data (NDArray[np.bool]): input data, 3D boolean matrix of shape
            (num_patients, max_time_window, num_codes)
        number_admissions (NDArray[np.int64]): number of visits per patient
        index_code_map (Dict[int, NodeIndex]): mapping of code indices to code values

    Returns:
        List[Tuple[NodeIndex, int]]: sorted count of codes
    """
    count = {
        index_to_concept_map[i]: int(data[:, :, i].sum()) for i in range(data.shape[2])
    }
    return sorted(count.items(), key=lambda item: item[1], reverse=True)


def get_top_k_disease(
    data: NDArray[np.bool],
    index_to_concept_map: Dict[int, NodeIndex],
    top_k: int = 10,
) -> List[Tuple[NodeIndex, int]]:
    """Get the diseases with their codes and number of occurrences in order of
    occurrences and save the top k diseases into a file if file is not None

    Args:
        data (NDArray[np.bool]): input data, 3D boolean matrix of shape
            (num_patients, max_time_window, num_codes)
        index_to_concept_map (Dict[int, NodeIndex]): mapping of code indices to code values
        top_k (int, optional): number of top diseases to return, defaults to 10

    Returns:
        List[Tuple[NodeIndex, int]]: sorted count of codes
    """
    count = code_count(data, index_to_concept_map)
    num_diseases = 0
    logging.info("--------------------------------------------------")
    for code_id, num in count:
        num_diseases += 1
        logging.info("%s; %d", code_id, num)
        if num_diseases == top_k:
            break
    logging.info("--------------------------------------------------")

    return count[:top_k]


def normalized_distance(
    distribution1: NDArray[np.float32],
    distribution2: NDArray[np.float32],
    epsilon: float = 1e-12,
) -> float:
    """Compute normalized distance between two distributions.

    The normalized distance is the absolute difference between two probability
    distributions divided by their average.

    The normalized distance is defined as:
        dist = (p1 - p2)^2 / ((p1 + p2) / 2)
    where p1 and p2 are the two probability distributions.

    Args:
        distribution1 (NDArray[np.float32]): first distribution
        distribution2 (NDArray[np.float32]): second distribution
        epsilon (float, optional): epsilon for smoothing, defaults to 1e-12

    Returns:
        float: normalized distance
    """
    distance = ((distribution1 - distribution2) ** 2 + epsilon) / (
        ((distribution1 + distribution2) / 2) + epsilon
    )
    return distance.mean()


def pad_distribution(
    distribution1: NDArray[np.float32], distribution2: NDArray[np.float32]
) -> Tuple[NDArray[np.float32], NDArray[np.float32]]:
    """Pad the shorter distribution with zeros

    Args:
        distribution1 (NDArray[np.float32]): first distribution
        distribution2 (NDArray[np.float32]): second distribution

    Returns:
        Tuple[NDArray[np.float32], NDArray[np.float32]]: padded distributions
    """
    if len(distribution1) < len(distribution2):
        distribution1 = np.pad(
            distribution1, (0, len(distribution2) - len(distribution1)), "constant"
        )
    elif len(distribution2) < len(distribution1):
        distribution2 = np.pad(
            distribution2, (0, len(distribution1) - len(distribution2)), "constant"
        )
    return distribution1, distribution2


def calculate_distance(
    real_data: NDArray[np.bool],
    synthetic_data: NDArray[np.bool],
) -> Dict[str, float]:
    """Compute the distances btw. real and synthetic data based on the following stats:
    Jensen-Shannon-Divergence and normalized distance for patients (p) and visits (v).

    JS divergence is a symmetrized and smoothed version of the Kullback-Leibler
    divergence. It is a measure of similarity between two probability distributions.
    The normalized distance is the absolute difference between two probability
    distributions divided by their average.

    Args:
        real_data (NDArray[np.bool]): real data, boolean matrix of shape
            (number_patients, max_number_admissions, number_codes)
        synthetic_data (NDArray[np.float32]): synthetic data, boolean matrix of shape
            (number_patients, max_number_admissions, number_codes)

    Returns:
        Dict[str, float]: Dictionary with the following stats:
            - JS divergence for the intervals' distributions,
            - normalized distance for the intervals' distributions,
            - JS divergence for the codes' distributions,
            - normalized distance for the codes' distributions,
            - JS divergence for the patients' distributions,
            - normalized distance for the patients' distributions.
    """
    (
        real_intervals_distribution,
        real_codes_interval_distribution,
        real_patients_code_distribution,
    ) = real_data_distribution(real_data)
    (
        synthetic_intervals_distribution,
        synthetic_codes_interval_distribution,
        synthetic_patients_code_distribution,
    ) = real_data_distribution(synthetic_data)

    # Calculate distances for intervals' distributions
    real_intervals_distribution, synthetic_intervals_distribution = pad_distribution(
        real_intervals_distribution, synthetic_intervals_distribution
    )
    js_intervals = jsd(
        real_intervals_distribution, synthetic_intervals_distribution
    ).astype(float)
    normalized_distance_intervals = normalized_distance(
        real_intervals_distribution, synthetic_intervals_distribution
    )

    # Calculate distances for codes' distributions
    js_codes_interval = jsd(
        real_codes_interval_distribution, synthetic_codes_interval_distribution
    ).astype(float)
    normalized_distance_codes_interval = normalized_distance(
        real_codes_interval_distribution, synthetic_codes_interval_distribution
    )

    # Calculate distances for patients' distributions
    js_patients_code = jsd(
        real_patients_code_distribution, synthetic_patients_code_distribution
    ).astype(float)
    normalized_distance_patients_code = normalized_distance(
        real_patients_code_distribution, synthetic_patients_code_distribution
    )

    return {
        "JS_divergence_intervals": js_intervals,
        "Normalized_distance_intervals": normalized_distance_intervals,
        "JS_divergence_codes_interval": js_codes_interval,
        "Normalized_distance_codes_interval": normalized_distance_codes_interval,
        "JS_divergence_patients_code": js_patients_code,
        "Normalized_distance_patients_code": normalized_distance_patients_code,
    }


def real_data_distribution(
    real_data: NDArray[np.bool],
) -> Tuple[NDArray[np.float32], NDArray[np.float32], NDArray[np.float32]]:
    """Analyse the real data and return different distributions.

    Args:
        real_data (NDArray[np.bool]): real data, 3D boolean array containing the
            patients data. Shape: (num_patients, number_admissions, num_codes).
            1 if the patient has the code in the interval, 0 otherwise.

    Returns:
        Tuple[NDArray[np.float32], NDArray[np.float32], NDArray[np.float32]]: distribution
        of the number of time intervals with codes per patient, distribution of the number
        of times each code appears in all intervals, distribution of the number of patients
        to whom each code is assigned at least once.
    """
    # 1. Distribution of how many time intervals each patient has
    intervals_per_patient = (real_data.sum(axis=2) > 0).sum(axis=1)
    intervals_distribution = np.bincount(intervals_per_patient) / (
        len(intervals_per_patient)
    )

    # 2. Number of times a code appears in the whole dataset
    code_interval_distribution = real_data.sum(axis=0).sum(axis=0).astype(float)
    code_interval_distribution /= code_interval_distribution.sum()

    # 3. Number of patients that have at least once that code
    patient_code_distribution = (real_data.sum(axis=1) > 0).sum(axis=0).astype(float)
    patient_code_distribution /= patient_code_distribution.sum()

    return intervals_distribution, code_interval_distribution, patient_code_distribution

@MarIniOnz MarIniOnz changed the title Create an Evaluator Create an Evaluator/Comparer Aug 15, 2024
@MarIniOnz MarIniOnz added the epic Tracks multiple related issues as a high-level task label Aug 15, 2024
@JabobKrauskopf JabobKrauskopf added p-high Priority: High and removed p-medium Priority: Medium labels Aug 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
epic Tracks multiple related issues as a high-level task feat Related to new feature or enhancement p-high Priority: High python Related to the python medmodels
Projects
None yet
Development

When branches are created from issues, their pull requests are automatically linked.

3 participants