Skip to content

Commit

Permalink
axis should be int or None
Browse files Browse the repository at this point in the history
  • Loading branch information
rdyro committed Jan 2, 2025
1 parent 1e08bcc commit 9cc3125
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions optax/losses/_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def softmax_cross_entropy(
def softmax_cross_entropy_with_integer_labels(
logits: chex.Array,
labels: chex.Array,
axis: Union[int, tuple[int, ...], None] = -1,
axis: Union[int, None] = -1,
where: Union[chex.Array, None] = None,
) -> chex.Array:
r"""Computes softmax cross entropy between the logits and integer labels.
Expand All @@ -297,7 +297,7 @@ def softmax_cross_entropy_with_integer_labels(
labels: Integers specifying the correct class for each input, with shape
``[batch_size]``. Class labels are assumed to be between 0 and
``num_classes - 1`` inclusive.
axis: Axis or axes along which to compute.
axis: Axis along which to compute.
where: Elements to include in the computation.
Returns:
Expand Down Expand Up @@ -329,6 +329,9 @@ def softmax_cross_entropy_with_integer_labels(
"""
chex.assert_type([logits], float)
chex.assert_type([labels], int)
if axis is not None and not isinstance(axis, int):
raise ValueError(f'axis = {axis} is unsupported. Provide an int or None.')

# This is like jnp.take_along_axis(jax.nn.log_softmax(...), ...) except that
# we avoid subtracting the normalizer from all values, just from the values
# for the correct labels.
Expand Down

0 comments on commit 9cc3125

Please sign in to comment.