diff --git a/doubleml/double_ml.py b/doubleml/double_ml.py index b47a5ace..8ea773a4 100644 --- a/doubleml/double_ml.py +++ b/doubleml/double_ml.py @@ -1735,7 +1735,7 @@ def sensitivity_plot(self, idx_treatment=0, value='theta', include_scenario=True fill=fill) return fig - def sensitivity_benchmark(self, benchmarking_set): + def sensitivity_benchmark(self, benchmarking_set, fit_args=None): """ Computes a benchmark for a given set of features. Returns a DataFrame containing the corresponding values for cf_y, cf_d, rho and the change in estimates. @@ -1757,12 +1757,18 @@ def sensitivity_benchmark(self, benchmarking_set): if not set(benchmarking_set) <= set(x_list_long): raise ValueError(f"benchmarking_set must be a subset of features {str(self._dml_data.x_cols)}. " f'{str(benchmarking_set)} was passed.') + if fit_args is not None and not isinstance(fit_args, dict): + raise TypeError('fit_args must be a dict. ' + f'{str(fit_args)} of type {type(fit_args)} was passed.') # refit short form of the model x_list_short = [x for x in x_list_long if x not in benchmarking_set] dml_short = copy.deepcopy(self) dml_short._dml_data.x_cols = x_list_short - dml_short.fit() + if fit_args is not None: + dml_short.fit(**fit_args) + else: + dml_short.fit() benchmark_dict = gain_statistics(dml_long=self, dml_short=dml_short) df_benchmark = pd.DataFrame(benchmark_dict, index=self._dml_data.d_cols) diff --git a/doubleml/tests/test_exceptions_ext_preds.py b/doubleml/tests/test_exceptions_ext_preds.py index 395d8bf5..4a61361d 100644 --- a/doubleml/tests/test_exceptions_ext_preds.py +++ b/doubleml/tests/test_exceptions_ext_preds.py @@ -1,8 +1,10 @@ import pytest -from doubleml import DoubleMLCVAR, DoubleMLQTE, DoubleMLData +from doubleml import DoubleMLCVAR, DoubleMLQTE, DoubleMLIRM, DoubleMLData from doubleml.datasets import make_irm_data from doubleml.utils import DMLDummyRegressor, DMLDummyClassifier +from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier + df_irm = make_irm_data(n_obs=10, dim_x=2, theta=0.5, return_type="DataFrame") ext_predictions = {"d": {}} @@ -21,3 +23,14 @@ def test_qte_external_prediction_exception(): with pytest.raises(NotImplementedError, match=msg): qte = DoubleMLQTE(DoubleMLData(df_irm, "y", "d"), DMLDummyClassifier(), DMLDummyClassifier()) qte.fit(external_predictions=ext_predictions) + + +@pytest.mark.ci +def test_sensitivity_benchmark_external_prediction_exception(): + msg = "fit_args must be a dict. " + with pytest.raises(TypeError, match=msg): + fit_args = [] + irm = DoubleMLIRM(DoubleMLData(df_irm, "y", "d"), RandomForestRegressor(), RandomForestClassifier()) + irm.fit() + irm.sensitivity_analysis() + irm.sensitivity_benchmark(benchmarking_set=["X1"], fit_args=fit_args) diff --git a/doubleml/tests/test_sensitivity.py b/doubleml/tests/test_sensitivity.py index b1277b78..9c9ca9f3 100644 --- a/doubleml/tests/test_sensitivity.py +++ b/doubleml/tests/test_sensitivity.py @@ -1,13 +1,21 @@ import pytest import numpy as np +import copy import doubleml as dml -from sklearn.linear_model import LinearRegression +from doubleml import DoubleMLIRM, DoubleMLData +from doubleml.datasets import make_irm_data +from sklearn.linear_model import LinearRegression, LogisticRegression from ._utils_doubleml_sensitivity_manual import doubleml_sensitivity_manual, \ doubleml_sensitivity_benchmark_manual +@pytest.fixture(scope="module", params=[["X1"], ["X2"], ["X3"]]) +def benchmarking_set(request): + return request.param + + @pytest.fixture(scope='module', params=[1, 3]) def n_rep(request): @@ -99,3 +107,56 @@ def test_dml_sensitivity_benchmark(dml_sensitivity_multitreat_fixture): assert all(dml_sensitivity_multitreat_fixture['benchmark'].index == dml_sensitivity_multitreat_fixture['d_cols']) assert dml_sensitivity_multitreat_fixture['benchmark'].equals(dml_sensitivity_multitreat_fixture['benchmark_manual']) + + +@pytest.fixture(scope="module") +def test_dml_benchmark_fixture(benchmarking_set, n_rep): + random_state = 42 + x, y, d = make_irm_data(n_obs=50, dim_x=5, theta=0, return_type="np.array") + + classifier_class = LogisticRegression + regressor_class = LinearRegression + + np.random.seed(3141) + dml_data = DoubleMLData.from_arrays(x=x, y=y, d=d) + x_list_long = copy.deepcopy(dml_data.x_cols) + dml_int = DoubleMLIRM(dml_data, + ml_m=classifier_class(random_state=random_state), + ml_g=regressor_class(), + n_folds=2, + n_rep=n_rep) + dml_int.fit(store_predictions=True) + dml_int.sensitivity_analysis() + dml_ext = copy.deepcopy(dml_int) + df_bm = dml_int.sensitivity_benchmark(benchmarking_set=benchmarking_set) + + np.random.seed(3141) + dml_data_short = DoubleMLData.from_arrays(x=x, y=y, d=d) + dml_data_short.x_cols = [x for x in x_list_long if x not in benchmarking_set] + dml_short = DoubleMLIRM(dml_data_short, + ml_m=classifier_class(random_state=random_state), + ml_g=regressor_class(), + n_folds=2, + n_rep=n_rep) + dml_short.fit(store_predictions=True) + fit_args = {"external_predictions": {"d": {"ml_m": dml_short.predictions["ml_m"][:, :, 0], + "ml_g0": dml_short.predictions["ml_g0"][:, :, 0], + "ml_g1": dml_short.predictions["ml_g1"][:, :, 0], + } + }, + } + dml_ext.sensitivity_analysis() + df_bm_ext = dml_ext.sensitivity_benchmark(benchmarking_set=benchmarking_set, fit_args=fit_args) + + res_dict = {"default_benchmark": df_bm, + "external_benchmark": df_bm_ext} + + return res_dict + + +@pytest.mark.ci +def test_dml_sensitivity_external_predictions(test_dml_benchmark_fixture): + assert np.allclose(test_dml_benchmark_fixture["default_benchmark"], + test_dml_benchmark_fixture["external_benchmark"], + rtol=1e-9, + atol=1e-4)