forked from meta-llama/llama-recipes
-
Notifications
You must be signed in to change notification settings - Fork 0
/
flop_utils.py
87 lines (78 loc) · 3.42 KB
/
flop_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
from typing import Any, Dict, List, Optional, Union
import time
import torch
from torch.utils.flop_counter import FlopCounterMode
class FlopMeasure(FlopCounterMode):
"""
``FlopMeasure`` is a customized context manager that counts the number of
flops within its context. It is based on ``FlopCounterMode`` with additional start_counting() and stop_counting() function so that the flop counting
will only start after the warmup stage.
It also supports hierarchical output by passing a module (or list of modules) to FlopCounterMode on construction.
Example usage
.. code-block:: python
model = ...
flop_counter = FlopMeasure(model,local_rank=0,warmup_step=3)
for batch in enumerate(dataloader):
with flop_counter:
model(batch)
flop_counter.step()
"""
def __init__(
self,
mods: Optional[Union[torch.nn.Module, List[torch.nn.Module]]] = None,
depth: int = 2,
display: bool = True,
custom_mapping: Dict[Any, Any] = None,
rank=None,
warmup_step: int = 3,
):
super().__init__(mods, depth, display, custom_mapping)
self.rank = rank
self.warmup_step = warmup_step
self.start_time = 0
self.end_time = 0
def step(self):
# decrease the warmup step by 1 for every step, so that the flop counting will start when warmup_step =0. Stop decreasing when warm_up reaches -1.
if self.warmup_step >= 0:
self.warmup_step -= 1
if self.warmup_step == 0 and self.start_time == 0:
self.start_time = time.time()
elif self.warmup_step == -1 and self.start_time != 0 and self.end_time == 0:
self.end_time = time.time()
def __enter__(self):
if self.warmup_step == 0:
self.start_time = time.time()
super().__enter__()
return self
def is_done(self):
return self.warmup_step == -1
def get_total_flops(self):
return super().get_total_flops()
def get_flops_per_sec(self):
if self.start_time == 0 or self.end_time == 0:
print("Warning: flop count did not finish correctly")
return 0
return super().get_total_flops()/ (self.end_time - self.start_time)
def get_table(self, depth=2):
return super().get_table(depth)
def __exit__(self, *args):
if self.get_total_flops() == 0:
print(
"Warning: did not record any flops this time. Skipping the flop report"
)
else:
if self.display:
if self.rank is None or self.rank == 0:
print("Total time used in this flop counting step is: {}".format(self.end_time - self.start_time))
print("The total TFlop per second is: {}".format(self.get_flops_per_sec() / 1e12))
print("The tflop_count table is below:")
print(self.get_table(self.depth))
# Disable the display feature so that we don't print the table again
self.display = False
super().__exit__(*args)
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
# when warmup_step is 0, count the flops and return the original output
if self.warmup_step == 0:
return super().__torch_dispatch__(func, types, args, kwargs)
# otherwise, just return the original output
return func(*args, **kwargs)