Skip to content

Commit

Permalink
set PYTORCH_JIT_USE_NNC_NOT_NVFUSER by default
Browse files Browse the repository at this point in the history
  • Loading branch information
Linux-cpp-lisp committed Jul 29, 2023
1 parent 0b02c41 commit 3f03c77
Showing 1 changed file with 17 additions and 14 deletions.
31 changes: 17 additions & 14 deletions nequip/utils/_global_options.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import warnings
from packaging import version
import os

import torch

Expand Down Expand Up @@ -86,22 +87,24 @@ def _set_global_options(config, warn_on_override: bool = False) -> None:
# fuser1 is NNC, fuser2 is nvFuser
# See https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/OVERVIEW.md#fusers
# And https://github.com/pytorch/pytorch/blob/e0a0f37a11164f59b42bc80a6f95b54f722d47ce/torch/jit/_fuser.py#L46
default_fuser = (
"fuser2" # TODO: does this make sense for ROCm?
if torch.cuda.is_available()
else "fuser1" # default to NNC on CPU for now no matter what
if version.parse(torch.__version__) >= version.parse("1.12")
else "fuser1"
)
fuser = config.get("_jit_fuser", default_fuser)
# context manager just restores old fuser afterwards
if torch.cuda.is_available():
torch.jit.fuser(fuser).__enter__()
if warn_on_override and fuser != default_fuser:
# ^ meh assumption, but better than hardcoding getting the old state
# Also https://github.com/pytorch/pytorch/blob/main/torch/csrc/jit/codegen/cuda/README.md
# Also https://github.com/pytorch/pytorch/blob/66fb83293e6a6f527d3fde632e3547fda20becea/torch/csrc/jit/OVERVIEW.md?plain=1#L1201
# https://github.com/search?q=repo%3Apytorch%2Fpytorch%20PYTORCH_JIT_USE_NNC_NOT_NVFUSER&type=code
# We follow the approach they have explicitly built for disabling nvFuser in favor of NNC:
# https://github.com/pytorch/pytorch/blob/66fb83293e6a6f527d3fde632e3547fda20becea/torch/csrc/jit/codegen/cuda/README.md?plain=1#L214
#
# There are three ways to disable nvfuser. Listed below with descending priorities:
# - Force using NNC instead of nvfuser for GPU fusion with env variable `export PYTORCH_JIT_USE_NNC_NOT_NVFUSER=1`.
# - Disabling nvfuser with torch API `torch._C._jit_set_nvfuser_enabled(False)`.
# - Disable nvfuser with env variable `export PYTORCH_JIT_ENABLE_NVFUSER=0`.
#
k = "PYTORCH_JIT_USE_NNC_NOT_NVFUSER"
if k in os.environ:
warnings.warn(
f"Setting the GLOBAL value for JIT fuser to `{fuser}`, which is different than the default for your current PyTorch version ({torch.__version__}) of `{default_fuser}`"
"Do NOT manually set PYTORCH_JIT_USE_NNC_NOT_NVFUSER=0 unless you know exactly what you're doing!"
)
else:
os.environ[k] = "1"

# TODO: warn_on_override for the rest here?
if config.get("model_debug_mode", False):
Expand Down

0 comments on commit 3f03c77

Please sign in to comment.