Skip to content

Commit

Permalink
Merge pull request #43 from teddykoker/cuda_fail_msg
Browse files Browse the repository at this point in the history
Better error message if CUDA extension is not available
  • Loading branch information
teddykoker authored Dec 30, 2021
2 parents 6b81e6c + 8de04e1 commit 87f7715
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 5 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def ext_modules():

setup(
name="torchsort",
version="0.1.7",
version="0.1.8",
description="Differentiable sorting and ranking in PyTorch",
author="Teddy Koker",
url="https://github.com/teddykoker/torchsort",
Expand Down
16 changes: 12 additions & 4 deletions torchsort/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,18 @@
from .isotonic_cuda import isotonic_l2 as isotonic_l2_cuda
from .isotonic_cuda import isotonic_l2_backward as isotonic_l2_backward_cuda
except ImportError:
isotonic_l2_cuda = None
isotonic_kl_cuda = None
isotonic_l2_backward_cuda = None
isotonic_kl_backward_cuda = None

def _error(*args, **kwargs):
raise ImportError(
"You are trying to use the torchsort CUDA extension, but it looks like it is not available."
" Make sure you have the CUDA toolchain installed, and reinstall torchsort with `pip install --force-reinstall --no-cache-dir torchsort`"
" to rebuild the extension."
)

isotonic_l2_cuda = _error
isotonic_kl_cuda = _error
isotonic_l2_backward_cuda = _error
isotonic_kl_backward_cuda = _error


def soft_rank(values, regularization="l2", regularization_strength=1.0):
Expand Down

0 comments on commit 87f7715

Please sign in to comment.