diff --git a/.gitignore b/.gitignore
index ad84f2ea923..af083e92af7 100644
--- a/.gitignore
+++ b/.gitignore
@@ -14,6 +14,7 @@ libs
*.dylib
*.dll
*.class
+*.log
# Eclipse
.settings
diff --git a/pytorch/pytorch-engine/build.gradle b/pytorch/pytorch-engine/build.gradle
index 216b9fada05..dc95a03b3eb 100644
--- a/pytorch/pytorch-engine/build.gradle
+++ b/pytorch/pytorch-engine/build.gradle
@@ -27,11 +27,11 @@ processResources {
]
def classesDir = "${project.buildDir}/jnilib"
files.each { entry ->
- project.logger.lifecycle("Downloading ${url}/${entry.key}")
def file = new File("${classesDir}/${entry.value}")
if (file.exists()) {
project.logger.lifecycle("prebuilt or cached file found for ${entry.value}")
} else {
+ project.logger.lifecycle("Downloading ${url}/${entry.key}")
file.getParentFile().mkdirs()
new URL("${url}/${entry.key}").withInputStream { i -> file.withOutputStream { it << i } }
}
diff --git a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtSymbolBlock.java b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtSymbolBlock.java
index 7215aba5a48..0d671ff0331 100644
--- a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtSymbolBlock.java
+++ b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtSymbolBlock.java
@@ -19,6 +19,7 @@
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractSymbolBlock;
import ai.djl.nn.SymbolBlock;
+import ai.djl.pytorch.jni.IValue;
import ai.djl.pytorch.jni.IValueUtils;
import ai.djl.pytorch.jni.JniUtils;
import ai.djl.training.ParameterStore;
@@ -88,6 +89,16 @@ public void close() {
}
}
+ /**
+ * Runs the forward of this PyTorch module.
+ *
+ * @param inputs the input {@link IValue}
+ * @return the result {@link IValue}
+ */
+ public IValue forward(IValue... inputs) {
+ return IValueUtils.forward(this, inputs);
+ }
+
/** {@inheritDoc} */
@Override
protected NDList forwardInternal(
diff --git a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/IValue.java b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/IValue.java
new file mode 100644
index 00000000000..71d31e3420b
--- /dev/null
+++ b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/IValue.java
@@ -0,0 +1,402 @@
+/*
+ * Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
+ * with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0/
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
+ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
+ * and limitations under the License.
+ */
+package ai.djl.pytorch.jni;
+
+import ai.djl.ndarray.NDList;
+import ai.djl.pytorch.engine.PtNDArray;
+import ai.djl.pytorch.engine.PtNDManager;
+import ai.djl.util.NativeResource;
+import java.util.Arrays;
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+
+/**
+ * A class represent a PyTorch {@code IValue} data.
+ *
+ *
DJL doesn't support creating nested IValue.
+ */
+public class IValue extends NativeResource {
+
+ IValue(long handle) {
+ super(handle);
+ }
+
+ /**
+ * Returns if the IValue is a {@code Tensor} type.
+ *
+ * @return if the IValue is a Tensor type
+ */
+ public boolean isTensor() {
+ return PyTorchLibrary.LIB.iValueIsTensor(getHandle());
+ }
+
+ /**
+ * Returns if the IValue is a {@code boolean} type.
+ *
+ * @return if the IValue is a boolean type
+ */
+ public boolean isBoolean() {
+ return PyTorchLibrary.LIB.iValueIsBool(getHandle());
+ }
+
+ /**
+ * Returns if the IValue is a {@code long} type.
+ *
+ * @return if the IValue is a long type
+ */
+ public boolean isLong() {
+ return PyTorchLibrary.LIB.iValueIsLong(getHandle());
+ }
+
+ /**
+ * Returns if the IValue is a {@code double} type.
+ *
+ * @return if the IValue is a double type
+ */
+ public boolean isDouble() {
+ return PyTorchLibrary.LIB.iValueIsDouble(getHandle());
+ }
+
+ /**
+ * Returns if the IValue is a {@code String} type.
+ *
+ * @return if the IValue is a String type
+ */
+ public boolean isString() {
+ return PyTorchLibrary.LIB.iValueIsString(getHandle());
+ }
+
+ /**
+ * Returns if the IValue is a {@code boolean[]} type.
+ *
+ * @return if the IValue is a boolean[] type
+ */
+ public boolean isBooleanList() {
+ return PyTorchLibrary.LIB.iValueIsBoolList(getHandle());
+ }
+
+ /**
+ * Returns if the IValue is a {@code long[]} type.
+ *
+ * @return if the IValue is a long[] type
+ */
+ public boolean isLongList() {
+ return PyTorchLibrary.LIB.iValueIsLongList(getHandle());
+ }
+
+ /**
+ * Returns if the IValue is a {@code double[]} type.
+ *
+ * @return if the IValue is a double[] type
+ */
+ public boolean isDoubleList() {
+ return PyTorchLibrary.LIB.iValueIsDoubleList(getHandle());
+ }
+
+ /**
+ * Returns if the IValue is a {@code IValue[]} type.
+ *
+ * The elements in the array must have the same type.
+ *
+ * @return if the IValue is a IValue[] type
+ */
+ public boolean isTensorList() {
+ return PyTorchLibrary.LIB.iValueIsTensorList(getHandle());
+ }
+
+ /**
+ * Returns if the IValue is a {@code IValue[]} type.
+ *
+ *
The elements in the array must have the same type.
+ *
+ * @return if the IValue is a IValue[] type
+ */
+ public boolean isList() {
+ return PyTorchLibrary.LIB.iValueIsList(getHandle());
+ }
+
+ /**
+ * Returns if the IValue is a {@code Map<String, V>} type.
+ *
+ * @return if the IValue is a Map<String, V> type
+ */
+ public boolean isMap() {
+ return PyTorchLibrary.LIB.iValueIsMap(getHandle());
+ }
+
+ /**
+ * Creates a new {@code IValue} of type {@code PtNDArray}.
+ *
+ * @param value the NDArray value
+ * @return a new {@code IValue} of type {@code PtNDArray}
+ */
+ public static IValue from(PtNDArray value) {
+ return new IValue(PyTorchLibrary.LIB.iValueFromTensor(value.getHandle()));
+ }
+
+ /**
+ * Creates a new {@code IValue} of type {@code boolean}.
+ *
+ * @param value the boolean value
+ * @return a new {@code IValue} of type {@code boolean}
+ */
+ public static IValue from(boolean value) {
+ return new IValue(PyTorchLibrary.LIB.iValueFromBool(value));
+ }
+
+ /**
+ * Creates a new {@code IValue} of type {@code long}.
+ *
+ * @param value the long value
+ * @return a new {@code IValue} of type {@code long}
+ */
+ public static IValue from(long value) {
+ return new IValue(PyTorchLibrary.LIB.iValueFromLong(value));
+ }
+
+ /**
+ * Creates a new {@code IValue} of type {@code double}.
+ *
+ * @param value the double value
+ * @return a new {@code IValue} of type {@code double}
+ */
+ public static IValue from(double value) {
+ return new IValue(PyTorchLibrary.LIB.iValueFromDouble(value));
+ }
+
+ /**
+ * Creates a new {@code IValue} of type {@code String}.
+ *
+ * @param value the String value
+ * @return a new {@code IValue} of type {@code String}
+ */
+ public static IValue from(String value) {
+ return new IValue(PyTorchLibrary.LIB.iValueFromString(value));
+ }
+
+ /**
+ * Creates a new {@code IValue} of type {@code boolean[]}.
+ *
+ * @param list the boolean[] value
+ * @return a new {@code IValue} of type {@code boolean[]}
+ */
+ public static IValue listFrom(boolean... list) {
+ return new IValue(PyTorchLibrary.LIB.iValueFromBoolList(list));
+ }
+
+ /**
+ * Creates a new {@code IValue} of type {@code long[]}.
+ *
+ * @param list the long[] value
+ * @return a new {@code IValue} of type {@code long[]}
+ */
+ public static IValue listFrom(long... list) {
+ return new IValue(PyTorchLibrary.LIB.iValueFromLongList(list));
+ }
+
+ /**
+ * Creates a new {@code IValue} of type {@code double[]}.
+ *
+ * @param list the double[] value
+ * @return a new {@code IValue} of type {@code double[]}
+ */
+ public static IValue listFrom(double... list) {
+ return new IValue(PyTorchLibrary.LIB.iValueFromDoubleList(list));
+ }
+
+ /**
+ * Creates a new {@code IValue} of type {@code NDArray[]}.
+ *
+ * @param list the NDArray[] value
+ * @return a new {@code IValue} of type {@code NDArray[]}
+ */
+ public static IValue listFrom(PtNDArray... list) {
+ long[] tensors = Arrays.stream(list).mapToLong(PtNDArray::getHandle).toArray();
+ return new IValue(PyTorchLibrary.LIB.iValueFromTensorList(tensors));
+ }
+
+ /**
+ * Creates a new {@code IValue} of type {@code Map[String, PtNDArray]}.
+ *
+ * @param map the Map[String, IValue] value
+ * @return a new {@code IValue} of type {@code Map[String, PtNDArray]}
+ */
+ public static IValue stringMapFrom(Map map) {
+ String[] keys = new String[map.size()];
+ long[] handles = new long[map.size()];
+ int i = 0;
+ for (Map.Entry entry : map.entrySet()) {
+ keys[i] = entry.getKey();
+ handles[i] = entry.getValue().getHandle();
+ ++i;
+ }
+ return new IValue(PyTorchLibrary.LIB.iValueFromStringMap(keys, handles));
+ }
+
+ /**
+ * Returns the {@code boolean} value of this IValue.
+ *
+ * @return the boolean value of this IValue
+ */
+ public boolean toBoolean() {
+ return PyTorchLibrary.LIB.iValueToBool(getHandle());
+ }
+
+ /**
+ * Returns the {@code long} value of this IValue.
+ *
+ * @return the long value of this IValue
+ */
+ public long toLong() {
+ return PyTorchLibrary.LIB.iValueToLong(getHandle());
+ }
+
+ /**
+ * Returns the {@code double} value of this IValue.
+ *
+ * @return the double value of this IValue
+ */
+ public double toDouble() {
+ return PyTorchLibrary.LIB.iValueToDouble(getHandle());
+ }
+
+ /**
+ * Returns the {@code String} value of this IValue.
+ *
+ * @return the String value of this IValue
+ */
+ public String toStringValue() {
+ return PyTorchLibrary.LIB.iValueToString(getHandle());
+ }
+
+ /**
+ * Returns the {@code boolean[]} value of this IValue.
+ *
+ * @return the boolean[] value of this IValue
+ */
+ public boolean[] toBooleanArray() {
+ return PyTorchLibrary.LIB.iValueToBoolList(getHandle());
+ }
+
+ /**
+ * Returns the {@code long[]} value of this IValue.
+ *
+ * @return the long[] value of this IValue
+ */
+ public long[] toLongArray() {
+ return PyTorchLibrary.LIB.iValueToLongList(getHandle());
+ }
+
+ /**
+ * Returns the {@code double[]} value of this IValue.
+ *
+ * @return the double[] value of this IValue
+ */
+ public double[] toDoubleArray() {
+ return PyTorchLibrary.LIB.iValueToDoubleList(getHandle());
+ }
+
+ /**
+ * Returns the {@code NDArray} value of this IValue.
+ *
+ * @param manager the {@code NDManager} to create the NDArray
+ * @return the NDArray value of this IValue
+ */
+ public PtNDArray toTensor(PtNDManager manager) {
+ return new PtNDArray(manager, PyTorchLibrary.LIB.iValueToTensor(getHandle()));
+ }
+
+ /**
+ * Returns the {@code NDArray[]} value of this IValue.
+ *
+ * @param manager the NDManager to create NDArray
+ * @return the NDArray[] value of this IValue
+ */
+ public PtNDArray[] toTensorArray(PtNDManager manager) {
+ long[] handles = PyTorchLibrary.LIB.iValueToTensorList(getHandle());
+ PtNDArray[] ret = new PtNDArray[handles.length];
+ for (int i = 0; i < ret.length; ++i) {
+ ret[i] = new PtNDArray(manager, handles[i]);
+ }
+ return ret;
+ }
+
+ /**
+ * Returns the {@code IValue[]} value of this IValue list.
+ *
+ * @return the IValue[] value of this IValue list
+ */
+ public IValue[] toIValueArray() {
+ long[] handles = PyTorchLibrary.LIB.iValueToIValueList(getHandle());
+ IValue[] ret = new IValue[handles.length];
+ for (int i = 0; i < ret.length; ++i) {
+ ret[i] = new IValue(handles[i]);
+ }
+ return ret;
+ }
+
+ /**
+ * Returns the {@code Map<String, IValue>} value of this IValue.
+ *
+ * @return the Map<String, IValue> value of this IValue
+ */
+ public Map toIValueMap() {
+ long[] handles = PyTorchLibrary.LIB.iValueToMap(getHandle());
+ Map map = new ConcurrentHashMap<>();
+ for (int i = 0; i < handles.length; i += 2) {
+ IValue key = new IValue(handles[i]);
+ map.put(key.toStringValue(), new IValue(handles[i + 1]));
+ key.close();
+ }
+ return map;
+ }
+
+ /**
+ * Returns the {@code NDList} value of this IValue.
+ *
+ * @param manager the NDManager to create NDArray
+ * @return the {@code NDList} value of this IValue
+ */
+ public NDList toNDList(PtNDManager manager) {
+ if (isTensor()) {
+ return new NDList(toTensor(manager));
+ } else if (isTensorList()) {
+ return new NDList(toTensorArray(manager));
+ } else if (isMap()) {
+ // Only allows one level type of map
+ NDList list = new NDList();
+ Map map = toIValueMap();
+ for (Map.Entry entry : map.entrySet()) {
+ IValue iv = entry.getValue();
+ if (!iv.isTensor()) {
+ throw new UnsupportedOperationException("Only one level of map is supported.");
+ }
+ PtNDArray value = entry.getValue().toTensor(manager);
+ value.setName(entry.getKey());
+ list.add(value);
+ iv.close();
+ }
+ return list;
+ }
+ throw new UnsupportedOperationException("Unsupported IValue type.");
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public void close() {
+ Long pointer = handle.getAndSet(null);
+ if (pointer != null) {
+ PyTorchLibrary.LIB.torchDeleteIValue(pointer);
+ }
+ }
+}
diff --git a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/IValueUtils.java b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/IValueUtils.java
index e1898713b2e..41cb8f57cf7 100644
--- a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/IValueUtils.java
+++ b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/IValueUtils.java
@@ -20,13 +20,12 @@
import ai.djl.pytorch.engine.PtSymbolBlock;
import ai.djl.util.Pair;
import ai.djl.util.PairList;
-import ai.djl.util.Preconditions;
import java.util.ArrayList;
+import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.regex.Pattern;
-import java.util.stream.Stream;
/** IValueUtils is utility class to deal with IValue in PyTorch. */
public final class IValueUtils {
@@ -34,278 +33,85 @@ public final class IValueUtils {
private IValueUtils() {}
/**
- * Create IValue Pointer from NDArray.
+ * Runs the forward of PyTorch module.
*
- * @param arrayHandle the handle for PyTorch Tensor
- * @return IValue Pointer
- */
- public static long toIValuePointer(long arrayHandle) {
- return PyTorchLibrary.LIB.iValueFromTensor(arrayHandle);
- }
-
- /**
- * Create List IValue Pointer from pointer array.
- *
- * @param pointers pointer array
- * @return IValue Pointer
- */
- public static long iValueFromList(long[] pointers) {
- return PyTorchLibrary.LIB.iValueFromList(pointers);
- }
-
- /**
- * Create Dict IValue Pointer from pointer with its key name.
- *
- * @param pointers pointer array
- * @param names the key value of the pointer
- * @return IValue Pointer
- */
- public static long iValueFromDict(long[] pointers, String[] names) {
- return PyTorchLibrary.LIB.iValueFromDict(pointers, names);
- }
-
- /**
- * Check IValue is a container of {@link PtNDArray}.
- *
- * @param iValueHandle IValue pointer
- * @return result
- */
- public static boolean isNDArray(long iValueHandle) {
- return PyTorchLibrary.LIB.iValueIsTensor(iValueHandle);
- }
-
- /**
- * Check IValue is a container of {@link NDList}.
- *
- * @param iValueHandle IValue pointer
- * @return result
- */
- public static boolean isNDList(long iValueHandle) {
- return PyTorchLibrary.LIB.iValueIsTensorList(iValueHandle);
- }
-
- /**
- * Check IValue is a container of IValue List.
- *
- * @param iValueHandle IValue pointer
- * @return result
- */
- public static boolean isList(long iValueHandle) {
- return PyTorchLibrary.LIB.iValueIsList(iValueHandle);
- }
-
- /**
- * Check IValue is a container of IValue Tuple.
- *
- * @param iValueHandle IValue pointer
- * @return result
- */
- public static boolean isTuple(long iValueHandle) {
- return PyTorchLibrary.LIB.iValueIsTuple(iValueHandle);
- }
-
- /**
- * Check IValue is a container of IValue Map.
- *
- * @param iValueHandle IValue pointer
- * @return result
- */
- public static boolean isMap(long iValueHandle) {
- return PyTorchLibrary.LIB.iValueIsMap(iValueHandle);
- }
-
- /**
- * Check IValue is a container of String.
- *
- * @param iValueHandle IValue pointer
- * @return result
- */
- public static boolean isString(long iValueHandle) {
- return PyTorchLibrary.LIB.iValueIsString(iValueHandle);
- }
-
- /**
- * Extract IValue with a {@link PtNDArray} value.
- *
- * @param iValueHandle IValue pointer
- * @param manager {@link PtNDManager} that creates {@link PtNDArray}
- * @return {@link ai.djl.ndarray.NDArray}
- */
- public static PtNDArray toNDArray(long iValueHandle, PtNDManager manager) {
- long ndHandle = PyTorchLibrary.LIB.iValueToTensor(iValueHandle);
- return new PtNDArray(manager, ndHandle);
- }
-
- /**
- * Extract IValue to {@link NDList}.
- *
- * @param iValueHandle IValue pointer
- * @param manager {@link PtNDManager} that creates {@link PtNDArray}
- * @return {@link NDList}
- */
- public static NDList toNDList(long iValueHandle, PtNDManager manager) {
- long[] ndHandles = PyTorchLibrary.LIB.iValueToTensorList(iValueHandle);
- NDList list = new NDList();
- for (long handle : ndHandles) {
- list.add(new PtNDArray(manager, handle));
- }
- return list;
- }
-
- /**
- * Extract IValue to String.
- *
- * @param iValueHandle IValue pointer
- * @return String
- */
- public static String toString(long iValueHandle) {
- return PyTorchLibrary.LIB.iValueToString(iValueHandle);
- }
-
- /**
- * Extract IValue to an IValue Array.
- *
- * @param iValueHandle IValue pointer
- * @return IValue array
- */
- public static long[] toIValueArray(long iValueHandle) {
- if (isTuple(iValueHandle)) {
- return PyTorchLibrary.LIB.iValueToListFromTuple(iValueHandle);
- }
- return PyTorchLibrary.LIB.iValueToList(iValueHandle);
- }
-
- /**
- * Extract IValue to a Map.
- *
- * @param iValueHandle IValue pointer
- * @return IValue Map
+ * @param block the block that contains PyTorch module
+ * @param inputs the input {@link NDList}
+ * @param isTrain if running on training mode
+ * @return the result {@link NDList}
*/
- public static Map toIValueMap(long iValueHandle) {
- long[] iValueHandles = PyTorchLibrary.LIB.iValueToMap(iValueHandle);
- Map map = new ConcurrentHashMap<>();
- for (int i = 0; i < iValueHandles.length; i += 2) {
- map.put(iValueHandles[i], iValueHandles[i + 1]);
- }
- return map;
- }
-
- private static NDList forwardHelper(long iValueHandle, PtNDManager manager) {
- NDList list = new NDList();
- if (isNDArray(iValueHandle)) {
- list.add(toNDArray(iValueHandle, manager));
- } else if (isNDList(iValueHandle)) {
- list.addAll(toNDList(iValueHandle, manager));
- } else if (isList(iValueHandle) || isTuple(iValueHandle)) {
- for (long handle : toIValueArray(iValueHandle)) {
- list.addAll(forwardHelper(handle, manager));
- }
- } else if (isMap(iValueHandle)) {
- // Only allows type of map
- Map map = toIValueMap(iValueHandle);
- for (Map.Entry entry : map.entrySet()) {
- String name = toString(entry.getKey());
- // free the IValue handle
- PyTorchLibrary.LIB.torchDeleteIValue(entry.getKey());
- PtNDArray value = toNDArray(entry.getValue(), manager);
- // free the IValue handle
- PyTorchLibrary.LIB.torchDeleteIValue(entry.getValue());
- value.setName(name);
- list.add(value);
- }
- } else {
- // free the IValue handle
- PyTorchLibrary.LIB.torchDeleteIValue(iValueHandle);
- throw new UnsupportedOperationException("Unsupported IValue type");
+ public static NDList forward(PtSymbolBlock block, NDList inputs, boolean isTrain) {
+ IValue[] iValues = getInputs(inputs);
+ long[] iValueHandles = Arrays.stream(iValues).mapToLong(IValue::getHandle).toArray();
+ long result = PyTorchLibrary.LIB.moduleForward(block.getHandle(), iValueHandles, isTrain);
+ PtNDManager manager = (PtNDManager) inputs.get(0).getManager();
+ Arrays.stream(iValues).forEach(IValue::close);
+ try (IValue iValue = new IValue(result)) {
+ return iValue.toNDList(manager);
}
- // free the IValue handle
- PyTorchLibrary.LIB.torchDeleteIValue(iValueHandle);
- return list;
}
/**
- * Run the forward of PyTorch module.
+ * Runs the forward of PyTorch module.
*
* @param block the block that contains PyTorch module
- * @param inputs input {@link NDList}
- * @param isTrain is running on training mode
- * @return result {@link NDList}
+ * @param inputs the input {@link IValue}
+ * @return the result {@link IValue}
*/
- public static NDList forward(PtSymbolBlock block, NDList inputs, boolean isTrain) {
- long[] arrayHandles =
- inputs.stream().mapToLong(input -> ((PtNDArray) input).getHandle()).toArray();
- String[] names = inputs.stream().map(NDArray::getName).toArray(String[]::new);
- long[] iValueInputs = getInputs(arrayHandles, names);
- long result = PyTorchLibrary.LIB.moduleForward(block.getHandle(), iValueInputs, isTrain);
- PtNDManager manager = (PtNDManager) inputs.get(0).getManager();
- return forwardHelper(result, manager);
+ public static IValue forward(PtSymbolBlock block, IValue... inputs) {
+ long[] handles = Arrays.stream(inputs).mapToLong(IValue::getHandle).toArray();
+ return new IValue(PyTorchLibrary.LIB.moduleForward(block.getHandle(), handles, false));
}
- private static boolean isNameList(String name) {
- return Pattern.matches("\\w+\\[]", name);
+ private static int addToMap(
+ Map map, String key, List> list) {
+ return map.computeIfAbsent(
+ key,
+ k -> {
+ list.add(new PairList<>());
+ return list.size() - 1;
+ });
}
- private static boolean isNameDict(String name) {
- return name.contains(".");
- }
-
- private static long[] getInputs(long[] arrays, String[] names) {
- List> outputs = new ArrayList<>();
+ private static IValue[] getInputs(NDList ndList) {
+ List> outputs = new ArrayList<>();
Map indexMap = new ConcurrentHashMap<>();
- for (int i = 0; i < arrays.length; i++) {
- String name = names[i];
- if (name == null || (!isNameList(name) && !isNameDict(name))) {
- PairList list = new PairList<>();
- list.add(new Pair<>(null, toIValuePointer(arrays[i])));
- outputs.add(list);
- continue;
- }
- String mapKey = null;
- boolean isDict = isNameDict(names[i]);
- if (isDict) {
- String[] strings = names[i].split("\\.");
- Preconditions.checkArgument(
- strings.length == 2,
- "Please make sure you only include one '.' in the name. Nested Map is not supported!");
- name = strings[0];
- mapKey = strings[1];
- }
- if (!indexMap.containsKey(name)) {
- outputs.add(new PairList<>());
- indexMap.put(name, outputs.size() - 1);
- }
- if (isDict) {
- outputs.get(indexMap.get(name)).add(new Pair<>(mapKey, arrays[i]));
+ for (NDArray array : ndList) {
+ String name = array.getName();
+ if (name != null && name.contains(".")) {
+ String[] strings = name.split("\\.", 2);
+ int index = addToMap(indexMap, strings[0], outputs);
+ PairList pl = outputs.get(index);
+ pl.add(strings[1], (PtNDArray) array);
+ } else if (name != null && Pattern.matches("\\w+\\[]", name)) {
+ int index = addToMap(indexMap, name, outputs);
+ PairList pl = outputs.get(index);
+ pl.add("[]", (PtNDArray) array);
} else {
- outputs.get(indexMap.get(name)).add(new Pair<>(name, arrays[i]));
+ PairList pl = new PairList<>();
+ pl.add(null, (PtNDArray) array);
+ outputs.add(pl);
}
}
- long[] pointers = new long[outputs.size()];
+ IValue[] ret = new IValue[outputs.size()];
for (int i = 0; i < outputs.size(); ++i) {
- // not List, Dict input
- if (outputs.get(i).size() == 1 && outputs.get(i).get(0).getKey() == null) {
- pointers[i] = outputs.get(i).get(0).getValue();
- } else if (isNameList(outputs.get(i).get(0).getKey())) {
- pointers[i] =
- iValueFromList(
- toPrimitiveLongArray(outputs.get(i).valueArray(new Long[0])));
+ PairList pl = outputs.get(i);
+ String key = pl.get(0).getKey();
+ if (key == null) {
+ // not List, Dict input
+ ret[i] = IValue.from(pl.get(0).getValue());
+ } else if ("[]".equals(key)) {
+ // list
+ PtNDArray[] arrays = pl.values().toArray(new PtNDArray[0]);
+ ret[i] = IValue.listFrom(arrays);
} else {
- PairList dict = outputs.get(i);
- pointers[i] =
- iValueFromDict(
- toPrimitiveLongArray(dict.valueArray(new Long[0])),
- dict.keyArray(new String[0]));
+ Map map = new ConcurrentHashMap<>();
+ for (Pair pair : pl) {
+ map.put(pair.getKey(), pair.getValue());
+ }
+ ret[i] = IValue.stringMapFrom(map);
}
}
- return pointers;
- }
-
- private static long[] toPrimitiveLongArray(Long[] array) {
- if (array == null) {
- return null;
- } else if (array.length == 0) {
- return new long[0];
- }
- return Stream.of(array).mapToLong(Long::longValue).toArray();
+ return ret;
}
}
diff --git a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java
index c078e7c7cef..dc2dcb19e4c 100644
--- a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java
+++ b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java
@@ -491,34 +491,74 @@ native long moduleLoad(
native long iValueFromTensor(long tensorHandle);
- native long iValueFromList(long[] tensorHandles);
+ native long iValueFromBool(boolean value);
- native long iValueFromDict(long[] tensorHandles, String[] names);
+ native long iValueFromLong(long value);
+
+ native long iValueFromDouble(double value);
+
+ native long iValueFromString(String value);
+
+ native long iValueFromBoolList(boolean... value);
+
+ native long iValueFromLongList(long... value);
+
+ native long iValueFromDoubleList(double... value);
+
+ native long iValueFromTensorList(long[] tensorHandles);
+
+ native long iValueFromTuple(long[] ivalueHandles);
+
+ native long iValueFromStringMap(String[] keys, long[] tensorHandles);
native long iValueToTensor(long iValueHandle);
+ native boolean iValueToBool(long iValueHandle);
+
+ native long iValueToLong(long iValueHandle);
+
+ native double iValueToDouble(long iValueHandle);
+
+ native String iValueToString(long iValueHandle);
+
+ native boolean[] iValueToBoolList(long iValueHandle);
+
+ native long[] iValueToLongList(long iValueHandle);
+
+ native double[] iValueToDoubleList(long iValueHandle);
+
native long[] iValueToTensorList(long iValueHandle);
- native long[] iValueToList(long iValueHandle);
+ native long[] iValueToIValueList(long iValueHandle);
- native long[] iValueToListFromTuple(long iValueHandle);
+ native long[] iValueToIValueTuple(long iValueHandle);
native long[] iValueToMap(long iValueHandle);
- native String iValueToString(long iValueHandle);
+ native boolean iValueIsTensor(long iValueHandle);
+
+ native boolean iValueIsBool(long iValueHandle);
+
+ native boolean iValueIsLong(long iValueHandle);
+
+ native boolean iValueIsDouble(long iValueHandle);
native boolean iValueIsString(long iValueHandle);
- native boolean iValueIsTensor(long iValueHandle);
+ native boolean iValueIsBoolList(long iValueHandle);
+
+ native boolean iValueIsLongList(long iValueHandle);
+
+ native boolean iValueIsDoubleList(long iValueHandle);
native boolean iValueIsTensorList(long iValueHandle);
native boolean iValueIsList(long iValueHandle);
- native boolean iValueIsMap(long iValueHandle);
-
native boolean iValueIsTuple(long iValueHandle);
+ native boolean iValueIsMap(long iValueHandle);
+
native void zeroGrad(long handle);
native void adamUpdate(
diff --git a/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/IValueTest.java b/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/IValueTest.java
new file mode 100644
index 00000000000..6c776f59678
--- /dev/null
+++ b/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/IValueTest.java
@@ -0,0 +1,140 @@
+/*
+ * Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
+ * with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0/
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
+ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
+ * and limitations under the License.
+ */
+package ai.djl.pytorch.integration;
+
+import ai.djl.ModelException;
+import ai.djl.ndarray.NDArray;
+import ai.djl.ndarray.NDList;
+import ai.djl.ndarray.NDManager;
+import ai.djl.ndarray.types.Shape;
+import ai.djl.pytorch.engine.PtNDArray;
+import ai.djl.pytorch.engine.PtNDManager;
+import ai.djl.pytorch.engine.PtSymbolBlock;
+import ai.djl.pytorch.jni.IValue;
+import ai.djl.repository.zoo.Criteria;
+import ai.djl.repository.zoo.ZooModel;
+import ai.djl.training.util.ProgressBar;
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+import org.testng.Assert;
+import org.testng.annotations.Test;
+
+public class IValueTest {
+
+ @Test
+ public void testIValue() {
+ try (PtNDManager manager = (PtNDManager) NDManager.newBaseManager()) {
+ PtNDArray array1 = (PtNDArray) manager.zeros(new Shape(1));
+ PtNDArray array2 = (PtNDArray) manager.ones(new Shape(1));
+
+ try (IValue ivalue = IValue.from(array1)) {
+ Assert.assertTrue(ivalue.isTensor());
+ NDArray ret = ivalue.toTensor(manager);
+ Assert.assertEquals(ret, array1);
+ NDList list = ivalue.toNDList(manager);
+ Assert.assertEquals(list.size(), 1);
+ Assert.assertEquals(list.head(), array1);
+ }
+
+ try (IValue ivalue = IValue.from(true)) {
+ Assert.assertTrue(ivalue.isBoolean());
+ Assert.assertTrue(ivalue.toBoolean());
+ }
+
+ try (IValue ivalue = IValue.from(1)) {
+ Assert.assertTrue(ivalue.isLong());
+ Assert.assertEquals(ivalue.toLong(), 1);
+ }
+
+ try (IValue ivalue = IValue.from(1d)) {
+ Assert.assertTrue(ivalue.isDouble());
+ Assert.assertEquals(ivalue.toDouble(), 1d);
+ }
+
+ try (IValue ivalue = IValue.from("test")) {
+ Assert.assertTrue(ivalue.isString());
+ Assert.assertEquals(ivalue.toStringValue(), "test");
+ }
+
+ try (IValue ivalue = IValue.listFrom(true, false)) {
+ Assert.assertTrue(ivalue.isList());
+ Assert.assertTrue(ivalue.isBooleanList());
+ Assert.assertEquals(ivalue.toBooleanArray(), new boolean[] {true, false});
+ }
+
+ try (IValue ivalue = IValue.listFrom(1, 2)) {
+ Assert.assertTrue(ivalue.isLongList());
+ Assert.assertEquals(ivalue.toLongArray(), new long[] {1, 2});
+ }
+
+ try (IValue ivalue = IValue.listFrom(1d, 2d)) {
+ Assert.assertTrue(ivalue.isDoubleList());
+ Assert.assertEquals(ivalue.toDoubleArray(), new double[] {1d, 2d});
+ }
+
+ try (IValue ivalue = IValue.listFrom(array1, array2)) {
+ Assert.assertTrue(ivalue.isTensorList());
+ NDArray[] ret = ivalue.toTensorArray(manager);
+ Assert.assertEquals(ret.length, 2);
+ NDList list = ivalue.toNDList(manager);
+ Assert.assertEquals(list.size(), 2);
+ Assert.assertEquals(list.head(), array1);
+
+ IValue[] iValues = ivalue.toIValueArray();
+ Assert.assertEquals(iValues.length, 2);
+ Assert.assertTrue(iValues[0].isTensor());
+ Arrays.stream(iValues).forEach(IValue::close);
+ }
+
+ Map map = new ConcurrentHashMap<>();
+ map.put("data1", array1);
+ map.put("data2", array2);
+ try (IValue ivalue = IValue.stringMapFrom(map)) {
+ Assert.assertTrue(ivalue.isMap());
+ Map ret = ivalue.toIValueMap();
+ Assert.assertEquals(ret.size(), 2);
+
+ NDList list = ivalue.toNDList(manager);
+ Assert.assertEquals(list.size(), 2);
+ Assert.assertEquals(list.get("data1"), array1);
+ }
+ }
+ }
+
+ @Test
+ public void testIValueModel() throws IOException, ModelException {
+ Criteria criteria =
+ Criteria.builder()
+ .setTypes(NDList.class, NDList.class)
+ .optModelUrls("https://resources.djl.ai/test-models/ivalue_jit.zip")
+ .optProgress(new ProgressBar())
+ .build();
+
+ try (ZooModel model = criteria.loadModel()) {
+ PtSymbolBlock block = (PtSymbolBlock) model.getBlock();
+ IValue tokens = IValue.listFrom(1, 2, 3);
+ IValue cls = IValue.from(0);
+ IValue sep = IValue.from(4);
+ IValue ret = block.forward(tokens, cls, sep);
+ long[] actual = ret.toLongArray();
+ Assert.assertEquals(actual, new long[] {0, 1, 2, 3, 4});
+
+ tokens.close();
+ cls.close();
+ sep.close();
+ ret.close();
+ }
+ }
+}
diff --git a/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/PtModelTest.java b/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/PtModelTest.java
index cda9b86d8dd..ab2bcf16b82 100644
--- a/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/PtModelTest.java
+++ b/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/PtModelTest.java
@@ -13,30 +13,43 @@
package ai.djl.pytorch.integration;
import ai.djl.Model;
+import ai.djl.ModelException;
import ai.djl.inference.Predictor;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.Shape;
import ai.djl.pytorch.engine.PtModel;
+import ai.djl.repository.zoo.Criteria;
+import ai.djl.repository.zoo.ZooModel;
+import ai.djl.training.util.ProgressBar;
import ai.djl.translate.NoopTranslator;
import ai.djl.translate.TranslateException;
import java.io.IOException;
-import java.net.URL;
+import java.nio.file.Files;
+import java.nio.file.Path;
import org.testng.Assert;
import org.testng.annotations.Test;
public class PtModelTest {
@Test
- public void testLoadFromStream() throws IOException, TranslateException {
- URL url =
- new URL("https://djl-ai.s3.amazonaws.com/resources/test-models/traced_resnet18.pt");
- try (PtModel model = (PtModel) Model.newInstance("test model")) {
- model.load(url.openStream());
- try (Predictor predictor = model.newPredictor(new NoopTranslator())) {
- NDArray array = model.getNDManager().ones(new Shape(1, 3, 224, 224));
- NDArray result = predictor.predict(new NDList(array)).singletonOrThrow();
- Assert.assertEquals(result.getShape(), new Shape(1, 1000));
+ public void testLoadFromStream() throws IOException, TranslateException, ModelException {
+ Criteria criteria =
+ Criteria.builder()
+ .setTypes(NDList.class, NDList.class)
+ .optModelUrls("djl://ai.djl.pytorch/resnet/0.0.1/traced_resnet18")
+ .optProgress(new ProgressBar())
+ .build();
+ try (ZooModel zooModel = criteria.loadModel()) {
+ Path modelFile = zooModel.getModelPath().resolve("traced_resnet18.pt");
+ try (PtModel model = (PtModel) Model.newInstance("test model")) {
+ model.load(Files.newInputStream(modelFile));
+ try (Predictor predictor =
+ model.newPredictor(new NoopTranslator())) {
+ NDArray array = model.getNDManager().ones(new Shape(1, 3, 224, 224));
+ NDArray result = predictor.predict(new NDList(array)).singletonOrThrow();
+ Assert.assertEquals(result.getShape(), new Shape(1, 1000));
+ }
}
}
}
diff --git a/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/TorchScriptTest.java b/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/TorchScriptTest.java
index d84565025a9..a57f5cafb90 100644
--- a/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/TorchScriptTest.java
+++ b/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/TorchScriptTest.java
@@ -15,10 +15,10 @@
import ai.djl.ModelException;
import ai.djl.inference.Predictor;
-import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
+import ai.djl.pytorch.engine.PtNDArray;
import ai.djl.pytorch.engine.PtNDManager;
import ai.djl.pytorch.engine.PtSymbolBlock;
import ai.djl.pytorch.jni.JniUtils;
@@ -30,7 +30,7 @@
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
-import java.net.URL;
+import java.nio.file.Files;
import java.nio.file.Path;
import org.testng.Assert;
import org.testng.annotations.Test;
@@ -51,7 +51,7 @@ public void testDictInput() throws ModelException, IOException, TranslateExcepti
Path modelFile;
try (ZooModel model = criteria.loadModel();
Predictor predictor = model.newPredictor()) {
- NDArray array = manager.ones(new Shape(2, 2));
+ PtNDArray array = (PtNDArray) manager.ones(new Shape(2, 2));
array.setName("input1.input");
NDList output = predictor.predict(new NDList(array));
Assert.assertEquals(output.singletonOrThrow(), array);
@@ -73,11 +73,17 @@ public void testDictInput() throws ModelException, IOException, TranslateExcepti
}
@Test
- public void testInputOutput() throws IOException {
- URL url =
- new URL("https://djl-ai.s3.amazonaws.com/resources/test-models/traced_resnet18.pt");
- try (PtNDManager manager = (PtNDManager) NDManager.newBaseManager()) {
- try (InputStream is = url.openStream()) {
+ public void testInputOutput() throws IOException, ModelException {
+ Criteria criteria =
+ Criteria.builder()
+ .setTypes(NDList.class, NDList.class)
+ .optModelUrls("djl://ai.djl.pytorch/resnet/0.0.1/traced_resnet18")
+ .optProgress(new ProgressBar())
+ .build();
+ try (ZooModel model = criteria.loadModel()) {
+ PtNDManager manager = (PtNDManager) model.getNDManager();
+ Path modelFile = model.getModelPath().resolve("traced_resnet18.pt");
+ try (InputStream is = Files.newInputStream(modelFile)) {
PtSymbolBlock block = JniUtils.loadModule(manager, is, manager.getDevice(), false);
ByteArrayOutputStream os = new ByteArrayOutputStream();
JniUtils.writeModule(block, os, true);
diff --git a/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_inference.cc b/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_inference.cc
index 66f2e171ba4..628c862b2ef 100644
--- a/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_inference.cc
+++ b/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_inference.cc
@@ -183,11 +183,6 @@ JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_moduleForward(
JITCallGuard guard;
return module_ptr->forward(inputs);
}();
- // release resource
- // each IValue is created by new, free the memory after the inference
- for (auto i = 0; i < len; ++i) {
- delete reinterpret_cast(jptrs[i]);
- }
env->ReleaseLongArrayElements(jivalue_ptrs, jptrs, djl::utils::jni::RELEASE_MODE);
const auto* result_ptr = new torch::IValue(output);
return reinterpret_cast(result_ptr);
diff --git a/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_ivalue.cc b/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_ivalue.cc
index fea7582bc95..a20ac6f5369 100644
--- a/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_ivalue.cc
+++ b/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_ivalue.cc
@@ -26,35 +26,120 @@ JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueFromTensor(
API_END_RETURN()
}
-JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueFromList(
- JNIEnv* env, jobject jthis, jlongArray jtensor_ptrs) {
- jsize len = env->GetArrayLength(jtensor_ptrs);
- jlong* jptrs = env->GetLongArrayElements(jtensor_ptrs, JNI_FALSE);
+JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueFromBool(
+ JNIEnv* env, jobject jthis, jboolean jvalue) {
+ API_BEGIN()
+ const auto* ivalue_ptr = new torch::IValue((bool) jvalue);
+ return reinterpret_cast(ivalue_ptr);
+ API_END_RETURN()
+}
+
+JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueFromLong(
+ JNIEnv* env, jobject jthis, jlong jvalue) {
+ API_BEGIN()
+ const auto* ivalue_ptr = new torch::IValue((int64_t) jvalue);
+ return reinterpret_cast(ivalue_ptr);
+ API_END_RETURN()
+}
+
+JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueFromDouble(
+ JNIEnv* env, jobject jthis, jdouble jvalue) {
+ API_BEGIN()
+ const auto* ivalue_ptr = new torch::IValue(jvalue);
+ return reinterpret_cast(ivalue_ptr);
+ API_END_RETURN()
+}
+
+JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueFromString(
+ JNIEnv* env, jobject jthis, jstring jvalue) {
+ API_BEGIN()
+ const std::string value = djl::utils::jni::GetStringFromJString(env, jvalue);
+ const auto* ivalue_ptr = new torch::IValue(value);
+ return reinterpret_cast(ivalue_ptr);
+ API_END_RETURN()
+}
+
+JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueFromBoolList(
+ JNIEnv* env, jobject jthis, jbooleanArray jvalues) {
+ API_BEGIN()
+ jsize len = env->GetArrayLength(jvalues);
+ jboolean* jptrs = env->GetBooleanArrayElements(jvalues, JNI_FALSE);
+ torch::List list;
+ list.reserve(len);
+ for (size_t i = 0; i < len; ++i) {
+ list.emplace_back(jptrs[i]);
+ }
+ env->ReleaseBooleanArrayElements(jvalues, jptrs, JNI_ABORT);
+ const auto* ivalue_ptr = new torch::IValue(list);
+ return reinterpret_cast(ivalue_ptr);
+ API_END_RETURN()
+}
+
+JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueFromLongList(
+ JNIEnv* env, jobject jthis, jlongArray jvalues) {
+ API_BEGIN()
+ jsize len = env->GetArrayLength(jvalues);
+ jlong* jptrs = env->GetLongArrayElements(jvalues, JNI_FALSE);
+ torch::List list;
+ list.reserve(len);
+ for (size_t i = 0; i < len; ++i) {
+ list.emplace_back(jptrs[i]);
+ }
+ env->ReleaseLongArrayElements(jvalues, jptrs, JNI_ABORT);
+ const auto* ivalue_ptr = new torch::IValue(list);
+ return reinterpret_cast(ivalue_ptr);
+ API_END_RETURN()
+}
+
+JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueFromDoubleList(
+ JNIEnv* env, jobject jthis, jdoubleArray jvalues) {
+ API_BEGIN()
+ jsize len = env->GetArrayLength(jvalues);
+ jdouble* jptrs = env->GetDoubleArrayElements(jvalues, JNI_FALSE);
+ torch::List list;
+ list.reserve(len);
+ for (size_t i = 0; i < len; ++i) {
+ list.emplace_back(jptrs[i]);
+ }
+ env->ReleaseDoubleArrayElements(jvalues, jptrs, JNI_ABORT);
+ const auto* ivalue_ptr = new torch::IValue(list);
+ return reinterpret_cast(ivalue_ptr);
+ API_END_RETURN()
+}
+
+JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueFromTensorList(
+ JNIEnv* env, jobject jthis, jlongArray jvalues) {
+ API_BEGIN()
+ jsize len = env->GetArrayLength(jvalues);
+ jlong* jptrs = env->GetLongArrayElements(jvalues, JNI_FALSE);
torch::List list;
list.reserve(len);
for (size_t i = 0; i < len; ++i) {
list.emplace_back(*reinterpret_cast(jptrs[i]));
}
- env->ReleaseLongArrayElements(jtensor_ptrs, jptrs, JNI_ABORT);
- auto* result_ptr = new torch::IValue(list);
- return reinterpret_cast(result_ptr);
+ env->ReleaseLongArrayElements(jvalues, jptrs, JNI_ABORT);
+ const auto* ivalue_ptr = new torch::IValue(list);
+ return reinterpret_cast(ivalue_ptr);
+ API_END_RETURN()
}
-JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueFromDict(
- JNIEnv* env, jobject jthis, jlongArray jtensor_ptrs, jobjectArray jnames) {
- auto len = static_cast(env->GetArrayLength(jtensor_ptrs));
- jlong* jptrs = env->GetLongArrayElements(jtensor_ptrs, JNI_FALSE);
+JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueFromStringMap(
+ JNIEnv* env, jobject jthis, jobjectArray jkeys, jlongArray jvalues) {
+ API_BEGIN()
+ auto len = static_cast(env->GetArrayLength(jvalues));
+ jlong* jptrs = env->GetLongArrayElements(jvalues, JNI_FALSE);
torch::Dict dict;
dict.reserve(len);
for (size_t i = 0; i < len; ++i) {
- auto jname = (jstring) env->GetObjectArrayElement(jnames, i);
+ auto jname = (jstring) env->GetObjectArrayElement(jkeys, i);
std::string name = djl::utils::jni::GetStringFromJString(env, jname);
dict.insert(name, *reinterpret_cast(jptrs[i]));
}
- env->ReleaseLongArrayElements(jtensor_ptrs, jptrs, JNI_ABORT);
- env->DeleteLocalRef(jnames);
- auto* result_ptr = new torch::IValue(dict);
- return reinterpret_cast(result_ptr);
+ env->ReleaseLongArrayElements(jvalues, jptrs, JNI_ABORT);
+ env->DeleteLocalRef(jkeys);
+ const auto* ivalue_ptr = new torch::IValue(dict);
+ return reinterpret_cast(ivalue_ptr);
+ API_END_RETURN()
}
JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueToTensor(
@@ -65,12 +150,85 @@ JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueToTensor(
API_END_RETURN()
}
-JNIEXPORT jlongArray JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueToListFromTuple(
+JNIEXPORT jboolean JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueToBool(
JNIEnv* env, jobject jthis, jlong jhandle) {
API_BEGIN()
auto* ivalue_ptr = reinterpret_cast(jhandle);
- std::vector ivalue_vec = ivalue_ptr->toTuple()->elements();
- return djl::utils::jni::GetPtrArrayFromContainer, torch::IValue>(env, ivalue_vec);
+ return ivalue_ptr->toBool();
+ API_END_RETURN()
+}
+
+JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueToLong(JNIEnv* env, jobject jthis, jlong jhandle) {
+ API_BEGIN()
+ auto* ivalue_ptr = reinterpret_cast(jhandle);
+ return ivalue_ptr->toInt();
+ API_END_RETURN()
+}
+
+JNIEXPORT jdouble JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueToDouble(
+ JNIEnv* env, jobject jthis, jlong jhandle) {
+ API_BEGIN()
+ auto* ivalue_ptr = reinterpret_cast(jhandle);
+ return ivalue_ptr->toDouble();
+ API_END_RETURN()
+}
+
+JNIEXPORT jstring JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueToString(
+ JNIEnv* env, jobject jthis, jlong jhandle) {
+ API_BEGIN()
+ auto* ivalue_ptr = reinterpret_cast(jhandle);
+ return env->NewStringUTF(ivalue_ptr->toString()->string().c_str());
+ API_END_RETURN()
+}
+
+JNIEXPORT jbooleanArray JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueToBoolList(
+ JNIEnv* env, jobject jthis, jlong jhandle) {
+ API_BEGIN()
+ auto* ivalue_ptr = reinterpret_cast(jhandle);
+ torch::List list = ivalue_ptr->toBoolList();
+ size_t len = list.size();
+ jbooleanArray jarray = env->NewBooleanArray(len);
+ std::vector jptrs;
+ jptrs.reserve(len);
+ for (size_t i = 0; i < len; ++i) {
+ jptrs[i] = list[i];
+ }
+ env->SetBooleanArrayRegion(jarray, 0, len, jptrs.data());
+ return jarray;
+ API_END_RETURN()
+}
+
+JNIEXPORT jlongArray JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueToLongList(
+ JNIEnv* env, jobject jthis, jlong jhandle) {
+ API_BEGIN()
+ auto* ivalue_ptr = reinterpret_cast(jhandle);
+ torch::List list = ivalue_ptr->toIntList();
+ size_t len = list.size();
+ jlongArray jarray = env->NewLongArray(len);
+ std::vector jptrs;
+ jptrs.reserve(len);
+ for (size_t i = 0; i < len; ++i) {
+ jptrs[i] = list[i];
+ }
+ env->SetLongArrayRegion(jarray, 0, len, jptrs.data());
+ return jarray;
+ API_END_RETURN()
+}
+
+JNIEXPORT jdoubleArray JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueToDoubleList(
+ JNIEnv* env, jobject jthis, jlong jhandle) {
+ API_BEGIN()
+ auto* ivalue_ptr = reinterpret_cast(jhandle);
+ torch::List list = ivalue_ptr->toDoubleList();
+ size_t len = list.size();
+ jdoubleArray jarray = env->NewDoubleArray(len);
+ std::vector jptrs;
+ jptrs.reserve(len);
+ for (size_t i = 0; i < len; ++i) {
+ jptrs[i] = list[i];
+ }
+ env->SetDoubleArrayRegion(jarray, 0, len, jptrs.data());
+ return jarray;
API_END_RETURN()
}
@@ -83,7 +241,7 @@ JNIEXPORT jlongArray JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueToTens
API_END_RETURN()
}
-JNIEXPORT jlongArray JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueToList(
+JNIEXPORT jlongArray JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueToIValueList(
JNIEnv* env, jobject jthis, jlong jhandle) {
API_BEGIN()
auto* ivalue_ptr = reinterpret_cast(jhandle);
@@ -92,6 +250,15 @@ JNIEXPORT jlongArray JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueToList
API_END_RETURN()
}
+JNIEXPORT jlongArray JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueToIValueTuple(
+ JNIEnv* env, jobject jthis, jlong jhandle) {
+ API_BEGIN()
+ auto* ivalue_ptr = reinterpret_cast(jhandle);
+ std::vector ivalue_vec = ivalue_ptr->toTuple()->elements();
+ return djl::utils::jni::GetPtrArrayFromContainer, torch::IValue>(env, ivalue_vec);
+ API_END_RETURN()
+}
+
JNIEXPORT jlongArray JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueToMap(
JNIEnv* env, jobject jthis, jlong jhandle) {
API_BEGIN()
@@ -113,11 +280,31 @@ JNIEXPORT jlongArray JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueToMap(
API_END_RETURN()
}
-JNIEXPORT jstring JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueToString(
+JNIEXPORT jboolean JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueIsTensor(
JNIEnv* env, jobject jthis, jlong jhandle) {
API_BEGIN()
- auto* ivalue_ptr = reinterpret_cast(jhandle);
- return env->NewStringUTF(ivalue_ptr->toString()->string().c_str());
+ return reinterpret_cast(jhandle)->isTensor();
+ API_END_RETURN()
+}
+
+JNIEXPORT jboolean JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueIsBool(
+ JNIEnv* env, jobject jthis, jlong jhandle) {
+ API_BEGIN()
+ return reinterpret_cast(jhandle)->isBool();
+ API_END_RETURN()
+}
+
+JNIEXPORT jboolean JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueIsLong(
+ JNIEnv* env, jobject jthis, jlong jhandle) {
+ API_BEGIN()
+ return reinterpret_cast(jhandle)->isInt();
+ API_END_RETURN()
+}
+
+JNIEXPORT jboolean JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueIsDouble(
+ JNIEnv* env, jobject jthis, jlong jhandle) {
+ API_BEGIN()
+ return reinterpret_cast(jhandle)->isDouble();
API_END_RETURN()
}
@@ -128,10 +315,24 @@ JNIEXPORT jboolean JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueIsString
API_END_RETURN()
}
-JNIEXPORT jboolean JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueIsTensor(
+JNIEXPORT jboolean JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueIsBoolList(
JNIEnv* env, jobject jthis, jlong jhandle) {
API_BEGIN()
- return reinterpret_cast(jhandle)->isTensor();
+ return reinterpret_cast(jhandle)->isBoolList();
+ API_END_RETURN()
+}
+
+JNIEXPORT jboolean JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueIsLongList(
+ JNIEnv* env, jobject jthis, jlong jhandle) {
+ API_BEGIN()
+ return reinterpret_cast(jhandle)->isIntList();
+ API_END_RETURN()
+}
+
+JNIEXPORT jboolean JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueIsDoubleList(
+ JNIEnv* env, jobject jthis, jlong jhandle) {
+ API_BEGIN()
+ return reinterpret_cast(jhandle)->isDoubleList();
API_END_RETURN()
}