From 456e3daf2ff958d3cb2b983915121363f27b56ad Mon Sep 17 00:00:00 2001 From: RoyYang0714 Date: Wed, 7 Aug 2024 15:56:05 +0200 Subject: [PATCH] feat: Add compute FLOPs flag. --- vis4d/common/ckpt.py | 4 +++- vis4d/pl/run.py | 1 + vis4d/zoo/base/runtime.py | 1 + 3 files changed, 5 insertions(+), 1 deletion(-) diff --git a/vis4d/common/ckpt.py b/vis4d/common/ckpt.py index fdd5b6e4a..c4ad64b49 100644 --- a/vis4d/common/ckpt.py +++ b/vis4d/common/ckpt.py @@ -205,7 +205,9 @@ def load_from_local( filename = osp.expanduser(filename) if not osp.isfile(filename): raise FileNotFoundError(f"{filename} can not be found.") - checkpoint = torch.load(filename, map_location=map_location) + checkpoint = torch.load( + filename, weights_only=True, map_location=map_location + ) return checkpoint diff --git a/vis4d/pl/run.py b/vis4d/pl/run.py index 574f35728..bfa43ab1e 100644 --- a/vis4d/pl/run.py +++ b/vis4d/pl/run.py @@ -155,6 +155,7 @@ def main(argv: ArgsType) -> None: hyper_params, config.seed, ckpt_path if not resume else None, + config.compute_flops, ) data_module = DataModule(config.data) diff --git a/vis4d/zoo/base/runtime.py b/vis4d/zoo/base/runtime.py index 65e842f97..c457638c7 100644 --- a/vis4d/zoo/base/runtime.py +++ b/vis4d/zoo/base/runtime.py @@ -58,6 +58,7 @@ def get_default_cfg( config.use_tf32 = False config.tf32_matmul_precision = "highest" config.benchmark = False + config.compute_flops = False return config