diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 21f8ef0260fe1c..69cb1f05163def 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -312,31 +312,7 @@ def is_torch_bf16_gpu_available(): import torch - # since currently no utility function is available we build our own. - # some bits come from https://github.com/pytorch/pytorch/blob/2289a12f21c54da93bf5d696e3f9aea83dd9c10d/torch/testing/_internal/common_cuda.py#L51 - # with additional check for torch version - # to succeed: - # 1. torch >= 1.10 (1.9 should be enough for AMP API has changed in 1.10, so using 1.10 as minimal) - # 2. the hardware needs to support bf16 (GPU arch >= Ampere, or CPU) - # 3. if using gpu, CUDA >= 11 - # 4. torch.autocast exists - # XXX: one problem here is that it may give invalid results on mixed gpus setup, so it's - # really only correct for the 0th gpu (or currently set default device if different from 0) - if version.parse(version.parse(torch.__version__).base_version) < version.parse("1.10"): - return False - - if torch.cuda.is_available() and torch.version.cuda is not None: - if torch.cuda.get_device_properties(torch.cuda.current_device()).major < 8: - return False - if int(torch.version.cuda.split(".")[0]) < 11: - return False - if not hasattr(torch.cuda.amp, "autocast"): - return False - else: - return False - - return True - + return torch.cuda.is_bf16_supported() def is_torch_bf16_cpu_available(): if not is_torch_available():