Skip to content

Commit

Permalink
added missing function
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmacavaney committed Jul 16, 2023
1 parent f3452cf commit 862b973
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 1 deletion.
2 changes: 1 addition & 1 deletion pyterrier_dr/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .util import SimFn
from .util import SimFn, infer_device
from .indexes import DocnoFile, NilIndex, NumpyIndex, RankedLists, FaissFlat, FaissHnsw, MemIndex, TorchIndex
from .flex import FlexIndex
from .biencoder import BiEncoder, BiQueryEncoder, BiDocEncoder, BiScorer
Expand Down
5 changes: 5 additions & 0 deletions pyterrier_dr/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,8 @@ def wrapped(*args, **kwargs):

def __init__(self, *args, **kwargs):
return super().__init__(*args, **kwargs)

def infer_device(device):
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
return torch.device(device)

0 comments on commit 862b973

Please sign in to comment.