Skip to content

Commit

Permalink
Remove multiplication by 2 in mfu calculations. More accurately compu…
Browse files Browse the repository at this point in the history
…te softmax computational intensity. (#111)

Co-authored-by: Juan Acevedo <[email protected]>
  • Loading branch information
entrpn and jfacevedo-google authored Sep 25, 2024
1 parent a73fbcf commit 5de63af
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 3 deletions.
3 changes: 1 addition & 2 deletions src/maxdiffusion/max_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,7 @@ def calculate_model_tflops(
elif isinstance(c.module, AttentionOp):
qk_einsum = 2 * (reduce(lambda x, y: x * y, inputs[0].shape)) * inputs[1].shape[1]
scaling = inputs[0].shape[0] * inputs[0].shape[1] * inputs[1].shape[1]
softmax = reduce(lambda x, y: x * y, inputs[0].shape) * np.log(inputs[1].shape[1])
softmax = inputs[0].shape[0] * inputs[0].shape[1] * np.log(inputs[1].shape[1])
att_v = 2 * (reduce(lambda x, y: x * y, inputs[0].shape)) * inputs[2].shape[1]
# When seq_length_1 == seq_length_2 then,
# qk_einsum + scaling + softmax + att_v == 4 * batch_size * hidden_dim * seq_length ^ 2
Expand All @@ -528,7 +528,6 @@ def calculate_model_tflops(
* c.module.features
* reduce(lambda x, y: x * y, c.module.kernel_size)) / reduce(lambda x, y: x * y, c.module.strides)
visited_paths.add(c.path)

total_flops = (total_flops * 3 if train else total_flops) / 10**12
return total_flops

Expand Down
2 changes: 1 addition & 1 deletion src/maxdiffusion/trainers/base_stable_diffusion_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def create_scheduler(self, pipeline, params):
def calculate_tflops(self, pipeline, params):
per_device_tflops = maxdiffusion_utils.calculate_unet_tflops(
self.config, pipeline,
(2 * self.config.per_device_batch_size * jax.local_device_count()),
(self.config.per_device_batch_size * jax.local_device_count()),
self.rng,
train=True
)
Expand Down

0 comments on commit 5de63af

Please sign in to comment.