From 7b7e4a45c1d5ef1c8b832ed4a6dfd6233351967d Mon Sep 17 00:00:00 2001 From: Ahmed Ahmed Date: Fri, 16 Feb 2024 13:08:40 -0800 Subject: [PATCH] fix --- src/levanter/main/train_lm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py index 2279a2b8d..59c55cfc2 100644 --- a/src/levanter/main/train_lm.py +++ b/src/levanter/main/train_lm.py @@ -157,7 +157,7 @@ def add_floats(x, y): # what is the f here? logger.info(f"Interpolating between the two models with alpha={alpha}") merged_model = named_jit(lambda m1, m2: jax.tree_util.tree_map(add_floats, m1, m2), donate_args=True)(model, model_2) - state = dataclasses.replace(state, model=model) + state = dataclasses.replace(state, model=merged_model) else: logger.info("No checkpoint found. Starting from scratch.")