diff --git a/src/levanter/doremi.py b/src/levanter/doremi.py index c08be31d5..602dae4db 100644 --- a/src/levanter/doremi.py +++ b/src/levanter/doremi.py @@ -184,7 +184,8 @@ def doremi_step(state: DoremiState, ref, batch, domains): "train/mean_excess_loss": mean_excess_loss, "train/mean_proxy_loss": mean_proxy_loss, **{f"alpha/{domain}": weight for domain, weight in alpha_dict.items()}, - **{f"train/{domain}/loss": loss for domain, loss in per_domain_dict.items()}, + # just skip domains with no excess loss + **{f"train/{domain}/excess_loss": loss for domain, loss in per_domain_dict.items() if loss > 0}, }, step=state._step, )