diff --git a/mlonmcu/session/postprocess/__init__.py b/mlonmcu/session/postprocess/__init__.py index c02fa9400..3c225a5bd 100644 --- a/mlonmcu/session/postprocess/__init__.py +++ b/mlonmcu/session/postprocess/__init__.py @@ -35,6 +35,7 @@ ValidateOutputsPostprocess, ValidateLabelsPostprocess, ExportOutputsPostprocess, + StageTimesGanttPostprocess, ) SUPPORTED_POSTPROCESSES = { @@ -54,4 +55,5 @@ "validate_outputs": ValidateOutputsPostprocess, "validate_labels": ValidateLabelsPostprocess, "export_outputs": ExportOutputsPostprocess, + "stage_times_gantt": StageTimesGanttPostprocess, } diff --git a/mlonmcu/session/postprocess/postprocesses.py b/mlonmcu/session/postprocess/postprocesses.py index 98aab6885..9f0a94956 100644 --- a/mlonmcu/session/postprocess/postprocesses.py +++ b/mlonmcu/session/postprocess/postprocesses.py @@ -23,6 +23,7 @@ import tempfile from pathlib import Path from io import StringIO +from collections import defaultdict import numpy as np import pandas as pd @@ -1903,3 +1904,58 @@ def dequant_helper(quant, data): if temp_dir: temp_dir.cleanup() return artifacts + + +class StageTimesGanttPostprocess(SessionPostprocess): + """TODO.""" + + DEFAULTS = { + **SessionPostprocess.DEFAULTS, + } + + def __init__(self, features=None, config=None): + super().__init__("stage_times_gantt", features=features, config=config) + + def post_session(self, report): + """Called at the end of a session.""" + artifacts = [] + content = """gantt + title Flow + dateFormat x + axisFormat %H:%M:%S +""" + for i, row in report.main_df.iterrows(): + content += f" section Run {i}\n" + stage_times = defaultdict(dict) + for key, value in row.items(): + # if " Stage Time [s]" in key: + # key = key.replace(" Stage Time [s]", "") + # stage_times[key]["time_s"] = value + if " Start Time [s]" in key: + key = key.replace(" Start Time [s]", "") + stage_times[key]["start"] = value + if " End Time [s]" in key: + key = key.replace(" End Time [s]", "") + stage_times[key]["end"] = value + # stage_times = dict(reversed(list(stage_times.items()))) + print("stage_times", stage_times) + first = True + for stage, times in stage_times.items(): + start = times.get("start") + end = times.get("end") + # time_s = times.get("time_s") + time_s = None + start = int(start * 1e3) + end = int(end * 1e3) + if False: + if first: + first = False + content += f" {stage} : 0, {time_s}s\n" + else: + + content += f" {stage} : {time_s}s\n" + else: + content += f" {stage} : {start}, {end}\n" + artifact = Artifact("stage_times.mermaid", content=content, fmt=ArtifactFormat.TEXT) + artifacts.append(artifact) + return artifacts diff --git a/mlonmcu/session/run.py b/mlonmcu/session/run.py index 20b869837..a83553cf4 100644 --- a/mlonmcu/session/run.py +++ b/mlonmcu/session/run.py @@ -18,6 +18,7 @@ # """Definition of a MLonMCU Run which represents a single benchmark instance for a given set of options.""" import itertools +import time import os import copy import tempfile @@ -83,6 +84,7 @@ class Run: "target_optimized_layouts": False, "target_optimized_schedules": False, "stage_subdirs": False, + "profile_stages": False, } REQUIRED = set() @@ -135,6 +137,7 @@ def __init__( self.sub_names = [] self.sub_parents = {} self.result = None + self.times = {} self.failing = False # -> RunStatus self.reason = None self.failed_stage = None @@ -190,6 +193,11 @@ def stage_subdirs(self): value = self.run_config["stage_subdirs"] return str2bool(value) + @property + def profile_stages(self): + value = self.run_config["profile_stages"] + return str2bool(value) + @property def build_platform(self): """Get platform for build stage.""" @@ -1087,7 +1095,12 @@ def process(self, until=RunStage.RUN, skip=None, export=False): if func: self.failing = False try: + if self.profile_stages: + start = time.time() func() + if self.profile_stages: + end = time.time() + self.times[stage] = (start, end) except Exception as e: self.failing = True self.reason = e @@ -1303,6 +1316,12 @@ def metrics_helper(stage, subs): main = metrics_by_sub[sub].get_data(include_optional=self.export_optional) else: main = {} + if self.profile_stages: + for stage, stage_times in self.times.items(): + assert len(stage_times) == 2 + start, end = stage_times + main[f"{RunStage(stage).name.capitalize()} Stage Start Time [s]"] = start + main[f"{RunStage(stage).name.capitalize()} Stage End Time [s]"] = end mains.append(main if len(main) > 0 else {"Incomplete": True}) posts.append(post) # TODO: omit for subs? report.set(pre=pres, main=mains, post=posts)