From c70eedca05d5a4c235c80dc61f155f3128f421b8 Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 13 Feb 2024 13:58:51 -0800 Subject: [PATCH] fix trainable_param_count invocations --- src/levanter/main/lora_lm.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/levanter/main/lora_lm.py b/src/levanter/main/lora_lm.py index c92b61d6c..b111d098f 100644 --- a/src/levanter/main/lora_lm.py +++ b/src/levanter/main/lora_lm.py @@ -94,9 +94,7 @@ def compute_loss(model, example: LmExample, key=None): all_param_count = parameter_count(state.model) # TODO: remove this once we put this in trainer itself - just_lora_params = parameter_count( - levanter.trainer._partition_trainable_params(state.model, lora_param_filter) - ) + just_lora_params = parameter_count(state.trainable_model) levanter.tracker.log_summary( { @@ -108,7 +106,7 @@ def compute_loss(model, example: LmExample, key=None): logger.info(f"Total parameter count: {all_param_count}") logger.info(f"Trainable parameter count: {just_lora_params}") - logger.info(f"Fraction of parameters that are trainable: {just_lora_params * 1.0 / all_param_count%.3}") + logger.info(f"Fraction of parameters that are trainable: {just_lora_params * 1.0 / all_param_count:.3e}") # data loaders eval_datasets = config.data.validation_sets(Pos.size)