Skip to content

Commit

Permalink
fix docstrings about condition sample_dim.
Browse files Browse the repository at this point in the history
  • Loading branch information
janfb committed Dec 20, 2024
1 parent d3f22b5 commit c56056f
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 18 deletions.
24 changes: 12 additions & 12 deletions sbi/neural_nets/estimators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,15 +119,15 @@ def _check_input_shape(self, input: Tensor):
class ConditionalDensityEstimator(ConditionalEstimator):
r"""Base class for density estimators.
The density estimator class is a wrapper around neural networks that
allows to evaluate the `log_prob`, `sample`, and provide the `loss` of $\theta,x$
pairs. Here $\theta$ would be the `input` and $x$ would be the `condition`.
The density estimator class is a wrapper around neural networks that allows to
evaluate the `log_prob`, `sample`, and provide the `loss` of $\theta,x$ pairs. Here
$\theta$ would be the `input` and $x$ would be the `condition`.
Note:
We assume that the input to the density estimator is a tensor of shape
(batch_size, input_size), where input_size is the dimensionality of the input.
The condition is a tensor of shape (batch_size, *condition_shape), where
condition_shape is the shape of the condition tensor.
(sample_dim, batch_dim, *input_shape), where input_shape is the dimensionality
of the input. The condition is a tensor of shape (batch_size, *condition_shape),
where condition_shape is the shape of the condition tensor.
"""

Expand Down Expand Up @@ -226,15 +226,15 @@ def sample_and_log_prob(
class ConditionalVectorFieldEstimator(ConditionalEstimator):
r"""Base class for vector field (e.g., score and ODE flow) estimators.
The density estimator class is a wrapper around neural networks that
allows to evaluate the `vector_field`, and provide the `loss` of $\theta,x$
pairs. Here $\theta$ would be the `input` and $x$ would be the `condition`.
The vector field estimator class is a wrapper around neural networks that allows to
evaluate the `vector_field`, and provide the `loss` of $\theta,x$ pairs. Here
$\theta$ would be the `input` and $x$ would be the `condition`.
Note:
We assume that the input to the density estimator is a tensor of shape
(batch_size, input_size), where input_size is the dimensionality of the input.
The condition is a tensor of shape (batch_size, *condition_shape), where
condition_shape is the shape of the condition tensor.
(sample_dim, batch_dim, *input_shape), where input_shape is the dimensionality
of the input. The condition is a tensor of shape (batch_dim, *condition_shape),
where condition_shape is the shape of the condition tensor.
"""

Expand Down
6 changes: 3 additions & 3 deletions sbi/neural_nets/estimators/nflows_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def log_prob(self, input: Tensor, condition: Tensor) -> Tensor:
Args:
input: Inputs to evaluate the log probability on. Of shape
`(sample_dim, batch_dim, *event_shape)`.
condition: Conditions of shape `(sample_dim, batch_dim, *event_shape)`.
condition: Conditions of shape `(batch_dim, *event_shape)`.
Raises:
AssertionError: If `input_batch_dim != condition_batch_dim`.
Expand Down Expand Up @@ -126,7 +126,7 @@ def sample(self, sample_shape: Shape, condition: Tensor) -> Tensor:
Args:
sample_shape: Shape of the samples to return.
condition: Conditions of shape `(sample_dim, batch_dim, *event_shape)`.
condition: Conditions of shape `(batch_dim, *event_shape)`.
Returns:
Samples of shape `(*sample_shape, condition_batch_dim)`.
Expand All @@ -147,7 +147,7 @@ def sample_and_log_prob(
Args:
sample_shape: Shape of the samples to return.
condition: Conditions of shape (sample_dim, batch_dim, *event_shape).
condition: Conditions of shape (batch_dim, *event_shape).
Returns:
Samples of shape `(*sample_shape, condition_batch_dim, *input_event_shape)`
Expand Down
4 changes: 1 addition & 3 deletions sbi/neural_nets/estimators/zuko_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,7 @@ def log_prob(self, input: Tensor, condition: Tensor) -> Tensor:
Args:
input: Inputs to evaluate the log probability on. Of shape
`(sample_dim, batch_dim, *event_shape)`.
# TODO: the docstring is not correct here. in the code it seems we
do not have a sample_dim for the condition.
condition: Conditions of shape `(sample_dim, batch_dim, *event_shape)`.
condition: Conditions of shape `(batch_dim, *event_shape)`.
Raises:
AssertionError: If `input_batch_dim != condition_batch_dim`.
Expand Down

0 comments on commit c56056f

Please sign in to comment.