Skip to content

Commit

Permalink
[tokenizers] Converting encoding to int32 NDList (#3468)
Browse files Browse the repository at this point in the history
  • Loading branch information
xyang16 authored Sep 12, 2024
1 parent 8247e57 commit 83589bb
Showing 1 changed file with 9 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;

import java.util.Arrays;

/** A class holds token encoding information. */
public class Encoding {

Expand Down Expand Up @@ -58,11 +60,15 @@ protected Encoding(
* @return the {@link NDList}
*/
public NDList toNDList(NDManager manager, boolean withTokenType) {
// Converting encoding to int32 NDList because candle can't convert int64 to fp16 in cuda
NDList list = new NDList(withTokenType ? 3 : 2);
list.add(manager.create(ids));
list.add(manager.create(attentionMask));
int[] intIds = Arrays.stream(ids).mapToInt(i -> (int) i).toArray();
int[] intAttentionMask = Arrays.stream(attentionMask).mapToInt(i -> (int) i).toArray();
list.add(manager.create(intIds));
list.add(manager.create(intAttentionMask));
if (withTokenType) {
list.add(manager.create(typeIds));
int[] intTypeIds = Arrays.stream(typeIds).mapToInt(i -> (int) i).toArray();
list.add(manager.create(intTypeIds));
}
return list;
}
Expand Down

0 comments on commit 83589bb

Please sign in to comment.