From 26fd49270b06e943ba1302753757a48928337342 Mon Sep 17 00:00:00 2001 From: JorgeDC Date: Mon, 20 Apr 2020 04:34:24 +0200 Subject: [PATCH] Map to torchtext device parameter (#55) * Map to torchtext device parameter Co-authored-by: jorge.decorte@pttrns.ai Co-authored-by: Sidharth Mudgal --- deepmatcher/data/iterator.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/deepmatcher/data/iterator.py b/deepmatcher/data/iterator.py index 03ddaf1..ba8a446 100644 --- a/deepmatcher/data/iterator.py +++ b/deepmatcher/data/iterator.py @@ -17,13 +17,14 @@ def __init__(self, train, batch_size, sort_in_buckets=None, + device=None, **kwargs): if sort_in_buckets is None: sort_in_buckets = train self.sort_in_buckets = sort_in_buckets self.train_info = train_info super(MatchingIterator, self).__init__( - dataset, batch_size, train=train, repeat=False, sort=False, **kwargs) + dataset, batch_size, train=train, repeat=False, sort=False, device=device, **kwargs) @classmethod def splits(cls, datasets, batch_sizes=None, **kwargs):