diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py index 2279a2b8d..59c55cfc2 100644 --- a/src/levanter/main/train_lm.py +++ b/src/levanter/main/train_lm.py @@ -157,7 +157,7 @@ def add_floats(x, y): # what is the f here? logger.info(f"Interpolating between the two models with alpha={alpha}") merged_model = named_jit(lambda m1, m2: jax.tree_util.tree_map(add_floats, m1, m2), donate_args=True)(model, model_2) - state = dataclasses.replace(state, model=model) + state = dataclasses.replace(state, model=merged_model) else: logger.info("No checkpoint found. Starting from scratch.")