Skip to content

Commit

Permalink
Fix CI
Browse files Browse the repository at this point in the history
  • Loading branch information
Aphoh committed Dec 18, 2024
1 parent 3ae3014 commit e8a78d3
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 9 deletions.
4 changes: 1 addition & 3 deletions src/levanter/main/train_asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,8 @@ def compute_loss(
example: AudioTextExample,
*,
key=None,
reduction: Optional[hax.ReductionFunction] = hax.mean,
reduction_axis: Optional[hax.AxisSelection] = None,
) -> jax.numpy.ndarray | hax.NamedArray:
return m.compute_loss(example, key=key, reduction=reduction, reduction_axis=reduction_axis)
return m.compute_loss(example, key=key)

# Using the trainer as a context manager does 3 things:
# 1. Sets the device mesh
Expand Down
2 changes: 1 addition & 1 deletion src/levanter/main/viz_logprobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def main(config: VizGpt2Config):
def compute_log_probs(model: LmHeadModel, example: LmExample):
model = inference_mode(model, True)
model = mp.cast_to_compute(model)
logprobs = compute_next_token_loss(model, example, reduction=None)
logprobs, where, _ = compute_next_token_loss(model, example)
# roll forward to get the loss for each predicted token
logprobs = hax.roll(logprobs, 1, Pos)
return logprobs.rearrange((EvalBatch, Pos)).array
Expand Down
12 changes: 7 additions & 5 deletions src/levanter/models/asr_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from levanter.models.attention import AttentionMask
from levanter.models.lm_model import LmConfig
from levanter.utils.types import Extras


class AudioTextExample(eqx.Module):
Expand Down Expand Up @@ -97,9 +98,7 @@ def compute_loss(
example: AudioTextExample,
*,
key=None,
reduction: Optional[hax.ReductionFunction] = hax.mean,
reduction_axis: Optional[hax.AxisSelection] = None,
) -> jnp.ndarray | NamedArray:
) -> tuple[jnp.ndarray | NamedArray, NamedArray, Extras]:
"""
Computes the cross-entropy loss for predicted ASR tokens. If reduction is not None, the loss is reduced
across the reduction axis (with reduction_axis=None meaning all axes). If reduction is None, the loss is not
Expand All @@ -110,10 +109,13 @@ def compute_loss(
targets = hax.roll(example.tokens, -1, axis=self.Pos.name)
target_y = hax.nn.one_hot(targets, self.Vocab, dtype=logits.dtype)
loss = cross_entropy_loss(
logits, self.Vocab, target_y, reduction, reduction_axis=reduction_axis, where=example.loss_mask
logits,
self.Vocab,
target_y,
reduction=None,
)

return loss
return loss, example.loss_mask, {}

@property
def vocab_size(self) -> int:
Expand Down

0 comments on commit e8a78d3

Please sign in to comment.