diff --git a/src/levanter/main/lora_lm.py b/src/levanter/main/lora_lm.py index c92b61d6c..b111d098f 100644 --- a/src/levanter/main/lora_lm.py +++ b/src/levanter/main/lora_lm.py @@ -94,9 +94,7 @@ def compute_loss(model, example: LmExample, key=None): all_param_count = parameter_count(state.model) # TODO: remove this once we put this in trainer itself - just_lora_params = parameter_count( - levanter.trainer._partition_trainable_params(state.model, lora_param_filter) - ) + just_lora_params = parameter_count(state.trainable_model) levanter.tracker.log_summary( { @@ -108,7 +106,7 @@ def compute_loss(model, example: LmExample, key=None): logger.info(f"Total parameter count: {all_param_count}") logger.info(f"Trainable parameter count: {just_lora_params}") - logger.info(f"Fraction of parameters that are trainable: {just_lora_params * 1.0 / all_param_count%.3}") + logger.info(f"Fraction of parameters that are trainable: {just_lora_params * 1.0 / all_param_count:.3e}") # data loaders eval_datasets = config.data.validation_sets(Pos.size)