Skip to content

Commit

Permalink
Merge pull request #239 from DoubleML/s-ext-pred-benchmark
Browse files Browse the repository at this point in the history
Enable external predictions for short model in benchmark
  • Loading branch information
SvenKlaassen authored Apr 11, 2024
2 parents 8bc3d94 + 3769f81 commit ba9cc57
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 4 deletions.
10 changes: 8 additions & 2 deletions doubleml/double_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -1735,7 +1735,7 @@ def sensitivity_plot(self, idx_treatment=0, value='theta', include_scenario=True
fill=fill)
return fig

def sensitivity_benchmark(self, benchmarking_set):
def sensitivity_benchmark(self, benchmarking_set, fit_args=None):
"""
Computes a benchmark for a given set of features.
Returns a DataFrame containing the corresponding values for cf_y, cf_d, rho and the change in estimates.
Expand All @@ -1757,12 +1757,18 @@ def sensitivity_benchmark(self, benchmarking_set):
if not set(benchmarking_set) <= set(x_list_long):
raise ValueError(f"benchmarking_set must be a subset of features {str(self._dml_data.x_cols)}. "
f'{str(benchmarking_set)} was passed.')
if fit_args is not None and not isinstance(fit_args, dict):
raise TypeError('fit_args must be a dict. '
f'{str(fit_args)} of type {type(fit_args)} was passed.')

# refit short form of the model
x_list_short = [x for x in x_list_long if x not in benchmarking_set]
dml_short = copy.deepcopy(self)
dml_short._dml_data.x_cols = x_list_short
dml_short.fit()
if fit_args is not None:
dml_short.fit(**fit_args)
else:
dml_short.fit()

benchmark_dict = gain_statistics(dml_long=self, dml_short=dml_short)
df_benchmark = pd.DataFrame(benchmark_dict, index=self._dml_data.d_cols)
Expand Down
15 changes: 14 additions & 1 deletion doubleml/tests/test_exceptions_ext_preds.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import pytest
from doubleml import DoubleMLCVAR, DoubleMLQTE, DoubleMLData
from doubleml import DoubleMLCVAR, DoubleMLQTE, DoubleMLIRM, DoubleMLData
from doubleml.datasets import make_irm_data
from doubleml.utils import DMLDummyRegressor, DMLDummyClassifier

from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier

df_irm = make_irm_data(n_obs=10, dim_x=2, theta=0.5, return_type="DataFrame")
ext_predictions = {"d": {}}

Expand All @@ -21,3 +23,14 @@ def test_qte_external_prediction_exception():
with pytest.raises(NotImplementedError, match=msg):
qte = DoubleMLQTE(DoubleMLData(df_irm, "y", "d"), DMLDummyClassifier(), DMLDummyClassifier())
qte.fit(external_predictions=ext_predictions)


@pytest.mark.ci
def test_sensitivity_benchmark_external_prediction_exception():
msg = "fit_args must be a dict. "
with pytest.raises(TypeError, match=msg):
fit_args = []
irm = DoubleMLIRM(DoubleMLData(df_irm, "y", "d"), RandomForestRegressor(), RandomForestClassifier())
irm.fit()
irm.sensitivity_analysis()
irm.sensitivity_benchmark(benchmarking_set=["X1"], fit_args=fit_args)
63 changes: 62 additions & 1 deletion doubleml/tests/test_sensitivity.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,21 @@
import pytest
import numpy as np
import copy

import doubleml as dml
from sklearn.linear_model import LinearRegression
from doubleml import DoubleMLIRM, DoubleMLData
from doubleml.datasets import make_irm_data
from sklearn.linear_model import LinearRegression, LogisticRegression

from ._utils_doubleml_sensitivity_manual import doubleml_sensitivity_manual, \
doubleml_sensitivity_benchmark_manual


@pytest.fixture(scope="module", params=[["X1"], ["X2"], ["X3"]])
def benchmarking_set(request):
return request.param


@pytest.fixture(scope='module',
params=[1, 3])
def n_rep(request):
Expand Down Expand Up @@ -99,3 +107,56 @@ def test_dml_sensitivity_benchmark(dml_sensitivity_multitreat_fixture):
assert all(dml_sensitivity_multitreat_fixture['benchmark'].index ==
dml_sensitivity_multitreat_fixture['d_cols'])
assert dml_sensitivity_multitreat_fixture['benchmark'].equals(dml_sensitivity_multitreat_fixture['benchmark_manual'])


@pytest.fixture(scope="module")
def test_dml_benchmark_fixture(benchmarking_set, n_rep):
random_state = 42
x, y, d = make_irm_data(n_obs=50, dim_x=5, theta=0, return_type="np.array")

classifier_class = LogisticRegression
regressor_class = LinearRegression

np.random.seed(3141)
dml_data = DoubleMLData.from_arrays(x=x, y=y, d=d)
x_list_long = copy.deepcopy(dml_data.x_cols)
dml_int = DoubleMLIRM(dml_data,
ml_m=classifier_class(random_state=random_state),
ml_g=regressor_class(),
n_folds=2,
n_rep=n_rep)
dml_int.fit(store_predictions=True)
dml_int.sensitivity_analysis()
dml_ext = copy.deepcopy(dml_int)
df_bm = dml_int.sensitivity_benchmark(benchmarking_set=benchmarking_set)

np.random.seed(3141)
dml_data_short = DoubleMLData.from_arrays(x=x, y=y, d=d)
dml_data_short.x_cols = [x for x in x_list_long if x not in benchmarking_set]
dml_short = DoubleMLIRM(dml_data_short,
ml_m=classifier_class(random_state=random_state),
ml_g=regressor_class(),
n_folds=2,
n_rep=n_rep)
dml_short.fit(store_predictions=True)
fit_args = {"external_predictions": {"d": {"ml_m": dml_short.predictions["ml_m"][:, :, 0],
"ml_g0": dml_short.predictions["ml_g0"][:, :, 0],
"ml_g1": dml_short.predictions["ml_g1"][:, :, 0],
}
},
}
dml_ext.sensitivity_analysis()
df_bm_ext = dml_ext.sensitivity_benchmark(benchmarking_set=benchmarking_set, fit_args=fit_args)

res_dict = {"default_benchmark": df_bm,
"external_benchmark": df_bm_ext}

return res_dict


@pytest.mark.ci
def test_dml_sensitivity_external_predictions(test_dml_benchmark_fixture):
assert np.allclose(test_dml_benchmark_fixture["default_benchmark"],
test_dml_benchmark_fixture["external_benchmark"],
rtol=1e-9,
atol=1e-4)

0 comments on commit ba9cc57

Please sign in to comment.