-
-
Notifications
You must be signed in to change notification settings - Fork 885
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Different training losses when flash_attention is on/off #1918
Comments
I encountered the same issue while training Gemma-2-2b. |
After cross checking, I believe the loss when flash-attention is on is the correct one. Also, when flash-attention is on, the fine-tuned model performance is also higher than when flash-attention is off. |
This is very interesting! |
@zhangchen-xu Are you using liger by any chance too? |
@zhangchen-xu I'm wondering if you can provide the configuration? |
Thank you! I don't use that.
Sure, here is the configuration when flash attention is disabled: base_model: google/gemma-2-2b
model_type: Gemma2ForCausalLM
tokenizer_type: AutoTokenizer
chat_template: gemma
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: flydust/Magpie-100k-Gemma2-9B
type: sharegpt
chat_template: gemma
dataset_prepared_path: last_run_prepared
val_set_size: 0.001
output_dir: axolotl_out/gemma-2-2b-magpie-gemma2-9b
sequence_len: 4096
sample_packing: true
eval_sample_packing: false
pad_to_sequence_len: true
wandb_project: SynDa
wandb_entity:
wandb_watch:
wandb_name: gemma-2-2b-magpie-gemma2-9b
wandb_log_model:
hub_model_id: flydust/gemma-2-2b-magpie-gemma2-9b
gradient_accumulation_steps: 8
micro_batch_size: 1
num_epochs: 2
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 2e-5
train_on_inputs: false
group_by_length: false
bf16: true
fp16:
tf32: false
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
logging_steps: 1
xformers_attention:
# Disable flash attention
# flash_attention: false
# sdp_attention: falses
eager_attention: true
warmup_ratio: 0.1
evals_per_epoch: 5
eval_table_size:
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens: I also make these two models public for your reference: Thank you so much for your help! |
Please check that this issue hasn't been reported before.
Expected Behavior
Flash attention should not make training losses differs a lot.
Current behaviour
I did preliminary experiments on Gemma 2b with different datasets. When flash attention is on, the loss is significantly lower than when flash attention is off.
Please see the figure below. The wandb run name with -flash means the flash attention is on.
However, the validation losses are normal.
Steps to reproduce
Simply enable and disable flash attention in the configuration.
Config yaml
No response
Possible solution
Is there anything wrong with loss calculation when flash_attention is off? Since usually the training loss should be slightly lower than validation loss.
Which Operating Systems are you using?
Python Version
3.10
axolotl branch-commit
main/4d6490b
Acknowledgements
The text was updated successfully, but these errors were encountered: