diff --git a/src/levanter/optim/config.py b/src/levanter/optim/config.py index fe0cdfb88..d1f4841fc 100644 --- a/src/levanter/optim/config.py +++ b/src/levanter/optim/config.py @@ -172,7 +172,7 @@ def lr_scheduler(self, num_train_steps): if warmup_steps != 0: warmup = optax.linear_schedule(previous_end, self.learning_rate, warmup_steps) schedules.append(warmup) - boundaries.append(warmup_steps) + boundaries.append(start + warmup_steps) stable_steps = _convert_ratio_or_steps(self.stable, cycle_steps) lr_decay_steps = cycle_steps - stable_steps - warmup_steps