From 5de63afe41f9ee5b25276cbbbb85b460f293e6d7 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Wed, 25 Sep 2024 11:59:56 -0700 Subject: [PATCH] Remove multiplication by 2 in mfu calculations. More accurately compute softmax computational intensity. (#111) Co-authored-by: Juan Acevedo --- src/maxdiffusion/max_utils.py | 3 +-- src/maxdiffusion/trainers/base_stable_diffusion_trainer.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index 0a60d0dc..d2af454d 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -517,7 +517,7 @@ def calculate_model_tflops( elif isinstance(c.module, AttentionOp): qk_einsum = 2 * (reduce(lambda x, y: x * y, inputs[0].shape)) * inputs[1].shape[1] scaling = inputs[0].shape[0] * inputs[0].shape[1] * inputs[1].shape[1] - softmax = reduce(lambda x, y: x * y, inputs[0].shape) * np.log(inputs[1].shape[1]) + softmax = inputs[0].shape[0] * inputs[0].shape[1] * np.log(inputs[1].shape[1]) att_v = 2 * (reduce(lambda x, y: x * y, inputs[0].shape)) * inputs[2].shape[1] # When seq_length_1 == seq_length_2 then, # qk_einsum + scaling + softmax + att_v == 4 * batch_size * hidden_dim * seq_length ^ 2 @@ -528,7 +528,6 @@ def calculate_model_tflops( * c.module.features * reduce(lambda x, y: x * y, c.module.kernel_size)) / reduce(lambda x, y: x * y, c.module.strides) visited_paths.add(c.path) - total_flops = (total_flops * 3 if train else total_flops) / 10**12 return total_flops diff --git a/src/maxdiffusion/trainers/base_stable_diffusion_trainer.py b/src/maxdiffusion/trainers/base_stable_diffusion_trainer.py index df61f4ec..560a15dd 100644 --- a/src/maxdiffusion/trainers/base_stable_diffusion_trainer.py +++ b/src/maxdiffusion/trainers/base_stable_diffusion_trainer.py @@ -74,7 +74,7 @@ def create_scheduler(self, pipeline, params): def calculate_tflops(self, pipeline, params): per_device_tflops = maxdiffusion_utils.calculate_unet_tflops( self.config, pipeline, - (2 * self.config.per_device_batch_size * jax.local_device_count()), + (self.config.per_device_batch_size * jax.local_device_count()), self.rng, train=True )