diff --git a/medmodels/treatment_effect/builder.py b/medmodels/treatment_effect/builder.py index 15ceb4c7..9d5f25e9 100644 --- a/medmodels/treatment_effect/builder.py +++ b/medmodels/treatment_effect/builder.py @@ -38,7 +38,7 @@ class TreatmentEffectBuilder: matching_one_hot_covariates: Optional[MedRecordAttributeInputList] matching_model: Optional[Model] matching_number_of_neighbors: Optional[int] - matching_hyperparam: Optional[Dict[str, Any]] + matching_hyperparameters: Optional[Dict[str, Any]] def with_treatment(self, treatment: Group) -> TreatmentEffectBuilder: """Sets the treatment group for the treatment effect estimation. @@ -218,27 +218,25 @@ def filter_controls(self, query: NodeQuery) -> TreatmentEffectBuilder: def with_propensity_matching( self, - essential_covariates: MedRecordAttributeInputList = ["gender", "age"], - one_hot_covariates: MedRecordAttributeInputList = ["gender"], + essential_covariates: Optional[MedRecordAttributeInputList] = None, + one_hot_covariates: Optional[MedRecordAttributeInputList] = None, model: Model = "logit", number_of_neighbors: int = 1, - hyperparam: Optional[Dict[str, Any]] = None, + hyperparameters: Optional[Dict[str, Any]] = None, ) -> TreatmentEffectBuilder: """Adjust the treatment effect estimate using propensity score matching. Args: - essential_covariates (MedRecordAttributeInputList, optional): - Covariates that are essential for matching. Defaults to - ["gender", "age"]. - one_hot_covariates (MedRecordAttributeInputList, optional): - Covariates that are one-hot encoded for matching. Defaults to - ["gender"]. + essential_covariates (Optional[MedRecordAttributeInputList]): + Covariates that are essential for matching. Defaults to None. + one_hot_covariates (Optional[MedRecordAttributeInputList]): + Covariates that are one-hot encoded for matching. Defaults to None. model (Model, optional): Model to choose for the matching. Defaults to "logit". number_of_neighbors (int, optional): Number of neighbors to consider for the matching. Defaults to 1. - hyperparam (Optional[Dict[str, Any]], optional): Hyperparameters for the - matching model. Defaults to None. + hyperparameters (Optional[Dict[str, Any]], optional): Hyperparameters for + the matching model. Defaults to None. Returns: TreatmentEffectBuilder: The current instance of the TreatmentEffectBuilder @@ -249,27 +247,25 @@ def with_propensity_matching( self.matching_one_hot_covariates = one_hot_covariates self.matching_model = model self.matching_number_of_neighbors = number_of_neighbors - self.matching_hyperparam = hyperparam + self.matching_hyperparameters = hyperparameters return self def with_nearest_neighbors_matching( self, - essential_covariates: MedRecordAttributeInputList = ["gender", "age"], - one_hot_covariates: MedRecordAttributeInputList = ["gender"], + essential_covariates: Optional[MedRecordAttributeInputList] = None, + one_hot_covariates: Optional[MedRecordAttributeInputList] = None, number_of_neighbors: int = 1, ) -> TreatmentEffectBuilder: """Adjust the treatment effect estimate using nearest neighbors matching. Args: - essential_covariates (MedRecordAttributeInputList, optional): - Covariates that are essential for matching. Defaults to - ["gender", "age"]. - one_hot_covariates (MedRecordAttributeInputList, optional): - Covariates that are one-hot encoded for matching. Defaults to - ["gender"]. - number_of_neighbors (int, optional): Number of neighbors to consider for the - matching. Defaults to 1. + essential_covariates (Optional[MedRecordAttributeInputList]): + Covariates that are essential for matching. Defaults to None. + one_hot_covariates (Optional[MedRecordAttributeInputList]): + Covariates that are one-hot encoded for matching. Defaults to None. + number_of_neighbors (int, optional): Number of neighbors to consider for + the matching. Defaults to 1. Returns: TreatmentEffectBuilder: The current instance of the TreatmentEffectBuilder diff --git a/medmodels/treatment_effect/estimate.py b/medmodels/treatment_effect/estimate.py index 708b6d0b..d6e0c52c 100644 --- a/medmodels/treatment_effect/estimate.py +++ b/medmodels/treatment_effect/estimate.py @@ -181,7 +181,7 @@ def _sort_subjects_in_groups( else PropensityMatching( number_of_neighbors=self._treatment_effect._matching_number_of_neighbors, model=self._treatment_effect._matching_model, - hyperparam=self._treatment_effect._matching_hyperparam, + hyperparameters=self._treatment_effect._matching_hyperparameters, ) ) @@ -191,6 +191,7 @@ def _sort_subjects_in_groups( medrecord=medrecord, treated_set=treated_set, control_set=control_set, + patients_group=self._treatment_effect._patients_group, essential_covariates=self._treatment_effect._matching_essential_covariates, one_hot_covariates=self._treatment_effect._matching_one_hot_covariates, ) diff --git a/medmodels/treatment_effect/matching/algorithms/propensity_score.py b/medmodels/treatment_effect/matching/algorithms/propensity_score.py index 92ca7c4b..a5ce6bfe 100644 --- a/medmodels/treatment_effect/matching/algorithms/propensity_score.py +++ b/medmodels/treatment_effect/matching/algorithms/propensity_score.py @@ -32,7 +32,7 @@ def calculate_propensity( treated_test: NDArray[Union[np.int64, np.float64]], control_test: NDArray[Union[np.int64, np.float64]], model: Model = "logit", - hyperparam: Optional[Dict[str, Any]] = None, + hyperparameters: Optional[Dict[str, Any]] = None, ) -> Tuple[NDArray[np.float64], NDArray[np.float64]]: """Trains a classification algorithm on training data, predicts the probability of being in the last class for treated and control test datasets, and returns these probabilities. @@ -49,8 +49,8 @@ def calculate_propensity( control group to predict probabilities. model (Model, optional): Classification algorithm to use. Options: "logit", "dec_tree", "forest". - hyperparam (Optional[Dict[str, Any]], optional): Manual hyperparameter settings. - Uses default if None. + hyperparameters (Optional[Dict[str, Any]], optional): Manual hyperparameter + settings. Uses default hyperparameters if None. Returns: Tuple[NDArray[np.float64], NDArray[np.float64]: Probabilities of the positive @@ -61,7 +61,7 @@ class for treated and control groups. last class for treated and control sets, e.g., ([0.], [0.]). """ propensity_model = PROP_MODEL[model] - pm = propensity_model(**hyperparam) if hyperparam else propensity_model() + pm = propensity_model(**hyperparameters) if hyperparameters else propensity_model() pm.fit(x_train, y_train) # Predict the probability of the treated and control groups @@ -76,7 +76,7 @@ def run_propensity_score( control_set: pl.DataFrame, model: Model = "logit", number_of_neighbors: int = 1, - hyperparam: Optional[Dict[str, Any]] = None, + hyperparameters: Optional[Dict[str, Any]] = None, covariates: Optional[MedRecordAttributeInputList] = None, ) -> pl.DataFrame: """Executes Propensity Score matching using a specified classification algorithm. @@ -95,7 +95,7 @@ def run_propensity_score( Options include "logit", "dec_tree", "forest". number_of_neighbors (int, optional): Number of nearest neighbors to find for each treated unit. Defaults to 1. - hyperparam (Optional[Dict[str, Any]], optional): Hyperparameters for model + hyperparameters (Optional[Dict[str, Any]], optional): Hyperparameters for model tuning. Increases computation time if set. Uses default if None. covariates (Optional[MedRecordAttributeInputList], optional): Features for matching. Uses all if None. @@ -119,7 +119,7 @@ def run_propensity_score( y_train, treated_array, control_array, - hyperparam=hyperparam, + hyperparameters=hyperparameters, model=model, ) diff --git a/medmodels/treatment_effect/matching/matching.py b/medmodels/treatment_effect/matching/matching.py index 3a911786..c9726575 100644 --- a/medmodels/treatment_effect/matching/matching.py +++ b/medmodels/treatment_effect/matching/matching.py @@ -1,12 +1,14 @@ from __future__ import annotations from abc import ABCMeta, abstractmethod -from typing import TYPE_CHECKING, Literal, Set, Tuple +from typing import TYPE_CHECKING, Literal, Optional, Set, Tuple import polars as pl +from medmodels.medrecord._overview import extract_attribute_summary from medmodels.medrecord.medrecord import MedRecord -from medmodels.medrecord.types import MedRecordAttributeInputList, NodeIndex +from medmodels.medrecord.querying import NodeOperand +from medmodels.medrecord.types import Group, MedRecordAttributeInputList, NodeIndex if TYPE_CHECKING: import sys @@ -22,14 +24,25 @@ class Matching(metaclass=ABCMeta): """The Base Class for matching.""" + number_of_neighbors: int + + def __init__(self, number_of_neighbors: int) -> None: + """Initializes the matching class. + + Args: + number_of_neighbors (int): Number of nearest neighbors to find for each treated unit. + """ + self.number_of_neighbors = number_of_neighbors + def _preprocess_data( self, *, medrecord: MedRecord, control_set: Set[NodeIndex], treated_set: Set[NodeIndex], - essential_covariates: MedRecordAttributeInputList, - one_hot_covariates: MedRecordAttributeInputList, + patients_group: Group, + essential_covariates: Optional[MedRecordAttributeInputList] = None, + one_hot_covariates: Optional[MedRecordAttributeInputList] = None, ) -> Tuple[pl.DataFrame, pl.DataFrame]: """Prepared the data for the matching algorithms. @@ -37,21 +50,42 @@ def _preprocess_data( medrecord (MedRecord): MedRecord object containing the data. control_set (Set[NodeIndex]): Set of treated subjects. treated_set (Set[NodeIndex]): Set of control subjects. - essential_covariates (MedRecordAttributeInputList): Covariates - that are essential for matching - one_hot_covariates (MedRecordAttributeInputList): Covariates that - are one-hot encoded for matching + patients_group (Group): The group of patients. + essential_covariates (Optional[MedRecordAttributeInputList]): + Covariates that are essential for matching. Defaults to None. + one_hot_covariates (Optional[MedRecordAttributeInputList]): + Covariates that are one-hot encoded for matching. Defaults to None. Returns: Tuple[pl.DataFrame, pl.DataFrame]: Treated and control groups with their preprocessed covariates + + Raises: + ValueError: If not enough control subjects to match the treated subjects. + ValueError: If some treated nodes do not have all the essential covariates. + AssertionError: If the one-hot covariates are not in the essential covariates. """ - essential_covariates = [str(covariate) for covariate in essential_covariates] + if essential_covariates is None: + # If no essential covariates are provided, use all the attributes of the patients + essential_covariates = list( + extract_attribute_summary( + medrecord.node[medrecord.nodes_in_group(patients_group)] + ) + ) + else: + essential_covariates = [covariate for covariate in essential_covariates] + + control_set = self._check_nodes( + medrecord=medrecord, + treated_set=treated_set, + control_set=control_set, + essential_covariates=essential_covariates, + ) if "id" not in essential_covariates: essential_covariates.append("id") - # Dataframe + # Dataframe wth the essential covariates data = pl.DataFrame( data=[ {"id": k, **v} @@ -60,6 +94,24 @@ def _preprocess_data( ) original_columns = data.columns + if one_hot_covariates is None: + # If no one-hot covariates are provided, use all the categorical attributes of the patients + attributes = extract_attribute_summary( + medrecord.node[medrecord.nodes_in_group(patients_group)] + ) + one_hot_covariates = [ + covariate + for covariate, values in attributes.items() + if "values" in values + ] + + if not all( + covariate in essential_covariates for covariate in one_hot_covariates + ): + raise AssertionError( + "One-hot covariates must be in the essential covariates" + ) + # One-hot encode the categorical variables data = data.to_dummies( columns=[str(covariate) for covariate in one_hot_covariates], @@ -79,6 +131,59 @@ def _preprocess_data( return data_treated, data_control + def _check_nodes( + self, + medrecord: MedRecord, + treated_set: Set[NodeIndex], + control_set: Set[NodeIndex], + essential_covariates: MedRecordAttributeInputList, + ) -> Set[NodeIndex]: + """Check if the treated and control sets are disjoint. + + Args: + medrecord (MedRecord): MedRecord object containing the data. + treated_set (Set[NodeIndex]): Set of treated subjects. + control_set (Set[NodeIndex]): Set of control subjects. + essential_covariates (MedRecordAttributeInputList): Covariates that are + essential for matching. + + Returns: + Set[NodeIndex]: The control set. + + Raises: + ValueError: If not enough control subjects to match the treated subjects. + """ + + def query_essential_covariates( + node: NodeOperand, patients_set: Set[NodeIndex] + ) -> None: + """Query the nodes that have all the essential covariates.""" + for attribute in essential_covariates: + node.has_attribute(attribute) + + node.index().is_in(list(patients_set)) + + control_set = set( + medrecord.select_nodes( + lambda node: query_essential_covariates(node, control_set) + ) + ) + if len(control_set) < self.number_of_neighbors * len(treated_set): + raise ValueError( + "Not enough control subjects to match the treated subjects" + ) + + if len(treated_set) != len( + medrecord.select_nodes( + lambda node: query_essential_covariates(node, treated_set) + ) + ): + raise ValueError( + "Some treated nodes do not have all the essential covariates" + ) + + return control_set + @abstractmethod def match_controls( self, @@ -86,6 +191,6 @@ def match_controls( control_set: Set[NodeIndex], treated_set: Set[NodeIndex], medrecord: MedRecord, - essential_covariates: MedRecordAttributeInputList = ["gender", "age"], - one_hot_covariates: MedRecordAttributeInputList = ["gender"], + essential_covariates: Optional[MedRecordAttributeInputList] = None, + one_hot_covariates: Optional[MedRecordAttributeInputList] = None, ) -> Set[NodeIndex]: ... diff --git a/medmodels/treatment_effect/matching/neighbors.py b/medmodels/treatment_effect/matching/neighbors.py index 13ab9451..d66f79d6 100644 --- a/medmodels/treatment_effect/matching/neighbors.py +++ b/medmodels/treatment_effect/matching/neighbors.py @@ -1,9 +1,9 @@ from __future__ import annotations -from typing import Set +from typing import Optional, Set from medmodels import MedRecord -from medmodels.medrecord.types import MedRecordAttributeInputList, NodeIndex +from medmodels.medrecord.types import Group, MedRecordAttributeInputList, NodeIndex from medmodels.treatment_effect.matching.algorithms.classic_distance_models import ( nearest_neighbor, ) @@ -32,7 +32,7 @@ def __init__( number_of_neighbors (int, optional): Number of nearest neighbors to find for each treated unit. Defaults to 1. """ - self.number_of_neighbors = number_of_neighbors + super().__init__(number_of_neighbors) def match_controls( self, @@ -40,8 +40,9 @@ def match_controls( medrecord: MedRecord, control_set: Set[NodeIndex], treated_set: Set[NodeIndex], - essential_covariates: MedRecordAttributeInputList = ["gender", "age"], - one_hot_covariates: MedRecordAttributeInputList = ["gender"], + patients_group: Group, + essential_covariates: Optional[MedRecordAttributeInputList] = None, + one_hot_covariates: Optional[MedRecordAttributeInputList] = None, ) -> Set[NodeIndex]: """Matches the controls based on the nearest neighbor algorithm. @@ -49,10 +50,11 @@ def match_controls( medrecord (MedRecord): MedRecord object containing the data. treated_set (Set[NodeIndex]): Set of treated subjects. control_set (Set[NodeIndex]): Set of control subjects. - essential_covariates (MedRecordAttributeInputList, optional): Covariates - that are essential for matching - one_hot_covariates (MedRecordAttributeInputList, optional): Covariates that - are one-hot encoded for matching + patients_group (Group): The group of patients. + eessential_covariates (Optional[MedRecordAttributeInputList]): + Covariates that are essential for matching. Defaults to None. + one_hot_covariates (Optional[MedRecordAttributeInputList]): + Covariates that are one-hot encoded for matching. Defaults to None. Returns: Set[NodeIndex]: Node Ids of the matched controls. @@ -61,6 +63,7 @@ def match_controls( medrecord=medrecord, control_set=control_set, treated_set=treated_set, + patients_group=patients_group, essential_covariates=essential_covariates, one_hot_covariates=one_hot_covariates, ) diff --git a/medmodels/treatment_effect/matching/propensity.py b/medmodels/treatment_effect/matching/propensity.py index 053c0b2a..aa8af125 100644 --- a/medmodels/treatment_effect/matching/propensity.py +++ b/medmodels/treatment_effect/matching/propensity.py @@ -6,7 +6,7 @@ import polars as pl from medmodels import MedRecord -from medmodels.medrecord.types import MedRecordAttributeInputList, NodeIndex +from medmodels.medrecord.types import Group, MedRecordAttributeInputList, NodeIndex from medmodels.treatment_effect.matching.algorithms.classic_distance_models import ( nearest_neighbor, ) @@ -36,23 +36,22 @@ def __init__( *, model: Model = "logit", number_of_neighbors: int = 1, - hyperparam: Optional[Dict[str, Any]] = None, + hyperparameters: Optional[Dict[str, Any]] = None, ): """Initializes the propensity score class. Args: model (Model, optional): classification method to be used, default: "logit". Can be chosen from ["logit", "dec_tree", "forest"]. - nearest_neighbors_algorithm (NNAlgorithm, optional): algorithm used to - compute nearest neighbors. Defaults to "auto". number_of_neighbors (int, optional): number of neighbors to be matched per treated subject. Defaults to 1. - hyperparam (Optional[Dict[str, Any]], optional): hyperparameters for the - classification model, default: None. + hyperparameters (Optional[Dict[str, Any]], optional): hyperparameters for + the classification model. Defaults to None. """ + super().__init__(number_of_neighbors) self.model = model self.number_of_neighbors = number_of_neighbors - self.hyperparam = hyperparam + self.hyperparameters = hyperparameters def match_controls( self, @@ -60,8 +59,9 @@ def match_controls( medrecord: MedRecord, control_set: Set[NodeIndex], treated_set: Set[NodeIndex], - essential_covariates: MedRecordAttributeInputList = ["gender", "age"], - one_hot_covariates: MedRecordAttributeInputList = ["gender"], + patients_group: Group, + essential_covariates: Optional[MedRecordAttributeInputList] = None, + one_hot_covariates: Optional[MedRecordAttributeInputList] = None, ) -> Set[NodeIndex]: """Matches the controls based on propensity score matching. @@ -69,10 +69,11 @@ def match_controls( medrecord (MedRecord): medrecord object containing the data. treated_set (Set[NodeIndex]): Set of treated subjects. control_set (Set[NodeIndex]): Set of control subjects. - essential_covariates (MedRecordAttributeInputList, optional): Covariates - that are essential for matching. Defaults to ["gender", "age"]. - one_hot_covariates (MedRecordAttributeInputList, optional): Covariates that - are one-hot encoded for matching. Defaults to ["gender"]. + patients_group (Group): The group of patients. + essential_covariates (Optional[MedRecordAttributeInputList]): + Covariates that are essential for matching. Defaults to None. + one_hot_covariates (Optional[MedRecordAttributeInputList]): + Covariates that are one-hot encoded for matching. Defaults to None. Returns: Set[NodeIndex]: Node Ids of the matched controls. @@ -82,6 +83,7 @@ def match_controls( medrecord=medrecord, treated_set=treated_set, control_set=control_set, + patients_group=patients_group, essential_covariates=essential_covariates, one_hot_covariates=one_hot_covariates, ) @@ -100,7 +102,7 @@ def match_controls( y_train=y_train, treated_test=treated_array, control_test=control_array, - hyperparam=self.hyperparam, + hyperparameters=self.hyperparameters, model=self.model, ) diff --git a/medmodels/treatment_effect/matching/tests/test_matching.py b/medmodels/treatment_effect/matching/tests/test_matching.py new file mode 100644 index 00000000..ff49fe57 --- /dev/null +++ b/medmodels/treatment_effect/matching/tests/test_matching.py @@ -0,0 +1,241 @@ +"""Tests for the NeighborsMatching class in the matching module.""" + +import unittest +from typing import List, Set + +import pandas as pd + +from medmodels import MedRecord +from medmodels.medrecord.types import NodeIndex +from medmodels.treatment_effect.matching.neighbors import NeighborsMatching + + +def create_patients(patient_list: List[NodeIndex]) -> pd.DataFrame: + """Creates a patients dataframe. + + Returns: + pd.DataFrame: A patients dataframe. + """ + patients = pd.DataFrame( + { + "index": ["P1", "P2", "P3", "P4", "P5", "P6", "P7", "P8", "P9"], + "age": [20, 30, 40, 30, 40, 50, 60, 70, 80], + "gender": [ + "male", + "female", + "male", + "female", + "male", + "female", + "male", + "female", + "male", + ], + } + ) + + patients = patients.loc[patients["index"].isin(patient_list)] + return patients + + +def create_medrecord( + patient_list: List[NodeIndex] = [ + "P1", + "P2", + "P3", + "P4", + "P5", + "P6", + "P7", + "P8", + "P9", + ], +) -> MedRecord: + """Creates a MedRecord object. + + Returns: + MedRecord: A MedRecord object. + """ + patients = create_patients(patient_list=patient_list) + medrecord = MedRecord.from_pandas(nodes=[(patients, "index")]) + medrecord.add_group(group="patients", nodes=patients["index"].to_list()) + return medrecord + + +class TestNeighborsMatching(unittest.TestCase): + """Class to test the NeighborsMatching class in the matching module.""" + + def setUp(self): + self.medrecord = create_medrecord() + + def test_preprocess_data(self): + neighbors_matching = NeighborsMatching(number_of_neighbors=1) + + control_set: Set[NodeIndex] = {"P1", "P3", "P5", "P7", "P9"} + treated_set: Set[NodeIndex] = {"P2", "P4", "P6"} + + data_treated, data_control = neighbors_matching._preprocess_data( + medrecord=self.medrecord, + control_set=control_set, + treated_set=treated_set, + patients_group="patients", + essential_covariates=["age", "gender"], + one_hot_covariates=["gender"], + ) + + # Assert that the treated and control dataframes have the correct columns + self.assertIn("age", data_treated.columns) + self.assertIn("age", data_control.columns) + self.assertTrue( + "gender_female" in data_treated.columns + or "gender_male" in data_treated.columns + ) + self.assertTrue( + "gender_female" in data_control.columns + or "gender_male" in data_control.columns + ) + + # Assert that the treated and control dataframes have the correct number of rows + self.assertEqual(len(data_treated), len(treated_set)) + self.assertEqual(len(data_control), len(control_set)) + + # Try automatic detection of attributes + data_treated, data_control = neighbors_matching._preprocess_data( + medrecord=self.medrecord, + control_set=control_set, + treated_set=treated_set, + patients_group="patients", + ) + + # Assert that the treated and control dataframes have the correct columns + self.assertIn("age", data_treated.columns) + self.assertIn("age", data_control.columns) + self.assertTrue( + "gender_female" in data_treated.columns + or "gender_male" in data_treated.columns + ) + self.assertTrue( + "gender_female" in data_control.columns + or "gender_male" in data_control.columns + ) + + # Assert that the treated and control dataframes have the correct number of rows + self.assertEqual(len(data_treated), len(treated_set)) + self.assertEqual(len(data_control), len(control_set)) + + def test_match_controls(self): + neighbors_matching = NeighborsMatching(number_of_neighbors=1) + + control_set: Set[NodeIndex] = {"P1", "P3", "P5", "P7", "P9"} + treated_set: Set[NodeIndex] = {"P2", "P4", "P6"} + + matched_controls = neighbors_matching.match_controls( + medrecord=self.medrecord, + control_set=control_set, + treated_set=treated_set, + patients_group="patients", + essential_covariates=["age", "gender"], + one_hot_covariates=["gender"], + ) + + # Assert that the matched controls are a subset of the control set + self.assertTrue(matched_controls.issubset(control_set)) + + # Assert that the correct number of controls were matched + self.assertEqual(len(matched_controls), len(treated_set)) + + # Assert it works equally if no covariates are given (automatically assigned) + matched_controls_no_covariates_specified = neighbors_matching.match_controls( + medrecord=self.medrecord, + control_set=control_set, + treated_set=treated_set, + patients_group="patients", + ) + + self.assertTrue(matched_controls_no_covariates_specified.issubset(control_set)) + self.assertEqual( + len(matched_controls_no_covariates_specified), len(treated_set) + ) + + def test_check_nodes(self): + neighbors_matching = NeighborsMatching(number_of_neighbors=1) + + control_set: Set[NodeIndex] = {"P1", "P3", "P5", "P7", "P9"} + treated_set: Set[NodeIndex] = {"P2", "P4", "P6", "P8"} + + # Test valid case + valid_control_set = neighbors_matching._check_nodes( + medrecord=self.medrecord, + treated_set=treated_set, + control_set=control_set, + essential_covariates=["age", "gender"], + ) + self.assertEqual(valid_control_set, control_set) + + def test_invalid_check_nodes(self): + neighbors_matching = NeighborsMatching(number_of_neighbors=1) + + control_set: Set[NodeIndex] = {"P1", "P3", "P5", "P7", "P9"} + treated_set: Set[NodeIndex] = {"P2", "P4", "P6"} + + # Test insufficient control subjects + with self.assertRaises(ValueError) as context: + neighbors_matching._check_nodes( + medrecord=self.medrecord, + treated_set=treated_set, + control_set={"P1"}, + essential_covariates=["age", "gender"], + ) + self.assertEqual( + str(context.exception), + "Not enough control subjects to match the treated subjects", + ) + + with self.assertRaises(ValueError) as context: + neighbors_matching = NeighborsMatching(number_of_neighbors=2) + neighbors_matching._check_nodes( + medrecord=self.medrecord, + treated_set=treated_set, + control_set=control_set, + essential_covariates=["age", "gender"], + ) + self.assertEqual( + str(context.exception), + "Not enough control subjects to match the treated subjects", + ) + + # Test missing essential covariates in treated set + with self.assertRaises(ValueError) as context: + neighbors_matching._check_nodes( + medrecord=self.medrecord, + treated_set={"P2", "P10"}, + control_set=control_set, + essential_covariates=["age", "gender"], + ) + self.assertEqual( + str(context.exception), + "Some treated nodes do not have all the essential covariates", + ) + + def test_invalid_match_controls(self): + neighbors_matching = NeighborsMatching(number_of_neighbors=1) + + control_set: Set[NodeIndex] = {"P1", "P3", "P5", "P7", "P9"} + treated_set: Set[NodeIndex] = {"P2", "P4", "P6"} + + with self.assertRaisesRegex( + AssertionError, "One-hot covariates must be in the essential covariates" + ): + neighbors_matching.match_controls( + medrecord=self.medrecord, + control_set=control_set, + treated_set=treated_set, + patients_group="patients", + essential_covariates=["age"], + one_hot_covariates=["gender"], + ) + + +if __name__ == "__main__": + run_test = unittest.TestLoader().loadTestsFromTestCase(TestNeighborsMatching) + unittest.TextTestRunner(verbosity=2).run(run_test) diff --git a/medmodels/treatment_effect/matching/tests/test_propensity_score.py b/medmodels/treatment_effect/matching/tests/test_propensity_score.py index e505d6ab..f14f156c 100644 --- a/medmodels/treatment_effect/matching/tests/test_propensity_score.py +++ b/medmodels/treatment_effect/matching/tests/test_propensity_score.py @@ -12,14 +12,18 @@ def test_calculate_propensity(self): x, y = load_iris(return_X_y=True) # Set random state by each propensity estimator: - hyperparam = {"random_state": 1} - hyperparam_logit = {"random_state": 1, "max_iter": 200} + hyperparameters = {"random_state": 1} + hyperparameters_logit = {"random_state": 1, "max_iter": 200} x = np.array(x) y = np.array(y) # Logistic Regression model: result_1, result_2 = ps.calculate_propensity( - x, y, np.array([x[0, :]]), np.array([x[1, :]]), hyperparam=hyperparam_logit + x, + y, + np.array([x[0, :]]), + np.array([x[1, :]]), + hyperparameters=hyperparameters_logit, ) self.assertAlmostEqual(result_1[0], 1.43580537e-08, places=9) self.assertAlmostEqual(result_2[0], 3.00353141e-08, places=9) @@ -31,7 +35,7 @@ def test_calculate_propensity(self): np.array([x[0, :]]), np.array([x[1, :]]), model="dec_tree", - hyperparam=hyperparam, + hyperparameters=hyperparameters, ) self.assertAlmostEqual(result_1[0], 0, places=2) self.assertAlmostEqual(result_2[0], 0, places=2) @@ -43,15 +47,15 @@ def test_calculate_propensity(self): np.array([x[0, :]]), np.array([x[1, :]]), model="forest", - hyperparam=hyperparam, + hyperparameters=hyperparameters, ) self.assertAlmostEqual(result_1[0], 0, places=2) self.assertAlmostEqual(result_2[0], 0, places=2) def test_run_propensity_score(self): # Set random state by each propensity estimator: - hyperparam = {"random_state": 1} - hyperparam_logit = {"random_state": 1, "max_iter": 200} + hyperparameters = {"random_state": 1} + hyperparameters_logit = {"random_state": 1, "max_iter": 200} ########################################### # 1D example @@ -61,21 +65,21 @@ def test_run_propensity_score(self): # logit model expected_logit = pl.DataFrame({"a": [1.0, 3.0]}) result_logit = ps.run_propensity_score( - treated_set, control_set, hyperparam=hyperparam_logit + treated_set, control_set, hyperparameters=hyperparameters_logit ) self.assertTrue(result_logit.equals(expected_logit)) # dec_tree metric expected_logit = pl.DataFrame({"a": [1.0, 1.0]}) result_logit = ps.run_propensity_score( - treated_set, control_set, model="dec_tree", hyperparam=hyperparam + treated_set, control_set, model="dec_tree", hyperparameters=hyperparameters ) self.assertTrue(result_logit.equals(expected_logit)) # forest model expected_logit = pl.DataFrame({"a": [1.0, 1.0]}) result_logit = ps.run_propensity_score( - treated_set, control_set, model="forest", hyperparam=hyperparam + treated_set, control_set, model="forest", hyperparameters=hyperparameters ) self.assertTrue(result_logit.equals(expected_logit)) @@ -90,7 +94,10 @@ def test_run_propensity_score(self): # logit model expected_logit = pl.DataFrame({"a": [1.0], "b": [3.0], "c": [5.0]}) result_logit = ps.run_propensity_score( - treated_set, control_set, covariates=covs, hyperparam=hyperparam_logit + treated_set, + control_set, + covariates=covs, + hyperparameters=hyperparameters_logit, ) self.assertTrue(result_logit.equals(expected_logit)) @@ -101,7 +108,7 @@ def test_run_propensity_score(self): control_set, model="dec_tree", covariates=covs, - hyperparam=hyperparam, + hyperparameters=hyperparameters, ) self.assertTrue(result_logit.equals(expected_logit)) @@ -112,7 +119,7 @@ def test_run_propensity_score(self): control_set, model="forest", covariates=covs, - hyperparam=hyperparam, + hyperparameters=hyperparameters, ) self.assertTrue(result_logit.equals(expected_logit)) diff --git a/medmodels/treatment_effect/tests/test_treatment_effect.py b/medmodels/treatment_effect/tests/test_treatment_effect.py index 1d1d8539..5e72e0ba 100644 --- a/medmodels/treatment_effect/tests/test_treatment_effect.py +++ b/medmodels/treatment_effect/tests/test_treatment_effect.py @@ -268,7 +268,8 @@ def assert_treatment_effects_equal( treatment_effect2._matching_number_of_neighbors, ) test_case.assertEqual( - treatment_effect1._matching_hyperparam, treatment_effect2._matching_hyperparam + treatment_effect1._matching_hyperparameters, + treatment_effect2._matching_hyperparameters, ) diff --git a/medmodels/treatment_effect/treatment_effect.py b/medmodels/treatment_effect/treatment_effect.py index 169dc7b9..52ed8f24 100644 --- a/medmodels/treatment_effect/treatment_effect.py +++ b/medmodels/treatment_effect/treatment_effect.py @@ -52,11 +52,11 @@ class TreatmentEffect: _filter_controls_query: Optional[NodeQuery] _matching_method: Optional[MatchingMethod] - _matching_essential_covariates: MedRecordAttributeInputList - _matching_one_hot_covariates: MedRecordAttributeInputList + _matching_essential_covariates: Optional[MedRecordAttributeInputList] + _matching_one_hot_covariates: Optional[MedRecordAttributeInputList] _matching_model: Model _matching_number_of_neighbors: int - _matching_hyperparam: Optional[Dict[str, Any]] + _matching_hyperparameters: Optional[Dict[str, Any]] def __init__( self, @@ -93,11 +93,11 @@ def _set_configuration( outcome_before_treatment_days: Optional[int] = None, filter_controls_query: Optional[NodeQuery] = None, matching_method: Optional[MatchingMethod] = None, - matching_essential_covariates: MedRecordAttributeInputList = ["gender", "age"], - matching_one_hot_covariates: MedRecordAttributeInputList = ["gender"], + matching_essential_covariates: Optional[MedRecordAttributeInputList] = None, + matching_one_hot_covariates: Optional[MedRecordAttributeInputList] = None, matching_model: Model = "logit", matching_number_of_neighbors: int = 1, - matching_hyperparam: Optional[Dict[str, Any]] = None, + matching_hyperparameters: Optional[Dict[str, Any]] = None, ) -> None: """Initializes a Treatment Effect analysis setup with specified treatments and outcomes within a medical record dataset. @@ -131,17 +131,15 @@ def _set_configuration( Defaults to None. matching_method (Optional[MatchingMethod]): The method to match treatment and control groups. Defaults to None. - matching_essential_covariates (MedRecordAttributeInputList, optional): - The essential covariates to use for matching. Defaults to - ["gender", "age"]. - matching_one_hot_covariates (MedRecordAttributeInputList, optional): - The one-hot covariates to use for matching. Defaults to - ["gender"]. + matching_essential_covariates (Optional[MedRecordAttributeInputList]): + Covariates that are essential for matching. Defaults to None. + matching_one_hot_covariates (Optional[MedRecordAttributeInputList]): + Covariates that are one-hot encoded for matching. Defaults to None. matching_model (Model, optional): The model to use for matching. Defaults to "logit". matching_number_of_neighbors (int, optional): The number of neighbors to match for each treated subject. Defaults to 1. - matching_hyperparam (Optional[Dict[str, Any]], optional): The + matching_hyperparameters (Optional[Dict[str, Any]], optional): The hyperparameters for the matching model. Defaults to None. Raises: @@ -172,7 +170,7 @@ def _set_configuration( treatment_effect._matching_one_hot_covariates = matching_one_hot_covariates treatment_effect._matching_model = matching_model treatment_effect._matching_number_of_neighbors = matching_number_of_neighbors - treatment_effect._matching_hyperparam = matching_hyperparam + treatment_effect._matching_hyperparameters = matching_hyperparameters def _find_groups( self, medrecord: MedRecord