From 81ce2c9623bd3f5ee8b7e3acc7bc8f6646f6987b Mon Sep 17 00:00:00 2001 From: michaeldeistler Date: Wed, 10 Apr 2024 09:10:20 +0200 Subject: [PATCH] Fix circular imports --- sbi/analysis/conditional_density.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sbi/analysis/conditional_density.py b/sbi/analysis/conditional_density.py index 6d8d1ca44..f672ef3b0 100644 --- a/sbi/analysis/conditional_density.py +++ b/sbi/analysis/conditional_density.py @@ -10,7 +10,6 @@ from torch import Tensor from torch.distributions import Distribution -from sbi.neural_nets.density_estimators.nflows_flow import NFlowsFlow from sbi.sbi_types import Shape, TorchTransform from sbi.utils.conditional_density_utils import ( ConditionedPotential, @@ -186,7 +185,7 @@ def conditional_corrcoeff( class ConditionedMDN: def __init__( self, - mdn: NFlowsFlow, + mdn, x_o: Tensor, condition: Tensor, dims_to_sample: List[int], @@ -194,8 +193,9 @@ def __init__( r"""Class that can sample and evaluate a conditional mixture-of-gaussians. Args: - mdn: Mixture density network that models $p(\theta|x). We use the normflows - implementation of MDNs. + mdn Mixture density network that models $p(\theta|x). We use the normflows + implementation of MDNs. Type is `NFlowsFlow`, type hint removed to + avoid circular import, see #1140. x_o: The datapoint at which the `net` is evaluated. condition: Parameter set that all dimensions not specified in `dims_to_sample` will be fixed to. Should contain dim_theta elements,