diff --git a/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/Encoding.java b/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/Encoding.java index 887f01646dc..301490a2bb4 100644 --- a/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/Encoding.java +++ b/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/Encoding.java @@ -16,6 +16,8 @@ import ai.djl.ndarray.NDList; import ai.djl.ndarray.NDManager; +import java.util.Arrays; + /** A class holds token encoding information. */ public class Encoding { @@ -58,11 +60,15 @@ protected Encoding( * @return the {@link NDList} */ public NDList toNDList(NDManager manager, boolean withTokenType) { + // Converting encoding to int32 NDList because candle can't convert int64 to fp16 in cuda NDList list = new NDList(withTokenType ? 3 : 2); - list.add(manager.create(ids)); - list.add(manager.create(attentionMask)); + int[] intIds = Arrays.stream(ids).mapToInt(i -> (int) i).toArray(); + int[] intAttentionMask = Arrays.stream(attentionMask).mapToInt(i -> (int) i).toArray(); + list.add(manager.create(intIds)); + list.add(manager.create(intAttentionMask)); if (withTokenType) { - list.add(manager.create(typeIds)); + int[] intTypeIds = Arrays.stream(typeIds).mapToInt(i -> (int) i).toArray(); + list.add(manager.create(intTypeIds)); } return list; }