Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: the way essential covariates and one-hot-encoded ones are processed #256

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 19 additions & 23 deletions medmodels/treatment_effect/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class TreatmentEffectBuilder:
matching_one_hot_covariates: Optional[MedRecordAttributeInputList]
matching_model: Optional[Model]
matching_number_of_neighbors: Optional[int]
matching_hyperparam: Optional[Dict[str, Any]]
matching_hyperparameters: Optional[Dict[str, Any]]

def with_treatment(self, treatment: Group) -> TreatmentEffectBuilder:
"""Sets the treatment group for the treatment effect estimation.
Expand Down Expand Up @@ -218,27 +218,25 @@ def filter_controls(self, query: NodeQuery) -> TreatmentEffectBuilder:

def with_propensity_matching(
self,
essential_covariates: MedRecordAttributeInputList = ["gender", "age"],
one_hot_covariates: MedRecordAttributeInputList = ["gender"],
essential_covariates: Optional[MedRecordAttributeInputList] = None,
one_hot_covariates: Optional[MedRecordAttributeInputList] = None,
model: Model = "logit",
number_of_neighbors: int = 1,
hyperparam: Optional[Dict[str, Any]] = None,
hyperparameters: Optional[Dict[str, Any]] = None,
) -> TreatmentEffectBuilder:
"""Adjust the treatment effect estimate using propensity score matching.

Args:
essential_covariates (MedRecordAttributeInputList, optional):
Covariates that are essential for matching. Defaults to
["gender", "age"].
one_hot_covariates (MedRecordAttributeInputList, optional):
Covariates that are one-hot encoded for matching. Defaults to
["gender"].
essential_covariates (Optional[MedRecordAttributeInputList]):
Covariates that are essential for matching. Defaults to None.
one_hot_covariates (Optional[MedRecordAttributeInputList]):
Covariates that are one-hot encoded for matching. Defaults to None.
model (Model, optional): Model to choose for the matching. Defaults to
"logit".
number_of_neighbors (int, optional): Number of neighbors to consider
for the matching. Defaults to 1.
hyperparam (Optional[Dict[str, Any]], optional): Hyperparameters for the
matching model. Defaults to None.
hyperparameters (Optional[Dict[str, Any]], optional): Hyperparameters for
the matching model. Defaults to None.

Returns:
TreatmentEffectBuilder: The current instance of the TreatmentEffectBuilder
Expand All @@ -249,27 +247,25 @@ def with_propensity_matching(
self.matching_one_hot_covariates = one_hot_covariates
self.matching_model = model
self.matching_number_of_neighbors = number_of_neighbors
self.matching_hyperparam = hyperparam
self.matching_hyperparameters = hyperparameters

return self

def with_nearest_neighbors_matching(
self,
essential_covariates: MedRecordAttributeInputList = ["gender", "age"],
one_hot_covariates: MedRecordAttributeInputList = ["gender"],
essential_covariates: Optional[MedRecordAttributeInputList] = None,
one_hot_covariates: Optional[MedRecordAttributeInputList] = None,
number_of_neighbors: int = 1,
) -> TreatmentEffectBuilder:
"""Adjust the treatment effect estimate using nearest neighbors matching.

Args:
essential_covariates (MedRecordAttributeInputList, optional):
Covariates that are essential for matching. Defaults to
["gender", "age"].
one_hot_covariates (MedRecordAttributeInputList, optional):
Covariates that are one-hot encoded for matching. Defaults to
["gender"].
number_of_neighbors (int, optional): Number of neighbors to consider for the
matching. Defaults to 1.
essential_covariates (Optional[MedRecordAttributeInputList]):
Covariates that are essential for matching. Defaults to None.
one_hot_covariates (Optional[MedRecordAttributeInputList]):
Covariates that are one-hot encoded for matching. Defaults to None.
number_of_neighbors (int, optional): Number of neighbors to consider for
the matching. Defaults to 1.

Returns:
TreatmentEffectBuilder: The current instance of the TreatmentEffectBuilder
Expand Down
3 changes: 2 additions & 1 deletion medmodels/treatment_effect/estimate.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def _sort_subjects_in_groups(
else PropensityMatching(
number_of_neighbors=self._treatment_effect._matching_number_of_neighbors,
model=self._treatment_effect._matching_model,
hyperparam=self._treatment_effect._matching_hyperparam,
hyperparameters=self._treatment_effect._matching_hyperparameters,
)
)

Expand All @@ -191,6 +191,7 @@ def _sort_subjects_in_groups(
medrecord=medrecord,
treated_set=treated_set,
control_set=control_set,
patients_group=self._treatment_effect._patients_group,
essential_covariates=self._treatment_effect._matching_essential_covariates,
one_hot_covariates=self._treatment_effect._matching_one_hot_covariates,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def calculate_propensity(
treated_test: NDArray[Union[np.int64, np.float64]],
control_test: NDArray[Union[np.int64, np.float64]],
model: Model = "logit",
hyperparam: Optional[Dict[str, Any]] = None,
hyperparameters: Optional[Dict[str, Any]] = None,
) -> Tuple[NDArray[np.float64], NDArray[np.float64]]:
"""Trains a classification algorithm on training data, predicts the probability of being in the last class for treated and control test datasets, and returns these probabilities.

Expand All @@ -49,8 +49,8 @@ def calculate_propensity(
control group to predict probabilities.
model (Model, optional): Classification algorithm to use. Options: "logit",
"dec_tree", "forest".
hyperparam (Optional[Dict[str, Any]], optional): Manual hyperparameter settings.
Uses default if None.
hyperparameters (Optional[Dict[str, Any]], optional): Manual hyperparameter
settings. Uses default hyperparameters if None.

Returns:
Tuple[NDArray[np.float64], NDArray[np.float64]: Probabilities of the positive
Expand All @@ -61,7 +61,7 @@ class for treated and control groups.
last class for treated and control sets, e.g., ([0.], [0.]).
"""
propensity_model = PROP_MODEL[model]
pm = propensity_model(**hyperparam) if hyperparam else propensity_model()
pm = propensity_model(**hyperparameters) if hyperparameters else propensity_model()
pm.fit(x_train, y_train)

# Predict the probability of the treated and control groups
Expand All @@ -76,7 +76,7 @@ def run_propensity_score(
control_set: pl.DataFrame,
model: Model = "logit",
number_of_neighbors: int = 1,
hyperparam: Optional[Dict[str, Any]] = None,
hyperparameters: Optional[Dict[str, Any]] = None,
covariates: Optional[MedRecordAttributeInputList] = None,
) -> pl.DataFrame:
"""Executes Propensity Score matching using a specified classification algorithm.
Expand All @@ -95,7 +95,7 @@ def run_propensity_score(
Options include "logit", "dec_tree", "forest".
number_of_neighbors (int, optional): Number of nearest neighbors to find for
each treated unit. Defaults to 1.
hyperparam (Optional[Dict[str, Any]], optional): Hyperparameters for model
hyperparameters (Optional[Dict[str, Any]], optional): Hyperparameters for model
tuning. Increases computation time if set. Uses default if None.
covariates (Optional[MedRecordAttributeInputList], optional): Features for
matching. Uses all if None.
Expand All @@ -119,7 +119,7 @@ def run_propensity_score(
y_train,
treated_array,
control_array,
hyperparam=hyperparam,
hyperparameters=hyperparameters,
model=model,
)

Expand Down
129 changes: 117 additions & 12 deletions medmodels/treatment_effect/matching/matching.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from __future__ import annotations

from abc import ABCMeta, abstractmethod
from typing import TYPE_CHECKING, Literal, Set, Tuple
from typing import TYPE_CHECKING, Literal, Optional, Set, Tuple

import polars as pl

from medmodels.medrecord._overview import extract_attribute_summary
from medmodels.medrecord.medrecord import MedRecord
from medmodels.medrecord.types import MedRecordAttributeInputList, NodeIndex
from medmodels.medrecord.querying import NodeOperand
from medmodels.medrecord.types import Group, MedRecordAttributeInputList, NodeIndex

if TYPE_CHECKING:
import sys
Expand All @@ -22,36 +24,68 @@
class Matching(metaclass=ABCMeta):
"""The Base Class for matching."""

number_of_neighbors: int

def __init__(self, number_of_neighbors: int) -> None:
"""Initializes the matching class.

Args:
number_of_neighbors (int): Number of nearest neighbors to find for each treated unit.
"""
self.number_of_neighbors = number_of_neighbors

def _preprocess_data(
self,
*,
medrecord: MedRecord,
control_set: Set[NodeIndex],
treated_set: Set[NodeIndex],
essential_covariates: MedRecordAttributeInputList,
one_hot_covariates: MedRecordAttributeInputList,
patients_group: Group,
essential_covariates: Optional[MedRecordAttributeInputList] = None,
one_hot_covariates: Optional[MedRecordAttributeInputList] = None,
) -> Tuple[pl.DataFrame, pl.DataFrame]:
"""Prepared the data for the matching algorithms.

Args:
medrecord (MedRecord): MedRecord object containing the data.
control_set (Set[NodeIndex]): Set of treated subjects.
treated_set (Set[NodeIndex]): Set of control subjects.
essential_covariates (MedRecordAttributeInputList): Covariates
that are essential for matching
one_hot_covariates (MedRecordAttributeInputList): Covariates that
are one-hot encoded for matching
patients_group (Group): The group of patients.
essential_covariates (Optional[MedRecordAttributeInputList]):
Covariates that are essential for matching. Defaults to None.
one_hot_covariates (Optional[MedRecordAttributeInputList]):
Covariates that are one-hot encoded for matching. Defaults to None.

Returns:
Tuple[pl.DataFrame, pl.DataFrame]: Treated and control groups with their
preprocessed covariates

Raises:
ValueError: If not enough control subjects to match the treated subjects.
ValueError: If some treated nodes do not have all the essential covariates.
AssertionError: If the one-hot covariates are not in the essential covariates.
"""
essential_covariates = [str(covariate) for covariate in essential_covariates]
if essential_covariates is None:
# If no essential covariates are provided, use all the attributes of the patients
essential_covariates = list(
extract_attribute_summary(
medrecord.node[medrecord.nodes_in_group(patients_group)]
)
)
else:
essential_covariates = [covariate for covariate in essential_covariates]

control_set = self._check_nodes(
medrecord=medrecord,
treated_set=treated_set,
control_set=control_set,
essential_covariates=essential_covariates,
)

if "id" not in essential_covariates:
essential_covariates.append("id")

# Dataframe
# Dataframe wth the essential covariates
data = pl.DataFrame(
data=[
{"id": k, **v}
Expand All @@ -60,6 +94,24 @@ def _preprocess_data(
)
original_columns = data.columns

if one_hot_covariates is None:
# If no one-hot covariates are provided, use all the categorical attributes of the patients
attributes = extract_attribute_summary(
medrecord.node[medrecord.nodes_in_group(patients_group)]
)
one_hot_covariates = [
covariate
for covariate, values in attributes.items()
if "values" in values
]

if not all(
covariate in essential_covariates for covariate in one_hot_covariates
):
raise AssertionError(
"One-hot covariates must be in the essential covariates"
)

# One-hot encode the categorical variables
data = data.to_dummies(
columns=[str(covariate) for covariate in one_hot_covariates],
Expand All @@ -79,13 +131,66 @@ def _preprocess_data(

return data_treated, data_control

def _check_nodes(
self,
medrecord: MedRecord,
treated_set: Set[NodeIndex],
control_set: Set[NodeIndex],
essential_covariates: MedRecordAttributeInputList,
) -> Set[NodeIndex]:
"""Check if the treated and control sets are disjoint.

Args:
medrecord (MedRecord): MedRecord object containing the data.
treated_set (Set[NodeIndex]): Set of treated subjects.
control_set (Set[NodeIndex]): Set of control subjects.
essential_covariates (MedRecordAttributeInputList): Covariates that are
essential for matching.

Returns:
Set[NodeIndex]: The control set.

Raises:
ValueError: If not enough control subjects to match the treated subjects.
"""

def query_essential_covariates(
node: NodeOperand, patients_set: Set[NodeIndex]
) -> None:
"""Query the nodes that have all the essential covariates."""
for attribute in essential_covariates:
node.has_attribute(attribute)

node.index().is_in(list(patients_set))

control_set = set(
medrecord.select_nodes(
lambda node: query_essential_covariates(node, control_set)
)
)
if len(control_set) < self.number_of_neighbors * len(treated_set):
raise ValueError(
"Not enough control subjects to match the treated subjects"
)

if len(treated_set) != len(
medrecord.select_nodes(
lambda node: query_essential_covariates(node, treated_set)
)
):
raise ValueError(
"Some treated nodes do not have all the essential covariates"
)

return control_set

@abstractmethod
def match_controls(
self,
*,
control_set: Set[NodeIndex],
treated_set: Set[NodeIndex],
medrecord: MedRecord,
essential_covariates: MedRecordAttributeInputList = ["gender", "age"],
one_hot_covariates: MedRecordAttributeInputList = ["gender"],
essential_covariates: Optional[MedRecordAttributeInputList] = None,
one_hot_covariates: Optional[MedRecordAttributeInputList] = None,
) -> Set[NodeIndex]: ...
Loading