diff --git a/nequip/train/callbacks/__init__.py b/nequip/train/callbacks/__init__.py index 05f09397..5cb02a36 100644 --- a/nequip/train/callbacks/__init__.py +++ b/nequip/train/callbacks/__init__.py @@ -4,6 +4,7 @@ from .nemo_ema import NeMoExponentialMovingAverage from .write_xyz import TestTimeXYZFileWriter from .wandb_watch import WandbWatch +from .profiler import Profiler __all__ = [ SoftAdapt, @@ -12,4 +13,5 @@ NeMoExponentialMovingAverage, TestTimeXYZFileWriter, WandbWatch, + Profiler, ] diff --git a/nequip/train/callbacks/profiler.py b/nequip/train/callbacks/profiler.py new file mode 100644 index 00000000..3ce267ab --- /dev/null +++ b/nequip/train/callbacks/profiler.py @@ -0,0 +1,45 @@ +import torch +import lightning +from lightning.pytorch.callbacks import Callback +from nequip.data import AtomicDataDict +from nequip.train import NequIPLightningModule + +class Profiler(Callback): + """Proxy class for `TensorBoard Profiler `_. + + Example usage in config: + :: + + trainer: + ... + callbacks: + - _target_: nequip.train.callbacks.Profiler + trace_output: "./proflog" + + Args: + trace_output (str): directory where profile data is stored + """ + + def __init__(self, trace_output='proflog'): + super().__init__() + self.prof = torch.profiler.profile( + schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=1), + on_trace_ready=torch.profiler.tensorboard_trace_handler(trace_output), + record_shapes=True, + profile_memory=True, + with_stack=True + ) + + def on_train_start(self, trainer, pl_module): + self.prof.start() + def on_train_end(self, trainer, pl_module): + self.prof.stop() + def on_train_batch_start( + self, + trainer: lightning.Trainer, + pl_module: NequIPLightningModule, + batch: AtomicDataDict.Type, + batch_idx: int, + ): + """""" + self.prof.step()