diff --git a/sbi/utils/__init__.py b/sbi/utils/__init__.py index 93a301580..621758232 100644 --- a/sbi/utils/__init__.py +++ b/sbi/utils/__init__.py @@ -69,4 +69,3 @@ validate_theta_and_x, ) from sbi.utils.user_input_checks_utils import MultipleIndependent -from sbi.utils.get_nn_models import posterior_nn, likelihood_nn, classifier_nn diff --git a/sbi/utils/get_nn_models.py b/sbi/utils/get_nn_models.py deleted file mode 100644 index 7267dfa10..000000000 --- a/sbi/utils/get_nn_models.py +++ /dev/null @@ -1,112 +0,0 @@ -# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed -# under the Apache License Version 2.0, see - -from typing import Any, Callable, Optional -from warnings import warn - -from torch import nn - -from sbi.neural_nets.factory import classifier_nn as classifier_nn_moved_to_neural_nets -from sbi.neural_nets.factory import likelihood_nn as likelihood_nn_moved_to_neural_nets -from sbi.neural_nets.factory import posterior_nn as posterior_nn_moved_to_neural_nets - - -def classifier_nn( - model: str, - z_score_theta: Optional[str] = "independent", - z_score_x: Optional[str] = "independent", - hidden_features: int = 50, - embedding_net_theta: nn.Module = nn.Identity(), - embedding_net_x: nn.Module = nn.Identity(), - **kwargs: Any, -) -> Callable: - r"""This method is deprecated and will be removed in a future release. - Please use `from sbi.neural_nets import classifier_nn` in the future. - """ - - warn( - "This method is deprecated and will be removed in a future release." - "Please use `from sbi.neural_nets import classifier_nn` in the future.", - DeprecationWarning, - stacklevel=2, - ) - - return classifier_nn_moved_to_neural_nets( - model, - z_score_theta, - z_score_x, - hidden_features, - embedding_net_theta, - embedding_net_x, - **kwargs, - ) - - -def likelihood_nn( - model: str, - z_score_theta: Optional[str] = "independent", - z_score_x: Optional[str] = "independent", - hidden_features: int = 50, - num_transforms: int = 5, - num_bins: int = 10, - embedding_net: nn.Module = nn.Identity(), - num_components: int = 10, - **kwargs: Any, -) -> Callable: - r"""This method is deprecated and will be removed in a future release. - Please use `from sbi.neural_nets import likelihood_nn` in the future. - """ - - warn( - "This method is deprecated and will be removed in a future release. " - "Please use `from sbi.neural_nets import likelihood_nn` in the future.", - DeprecationWarning, - stacklevel=2, - ) - - return likelihood_nn_moved_to_neural_nets( - model, - z_score_theta, - z_score_x, - hidden_features, - num_transforms, - num_bins, - embedding_net, - num_components, - **kwargs, - ) - - -def posterior_nn( - model: str, - z_score_theta: Optional[str] = "independent", - z_score_x: Optional[str] = "independent", - hidden_features: int = 50, - num_transforms: int = 5, - num_bins: int = 10, - embedding_net: nn.Module = nn.Identity(), - num_components: int = 10, - **kwargs: Any, -) -> Callable: - r"""This method is deprecated and will be removed in a future release. - Please use `from sbi.neural_nets import posterior_nn` in the future. - """ - - warn( - "This method is deprecated and will be removed in a future release." - "Please use `from sbi.neural_nets import posterior_nn` in the future.", - DeprecationWarning, - stacklevel=2, - ) - - return posterior_nn_moved_to_neural_nets( - model, - z_score_theta, - z_score_x, - hidden_features, - num_transforms, - num_bins, - embedding_net, - num_components, - **kwargs, - ) diff --git a/tests/neural_nets_factory.py b/tests/neural_nets_factory.py deleted file mode 100644 index af8c4d4a1..000000000 --- a/tests/neural_nets_factory.py +++ /dev/null @@ -1,44 +0,0 @@ -import pytest - -from sbi.utils.get_nn_models import classifier_nn, likelihood_nn, posterior_nn - -models_to_test = [ - "mdn", - "made", - "maf", - "maf_rqs", - "nsf", - "mnle", - "zuko_bpf", - "zuko_gf", - "zuko_maf", - "zuko_naf", - "zuko_ncsf", - "zuko_nice", - "zuko_nsf", - "zuko_sospf", - "zuko_unaf", -] - - -@pytest.mark.parametrize( - "model", ["linear", "mlp", "resnet"], ids=["linear", "mlp", "resnet"] -) -def test_deprecated_import_classifier_nn(model: str): - with pytest.warns(DeprecationWarning): - build_fcn = classifier_nn(model) - assert callable(build_fcn) - - -@pytest.mark.parametrize("model", models_to_test, ids=models_to_test) -def test_deprecated_import_likelihood_nn(model: str): - with pytest.warns(DeprecationWarning): - build_fcn = likelihood_nn(model) - assert callable(build_fcn) - - -@pytest.mark.parametrize("model", models_to_test, ids=models_to_test) -def test_deprecated_import_posterior_nn(model: str): - with pytest.warns(DeprecationWarning): - build_fcn = posterior_nn(model) - assert callable(build_fcn)