Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Different training losses when flash_attention is on/off #1918

Open
6 of 8 tasks
zhangchen-xu opened this issue Sep 18, 2024 · 6 comments
Open
6 of 8 tasks

Different training losses when flash_attention is on/off #1918

zhangchen-xu opened this issue Sep 18, 2024 · 6 comments
Assignees
Labels
bug Something isn't working

Comments

@zhangchen-xu
Copy link

zhangchen-xu commented Sep 18, 2024

Please check that this issue hasn't been reported before.

  • I searched previous Bug Reports didn't find any similar reports.

Expected Behavior

Flash attention should not make training losses differs a lot.

Current behaviour

I did preliminary experiments on Gemma 2b with different datasets. When flash attention is on, the loss is significantly lower than when flash attention is off.

Please see the figure below. The wandb run name with -flash means the flash attention is on.

image

However, the validation losses are normal.

Steps to reproduce

Simply enable and disable flash attention in the configuration.

Config yaml

No response

Possible solution

Is there anything wrong with loss calculation when flash_attention is off? Since usually the training loss should be slightly lower than validation loss.

image

Which Operating Systems are you using?

  • Linux
  • macOS
  • Windows

Python Version

3.10

axolotl branch-commit

main/4d6490b

Acknowledgements

  • My issue title is concise, descriptive, and in title casing.
  • I have searched the existing issues to make sure this bug has not been reported yet.
  • I am using the latest version of axolotl.
  • I have provided enough information for the maintainers to reproduce and diagnose the issue.
@zhangchen-xu zhangchen-xu added the bug Something isn't working label Sep 18, 2024
@chouyi-peng
Copy link

I encountered the same issue while training Gemma-2-2b.
I was wondering if you have found a solution to it?

@zhangchen-xu
Copy link
Author

After cross checking, I believe the loss when flash-attention is on is the correct one. Also, when flash-attention is on, the fine-tuned model performance is also higher than when flash-attention is off.

@ehartford
Copy link
Collaborator

This is very interesting!

@bursteratom bursteratom self-assigned this Nov 13, 2024
@winglian
Copy link
Collaborator

@zhangchen-xu Are you using liger by any chance too?

@bursteratom
Copy link
Collaborator

@zhangchen-xu I'm wondering if you can provide the configuration?

@zhangchen-xu
Copy link
Author

@zhangchen-xu Are you using liger by any chance too?

Thank you! I don't use that.

@zhangchen-xu I'm wondering if you can provide the configuration?

Sure, here is the configuration when flash attention is disabled:

base_model: google/gemma-2-2b
model_type: Gemma2ForCausalLM
tokenizer_type: AutoTokenizer
chat_template: gemma

load_in_8bit: false
load_in_4bit: false
strict: false

datasets:
  - path: flydust/Magpie-100k-Gemma2-9B
    type: sharegpt
    chat_template: gemma
dataset_prepared_path: last_run_prepared
val_set_size: 0.001
output_dir: axolotl_out/gemma-2-2b-magpie-gemma2-9b

sequence_len: 4096
sample_packing: true
eval_sample_packing: false
pad_to_sequence_len: true

wandb_project: SynDa
wandb_entity:
wandb_watch:
wandb_name: gemma-2-2b-magpie-gemma2-9b
wandb_log_model:
hub_model_id: flydust/gemma-2-2b-magpie-gemma2-9b

gradient_accumulation_steps: 8
micro_batch_size: 1
num_epochs: 2
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 2e-5

train_on_inputs: false
group_by_length: false
bf16: true
fp16:
tf32: false

gradient_checkpointing: true
gradient_checkpointing_kwargs:
  use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
logging_steps: 1
xformers_attention:
# Disable flash attention
# flash_attention: false
# sdp_attention: falses
eager_attention: true

warmup_ratio: 0.1
evals_per_epoch: 5
eval_table_size:
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:

I also make these two models public for your reference:
flydust/gemma-2-2b-magpie-gemma2-9b
flydust/gemma-2-2b-magpie-gemma2-9b-flash

Thank you so much for your help!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

5 participants