diff --git a/api/src/main/java/ai/djl/translate/Batchifier.java b/api/src/main/java/ai/djl/translate/Batchifier.java index 32fa23f94fb..9ffcb781d6b 100644 --- a/api/src/main/java/ai/djl/translate/Batchifier.java +++ b/api/src/main/java/ai/djl/translate/Batchifier.java @@ -13,6 +13,7 @@ package ai.djl.translate; import ai.djl.ndarray.NDList; +import ai.djl.util.ClassLoaderUtils; import java.io.Serializable; import java.util.Arrays; @@ -47,7 +48,12 @@ static Batchifier fromString(String name) { case "none": return null; default: - throw new IllegalArgumentException("Invalid batchifier name"); + ClassLoader cl = ClassLoaderUtils.getContextClassLoader(); + Batchifier b = ClassLoaderUtils.initClass(cl, Batchifier.class, name); + if (b == null) { + throw new IllegalArgumentException("Invalid batchifier name: " + name); + } + return b; } } diff --git a/api/src/test/java/ai/djl/translate/BatchifierTest.java b/api/src/test/java/ai/djl/translate/BatchifierTest.java index c7ef460e951..92afd9ea4a1 100644 --- a/api/src/test/java/ai/djl/translate/BatchifierTest.java +++ b/api/src/test/java/ai/djl/translate/BatchifierTest.java @@ -22,6 +22,17 @@ public void testBatchifier() { Batchifier batchifier = Batchifier.fromString("stack"); Assert.assertEquals(batchifier, Batchifier.STACK); + batchifier = Batchifier.fromString("none"); + Assert.assertNull(batchifier); + + batchifier = Batchifier.fromString("padding"); + Assert.assertNotNull(batchifier); + Assert.assertEquals(batchifier.getClass(), SimplePaddingStackBatchifier.class); + + batchifier = Batchifier.fromString("ai.djl.translate.SimplePaddingStackBatchifier"); + Assert.assertNotNull(batchifier); + Assert.assertEquals(batchifier.getClass(), SimplePaddingStackBatchifier.class); + Assert.assertThrows(() -> Batchifier.fromString("invalid")); } } diff --git a/basicdataset/src/test/java/ai/djl/basicdataset/MnistTest.java b/basicdataset/src/test/java/ai/djl/basicdataset/MnistTest.java index 8d0e609ccd4..792ff77b7f7 100644 --- a/basicdataset/src/test/java/ai/djl/basicdataset/MnistTest.java +++ b/basicdataset/src/test/java/ai/djl/basicdataset/MnistTest.java @@ -157,7 +157,7 @@ public void testBulkEqualsNonBulk() throws IOException, TranslateException { .optUsage(Dataset.Usage.TEST) .optRepository(repository) .setSampling(32, false) - .optLabelBatchifier(new StackBatchifier() {}) + .optLabelBatchifier(new StackBatchifier()) .build(); try (Trainer trainer = model.newTrainer(config)) {