Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support tuple of axis in softmax_cross_entropy_with_integer_labels #1165

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 42 additions & 2 deletions optax/losses/_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,14 @@
import chex
import jax
import jax.numpy as jnp
import numpy as np
from optax import projections

if np.__version__.startswith('1.'):
from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple
daskol marked this conversation as resolved.
Show resolved Hide resolved
else:
from numpy.lib.array_utils import normalize_axis_index, normalize_axis_tuple


def sigmoid_binary_cross_entropy(
logits,
Expand Down Expand Up @@ -273,7 +279,7 @@ def softmax_cross_entropy(
def softmax_cross_entropy_with_integer_labels(
logits: chex.Array,
labels: chex.Array,
axis: Union[int, tuple[int, ...], None] = -1,
axis: Union[int, tuple[int, ...]] = -1,
where: Union[chex.Array, None] = None,
) -> chex.Array:
r"""Computes softmax cross entropy between the logits and integer labels.
Expand All @@ -297,7 +303,10 @@ def softmax_cross_entropy_with_integer_labels(
labels: Integers specifying the correct class for each input, with shape
``[batch_size]``. Class labels are assumed to be between 0 and
``num_classes - 1`` inclusive.
axis: Axis or axes along which to compute.
axis: Axis or axes along which to compute. If a tuple of axes is passed
then ``num_classes`` must match the total number of elements in ``axis``
dimensions and a label is interpreted as a flat index in a ``logits``
slice of shape ``logits[axis]``.
where: Elements to include in the computation.

Returns:
Expand All @@ -313,6 +322,21 @@ def softmax_cross_entropy_with_integer_labels(
>>> print(optax.softmax_cross_entropy_with_integer_labels(logits, labels))
[0.2761297 2.951799 ]

>>> import jax.numpy as jnp
>>> import numpy as np
>>> import optax
>>> # example: batch_size = (1, 2), num_classes = 12 (i.e. 3 * 4)
>>> shape = (1, 2, 3, 4)
>>> logits = jnp.arange(np.prod(shape), dtype=jnp.float32).reshape(shape)
>>> # elements indices in slice of shape (3, 4)
>>> ix = jnp.array([[1, 2]])
>>> jx = jnp.array([[1, 3]])
>>> labels = jnp.ravel_multi_index((ix, jx), shape[2:])
>>> cross_entropy = optax.softmax_cross_entropy_with_integer_labels(
... logits, labels, axis=(2, 3))
>>> print(cross_entropy)
[[6.458669 0.45866907]]

References:
`Cross-entropy Loss <https://en.wikipedia.org/wiki/Cross-entropy>`_,
Wikipedia
Expand All @@ -329,6 +353,22 @@ def softmax_cross_entropy_with_integer_labels(
"""
chex.assert_type([logits], float)
chex.assert_type([labels], int)
if isinstance(axis, int):
daskol marked this conversation as resolved.
Show resolved Hide resolved
axis = normalize_axis_index(axis, logits.ndim)
elif isinstance(axis, tuple):
# Move all "feature" dimensions to the end preserving axis ordering and
# subsequent flattening "feature" dimensions to a single one.
logit_axis = normalize_axis_tuple(axis, logits.ndim, argname='logits')
batch_axis = tuple(x for x in range(logits.ndim) if x not in logit_axis)
axis = len(batch_axis)
logits = logits.transpose(batch_axis + logit_axis)
logits = logits.reshape(logits.shape[:len(batch_axis)] + (-1, ))
if where is not None:
where = where.transpose(batch_axis + logit_axis)
where = where.reshape(where.shape[:len(batch_axis)] + (-1, ))
else:
raise ValueError('Keyword argument \'axis\' must be of type \'int\' or '
f'\'tuple[int, ...]\' but actual type is {type(axis)}.')
# This is like jnp.take_along_axis(jax.nn.log_softmax(...), ...) except that
# we avoid subtracting the normalizer from all values, just from the values
# for the correct labels.
Expand Down
34 changes: 34 additions & 0 deletions optax/losses/_classification_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,40 @@ def test_axis(self, shape, axis):
)
np.testing.assert_allclose(x, y, atol=1e-4)

@parameterized.parameters(
{'axis': (1, 3), 'shape': (2, 3, 4, 5)},
{'axis': (3, 2), 'shape': (2, 3, 4, 5)},
daskol marked this conversation as resolved.
Show resolved Hide resolved
{'axis': (2, 3), 'shape': (2, 3, 4, 5)},
{'axis': (-3, -1), 'shape': (2, 3, 4, 5)},
{'axis': (-1, -2), 'shape': (2, 3, 4, 5)},
{'axis': (-2, -1), 'shape': (2, 3, 4, 5)},
)
def test_axes(self, shape: tuple[int, ...], axis: tuple[int, ...]):
# Canonicalize axis and calculate shapes.
ndim = len(shape)
logits_axis = tuple((x + ndim) % ndim for x in axis)
labels_axis = tuple(x for x in range(ndim) if x not in logits_axis)
# Obtain shapes of batch and logits subspaces.
logits_shape = tuple(shape[x] for x in logits_axis)
labels_shape = tuple(shape[x] for x in labels_axis)
num_classes: float = np.prod(logits_shape).item()

key = jax.random.key(42)
keys = jax.random.split(key, 2)
logits = jax.random.uniform(keys[0], labels_shape + (num_classes, ))
labels = jax.random.randint(keys[1], labels_shape, 0, num_classes - 1)

fn = _classification.softmax_cross_entropy_with_integer_labels
desired = fn(logits, labels)

# Apply inverse axes permutation to obtain an array of `shape` shape.
logits = logits \
.reshape(labels_shape + logits_shape) \
.transpose(labels_axis + logits_axis)
assert logits.shape == shape
actual = fn(logits, labels, axis)
np.testing.assert_allclose(actual, desired)


class SigmoidCrossEntropyTest(parameterized.TestCase):

Expand Down
Loading