Skip to content

Commit

Permalink
fix trainable_param_count invocations
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh committed Feb 13, 2024
1 parent c6da233 commit c70eedc
Showing 1 changed file with 2 additions and 4 deletions.
6 changes: 2 additions & 4 deletions src/levanter/main/lora_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,7 @@ def compute_loss(model, example: LmExample, key=None):

all_param_count = parameter_count(state.model)
# TODO: remove this once we put this in trainer itself
just_lora_params = parameter_count(
levanter.trainer._partition_trainable_params(state.model, lora_param_filter)
)
just_lora_params = parameter_count(state.trainable_model)

levanter.tracker.log_summary(
{
Expand All @@ -108,7 +106,7 @@ def compute_loss(model, example: LmExample, key=None):

logger.info(f"Total parameter count: {all_param_count}")
logger.info(f"Trainable parameter count: {just_lora_params}")
logger.info(f"Fraction of parameters that are trainable: {just_lora_params * 1.0 / all_param_count%.3}")
logger.info(f"Fraction of parameters that are trainable: {just_lora_params * 1.0 / all_param_count:.3e}")

# data loaders
eval_datasets = config.data.validation_sets(Pos.size)
Expand Down

0 comments on commit c70eedc

Please sign in to comment.