diff --git a/.gitignore b/.gitignore index ad84f2ea923..af083e92af7 100644 --- a/.gitignore +++ b/.gitignore @@ -14,6 +14,7 @@ libs *.dylib *.dll *.class +*.log # Eclipse .settings diff --git a/pytorch/pytorch-engine/build.gradle b/pytorch/pytorch-engine/build.gradle index 216b9fada05..dc95a03b3eb 100644 --- a/pytorch/pytorch-engine/build.gradle +++ b/pytorch/pytorch-engine/build.gradle @@ -27,11 +27,11 @@ processResources { ] def classesDir = "${project.buildDir}/jnilib" files.each { entry -> - project.logger.lifecycle("Downloading ${url}/${entry.key}") def file = new File("${classesDir}/${entry.value}") if (file.exists()) { project.logger.lifecycle("prebuilt or cached file found for ${entry.value}") } else { + project.logger.lifecycle("Downloading ${url}/${entry.key}") file.getParentFile().mkdirs() new URL("${url}/${entry.key}").withInputStream { i -> file.withOutputStream { it << i } } } diff --git a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtSymbolBlock.java b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtSymbolBlock.java index 7215aba5a48..0d671ff0331 100644 --- a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtSymbolBlock.java +++ b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtSymbolBlock.java @@ -19,6 +19,7 @@ import ai.djl.ndarray.types.Shape; import ai.djl.nn.AbstractSymbolBlock; import ai.djl.nn.SymbolBlock; +import ai.djl.pytorch.jni.IValue; import ai.djl.pytorch.jni.IValueUtils; import ai.djl.pytorch.jni.JniUtils; import ai.djl.training.ParameterStore; @@ -88,6 +89,16 @@ public void close() { } } + /** + * Runs the forward of this PyTorch module. + * + * @param inputs the input {@link IValue} + * @return the result {@link IValue} + */ + public IValue forward(IValue... inputs) { + return IValueUtils.forward(this, inputs); + } + /** {@inheritDoc} */ @Override protected NDList forwardInternal( diff --git a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/IValue.java b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/IValue.java new file mode 100644 index 00000000000..71d31e3420b --- /dev/null +++ b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/IValue.java @@ -0,0 +1,402 @@ +/* + * Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.pytorch.jni; + +import ai.djl.ndarray.NDList; +import ai.djl.pytorch.engine.PtNDArray; +import ai.djl.pytorch.engine.PtNDManager; +import ai.djl.util.NativeResource; +import java.util.Arrays; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +/** + * A class represent a PyTorch {@code IValue} data. + * + *

DJL doesn't support creating nested IValue. + */ +public class IValue extends NativeResource { + + IValue(long handle) { + super(handle); + } + + /** + * Returns if the IValue is a {@code Tensor} type. + * + * @return if the IValue is a Tensor type + */ + public boolean isTensor() { + return PyTorchLibrary.LIB.iValueIsTensor(getHandle()); + } + + /** + * Returns if the IValue is a {@code boolean} type. + * + * @return if the IValue is a boolean type + */ + public boolean isBoolean() { + return PyTorchLibrary.LIB.iValueIsBool(getHandle()); + } + + /** + * Returns if the IValue is a {@code long} type. + * + * @return if the IValue is a long type + */ + public boolean isLong() { + return PyTorchLibrary.LIB.iValueIsLong(getHandle()); + } + + /** + * Returns if the IValue is a {@code double} type. + * + * @return if the IValue is a double type + */ + public boolean isDouble() { + return PyTorchLibrary.LIB.iValueIsDouble(getHandle()); + } + + /** + * Returns if the IValue is a {@code String} type. + * + * @return if the IValue is a String type + */ + public boolean isString() { + return PyTorchLibrary.LIB.iValueIsString(getHandle()); + } + + /** + * Returns if the IValue is a {@code boolean[]} type. + * + * @return if the IValue is a boolean[] type + */ + public boolean isBooleanList() { + return PyTorchLibrary.LIB.iValueIsBoolList(getHandle()); + } + + /** + * Returns if the IValue is a {@code long[]} type. + * + * @return if the IValue is a long[] type + */ + public boolean isLongList() { + return PyTorchLibrary.LIB.iValueIsLongList(getHandle()); + } + + /** + * Returns if the IValue is a {@code double[]} type. + * + * @return if the IValue is a double[] type + */ + public boolean isDoubleList() { + return PyTorchLibrary.LIB.iValueIsDoubleList(getHandle()); + } + + /** + * Returns if the IValue is a {@code IValue[]} type. + * + *

The elements in the array must have the same type. + * + * @return if the IValue is a IValue[] type + */ + public boolean isTensorList() { + return PyTorchLibrary.LIB.iValueIsTensorList(getHandle()); + } + + /** + * Returns if the IValue is a {@code IValue[]} type. + * + *

The elements in the array must have the same type. + * + * @return if the IValue is a IValue[] type + */ + public boolean isList() { + return PyTorchLibrary.LIB.iValueIsList(getHandle()); + } + + /** + * Returns if the IValue is a {@code Map<String, V>} type. + * + * @return if the IValue is a Map<String, V> type + */ + public boolean isMap() { + return PyTorchLibrary.LIB.iValueIsMap(getHandle()); + } + + /** + * Creates a new {@code IValue} of type {@code PtNDArray}. + * + * @param value the NDArray value + * @return a new {@code IValue} of type {@code PtNDArray} + */ + public static IValue from(PtNDArray value) { + return new IValue(PyTorchLibrary.LIB.iValueFromTensor(value.getHandle())); + } + + /** + * Creates a new {@code IValue} of type {@code boolean}. + * + * @param value the boolean value + * @return a new {@code IValue} of type {@code boolean} + */ + public static IValue from(boolean value) { + return new IValue(PyTorchLibrary.LIB.iValueFromBool(value)); + } + + /** + * Creates a new {@code IValue} of type {@code long}. + * + * @param value the long value + * @return a new {@code IValue} of type {@code long} + */ + public static IValue from(long value) { + return new IValue(PyTorchLibrary.LIB.iValueFromLong(value)); + } + + /** + * Creates a new {@code IValue} of type {@code double}. + * + * @param value the double value + * @return a new {@code IValue} of type {@code double} + */ + public static IValue from(double value) { + return new IValue(PyTorchLibrary.LIB.iValueFromDouble(value)); + } + + /** + * Creates a new {@code IValue} of type {@code String}. + * + * @param value the String value + * @return a new {@code IValue} of type {@code String} + */ + public static IValue from(String value) { + return new IValue(PyTorchLibrary.LIB.iValueFromString(value)); + } + + /** + * Creates a new {@code IValue} of type {@code boolean[]}. + * + * @param list the boolean[] value + * @return a new {@code IValue} of type {@code boolean[]} + */ + public static IValue listFrom(boolean... list) { + return new IValue(PyTorchLibrary.LIB.iValueFromBoolList(list)); + } + + /** + * Creates a new {@code IValue} of type {@code long[]}. + * + * @param list the long[] value + * @return a new {@code IValue} of type {@code long[]} + */ + public static IValue listFrom(long... list) { + return new IValue(PyTorchLibrary.LIB.iValueFromLongList(list)); + } + + /** + * Creates a new {@code IValue} of type {@code double[]}. + * + * @param list the double[] value + * @return a new {@code IValue} of type {@code double[]} + */ + public static IValue listFrom(double... list) { + return new IValue(PyTorchLibrary.LIB.iValueFromDoubleList(list)); + } + + /** + * Creates a new {@code IValue} of type {@code NDArray[]}. + * + * @param list the NDArray[] value + * @return a new {@code IValue} of type {@code NDArray[]} + */ + public static IValue listFrom(PtNDArray... list) { + long[] tensors = Arrays.stream(list).mapToLong(PtNDArray::getHandle).toArray(); + return new IValue(PyTorchLibrary.LIB.iValueFromTensorList(tensors)); + } + + /** + * Creates a new {@code IValue} of type {@code Map[String, PtNDArray]}. + * + * @param map the Map[String, IValue] value + * @return a new {@code IValue} of type {@code Map[String, PtNDArray]} + */ + public static IValue stringMapFrom(Map map) { + String[] keys = new String[map.size()]; + long[] handles = new long[map.size()]; + int i = 0; + for (Map.Entry entry : map.entrySet()) { + keys[i] = entry.getKey(); + handles[i] = entry.getValue().getHandle(); + ++i; + } + return new IValue(PyTorchLibrary.LIB.iValueFromStringMap(keys, handles)); + } + + /** + * Returns the {@code boolean} value of this IValue. + * + * @return the boolean value of this IValue + */ + public boolean toBoolean() { + return PyTorchLibrary.LIB.iValueToBool(getHandle()); + } + + /** + * Returns the {@code long} value of this IValue. + * + * @return the long value of this IValue + */ + public long toLong() { + return PyTorchLibrary.LIB.iValueToLong(getHandle()); + } + + /** + * Returns the {@code double} value of this IValue. + * + * @return the double value of this IValue + */ + public double toDouble() { + return PyTorchLibrary.LIB.iValueToDouble(getHandle()); + } + + /** + * Returns the {@code String} value of this IValue. + * + * @return the String value of this IValue + */ + public String toStringValue() { + return PyTorchLibrary.LIB.iValueToString(getHandle()); + } + + /** + * Returns the {@code boolean[]} value of this IValue. + * + * @return the boolean[] value of this IValue + */ + public boolean[] toBooleanArray() { + return PyTorchLibrary.LIB.iValueToBoolList(getHandle()); + } + + /** + * Returns the {@code long[]} value of this IValue. + * + * @return the long[] value of this IValue + */ + public long[] toLongArray() { + return PyTorchLibrary.LIB.iValueToLongList(getHandle()); + } + + /** + * Returns the {@code double[]} value of this IValue. + * + * @return the double[] value of this IValue + */ + public double[] toDoubleArray() { + return PyTorchLibrary.LIB.iValueToDoubleList(getHandle()); + } + + /** + * Returns the {@code NDArray} value of this IValue. + * + * @param manager the {@code NDManager} to create the NDArray + * @return the NDArray value of this IValue + */ + public PtNDArray toTensor(PtNDManager manager) { + return new PtNDArray(manager, PyTorchLibrary.LIB.iValueToTensor(getHandle())); + } + + /** + * Returns the {@code NDArray[]} value of this IValue. + * + * @param manager the NDManager to create NDArray + * @return the NDArray[] value of this IValue + */ + public PtNDArray[] toTensorArray(PtNDManager manager) { + long[] handles = PyTorchLibrary.LIB.iValueToTensorList(getHandle()); + PtNDArray[] ret = new PtNDArray[handles.length]; + for (int i = 0; i < ret.length; ++i) { + ret[i] = new PtNDArray(manager, handles[i]); + } + return ret; + } + + /** + * Returns the {@code IValue[]} value of this IValue list. + * + * @return the IValue[] value of this IValue list + */ + public IValue[] toIValueArray() { + long[] handles = PyTorchLibrary.LIB.iValueToIValueList(getHandle()); + IValue[] ret = new IValue[handles.length]; + for (int i = 0; i < ret.length; ++i) { + ret[i] = new IValue(handles[i]); + } + return ret; + } + + /** + * Returns the {@code Map<String, IValue>} value of this IValue. + * + * @return the Map<String, IValue> value of this IValue + */ + public Map toIValueMap() { + long[] handles = PyTorchLibrary.LIB.iValueToMap(getHandle()); + Map map = new ConcurrentHashMap<>(); + for (int i = 0; i < handles.length; i += 2) { + IValue key = new IValue(handles[i]); + map.put(key.toStringValue(), new IValue(handles[i + 1])); + key.close(); + } + return map; + } + + /** + * Returns the {@code NDList} value of this IValue. + * + * @param manager the NDManager to create NDArray + * @return the {@code NDList} value of this IValue + */ + public NDList toNDList(PtNDManager manager) { + if (isTensor()) { + return new NDList(toTensor(manager)); + } else if (isTensorList()) { + return new NDList(toTensorArray(manager)); + } else if (isMap()) { + // Only allows one level type of map + NDList list = new NDList(); + Map map = toIValueMap(); + for (Map.Entry entry : map.entrySet()) { + IValue iv = entry.getValue(); + if (!iv.isTensor()) { + throw new UnsupportedOperationException("Only one level of map is supported."); + } + PtNDArray value = entry.getValue().toTensor(manager); + value.setName(entry.getKey()); + list.add(value); + iv.close(); + } + return list; + } + throw new UnsupportedOperationException("Unsupported IValue type."); + } + + /** {@inheritDoc} */ + @Override + public void close() { + Long pointer = handle.getAndSet(null); + if (pointer != null) { + PyTorchLibrary.LIB.torchDeleteIValue(pointer); + } + } +} diff --git a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/IValueUtils.java b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/IValueUtils.java index e1898713b2e..41cb8f57cf7 100644 --- a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/IValueUtils.java +++ b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/IValueUtils.java @@ -20,13 +20,12 @@ import ai.djl.pytorch.engine.PtSymbolBlock; import ai.djl.util.Pair; import ai.djl.util.PairList; -import ai.djl.util.Preconditions; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.regex.Pattern; -import java.util.stream.Stream; /** IValueUtils is utility class to deal with IValue in PyTorch. */ public final class IValueUtils { @@ -34,278 +33,85 @@ public final class IValueUtils { private IValueUtils() {} /** - * Create IValue Pointer from NDArray. + * Runs the forward of PyTorch module. * - * @param arrayHandle the handle for PyTorch Tensor - * @return IValue Pointer - */ - public static long toIValuePointer(long arrayHandle) { - return PyTorchLibrary.LIB.iValueFromTensor(arrayHandle); - } - - /** - * Create List IValue Pointer from pointer array. - * - * @param pointers pointer array - * @return IValue Pointer - */ - public static long iValueFromList(long[] pointers) { - return PyTorchLibrary.LIB.iValueFromList(pointers); - } - - /** - * Create Dict IValue Pointer from pointer with its key name. - * - * @param pointers pointer array - * @param names the key value of the pointer - * @return IValue Pointer - */ - public static long iValueFromDict(long[] pointers, String[] names) { - return PyTorchLibrary.LIB.iValueFromDict(pointers, names); - } - - /** - * Check IValue is a container of {@link PtNDArray}. - * - * @param iValueHandle IValue pointer - * @return result - */ - public static boolean isNDArray(long iValueHandle) { - return PyTorchLibrary.LIB.iValueIsTensor(iValueHandle); - } - - /** - * Check IValue is a container of {@link NDList}. - * - * @param iValueHandle IValue pointer - * @return result - */ - public static boolean isNDList(long iValueHandle) { - return PyTorchLibrary.LIB.iValueIsTensorList(iValueHandle); - } - - /** - * Check IValue is a container of IValue List. - * - * @param iValueHandle IValue pointer - * @return result - */ - public static boolean isList(long iValueHandle) { - return PyTorchLibrary.LIB.iValueIsList(iValueHandle); - } - - /** - * Check IValue is a container of IValue Tuple. - * - * @param iValueHandle IValue pointer - * @return result - */ - public static boolean isTuple(long iValueHandle) { - return PyTorchLibrary.LIB.iValueIsTuple(iValueHandle); - } - - /** - * Check IValue is a container of IValue Map. - * - * @param iValueHandle IValue pointer - * @return result - */ - public static boolean isMap(long iValueHandle) { - return PyTorchLibrary.LIB.iValueIsMap(iValueHandle); - } - - /** - * Check IValue is a container of String. - * - * @param iValueHandle IValue pointer - * @return result - */ - public static boolean isString(long iValueHandle) { - return PyTorchLibrary.LIB.iValueIsString(iValueHandle); - } - - /** - * Extract IValue with a {@link PtNDArray} value. - * - * @param iValueHandle IValue pointer - * @param manager {@link PtNDManager} that creates {@link PtNDArray} - * @return {@link ai.djl.ndarray.NDArray} - */ - public static PtNDArray toNDArray(long iValueHandle, PtNDManager manager) { - long ndHandle = PyTorchLibrary.LIB.iValueToTensor(iValueHandle); - return new PtNDArray(manager, ndHandle); - } - - /** - * Extract IValue to {@link NDList}. - * - * @param iValueHandle IValue pointer - * @param manager {@link PtNDManager} that creates {@link PtNDArray} - * @return {@link NDList} - */ - public static NDList toNDList(long iValueHandle, PtNDManager manager) { - long[] ndHandles = PyTorchLibrary.LIB.iValueToTensorList(iValueHandle); - NDList list = new NDList(); - for (long handle : ndHandles) { - list.add(new PtNDArray(manager, handle)); - } - return list; - } - - /** - * Extract IValue to String. - * - * @param iValueHandle IValue pointer - * @return String - */ - public static String toString(long iValueHandle) { - return PyTorchLibrary.LIB.iValueToString(iValueHandle); - } - - /** - * Extract IValue to an IValue Array. - * - * @param iValueHandle IValue pointer - * @return IValue array - */ - public static long[] toIValueArray(long iValueHandle) { - if (isTuple(iValueHandle)) { - return PyTorchLibrary.LIB.iValueToListFromTuple(iValueHandle); - } - return PyTorchLibrary.LIB.iValueToList(iValueHandle); - } - - /** - * Extract IValue to a Map. - * - * @param iValueHandle IValue pointer - * @return IValue Map + * @param block the block that contains PyTorch module + * @param inputs the input {@link NDList} + * @param isTrain if running on training mode + * @return the result {@link NDList} */ - public static Map toIValueMap(long iValueHandle) { - long[] iValueHandles = PyTorchLibrary.LIB.iValueToMap(iValueHandle); - Map map = new ConcurrentHashMap<>(); - for (int i = 0; i < iValueHandles.length; i += 2) { - map.put(iValueHandles[i], iValueHandles[i + 1]); - } - return map; - } - - private static NDList forwardHelper(long iValueHandle, PtNDManager manager) { - NDList list = new NDList(); - if (isNDArray(iValueHandle)) { - list.add(toNDArray(iValueHandle, manager)); - } else if (isNDList(iValueHandle)) { - list.addAll(toNDList(iValueHandle, manager)); - } else if (isList(iValueHandle) || isTuple(iValueHandle)) { - for (long handle : toIValueArray(iValueHandle)) { - list.addAll(forwardHelper(handle, manager)); - } - } else if (isMap(iValueHandle)) { - // Only allows type of map - Map map = toIValueMap(iValueHandle); - for (Map.Entry entry : map.entrySet()) { - String name = toString(entry.getKey()); - // free the IValue handle - PyTorchLibrary.LIB.torchDeleteIValue(entry.getKey()); - PtNDArray value = toNDArray(entry.getValue(), manager); - // free the IValue handle - PyTorchLibrary.LIB.torchDeleteIValue(entry.getValue()); - value.setName(name); - list.add(value); - } - } else { - // free the IValue handle - PyTorchLibrary.LIB.torchDeleteIValue(iValueHandle); - throw new UnsupportedOperationException("Unsupported IValue type"); + public static NDList forward(PtSymbolBlock block, NDList inputs, boolean isTrain) { + IValue[] iValues = getInputs(inputs); + long[] iValueHandles = Arrays.stream(iValues).mapToLong(IValue::getHandle).toArray(); + long result = PyTorchLibrary.LIB.moduleForward(block.getHandle(), iValueHandles, isTrain); + PtNDManager manager = (PtNDManager) inputs.get(0).getManager(); + Arrays.stream(iValues).forEach(IValue::close); + try (IValue iValue = new IValue(result)) { + return iValue.toNDList(manager); } - // free the IValue handle - PyTorchLibrary.LIB.torchDeleteIValue(iValueHandle); - return list; } /** - * Run the forward of PyTorch module. + * Runs the forward of PyTorch module. * * @param block the block that contains PyTorch module - * @param inputs input {@link NDList} - * @param isTrain is running on training mode - * @return result {@link NDList} + * @param inputs the input {@link IValue} + * @return the result {@link IValue} */ - public static NDList forward(PtSymbolBlock block, NDList inputs, boolean isTrain) { - long[] arrayHandles = - inputs.stream().mapToLong(input -> ((PtNDArray) input).getHandle()).toArray(); - String[] names = inputs.stream().map(NDArray::getName).toArray(String[]::new); - long[] iValueInputs = getInputs(arrayHandles, names); - long result = PyTorchLibrary.LIB.moduleForward(block.getHandle(), iValueInputs, isTrain); - PtNDManager manager = (PtNDManager) inputs.get(0).getManager(); - return forwardHelper(result, manager); + public static IValue forward(PtSymbolBlock block, IValue... inputs) { + long[] handles = Arrays.stream(inputs).mapToLong(IValue::getHandle).toArray(); + return new IValue(PyTorchLibrary.LIB.moduleForward(block.getHandle(), handles, false)); } - private static boolean isNameList(String name) { - return Pattern.matches("\\w+\\[]", name); + private static int addToMap( + Map map, String key, List> list) { + return map.computeIfAbsent( + key, + k -> { + list.add(new PairList<>()); + return list.size() - 1; + }); } - private static boolean isNameDict(String name) { - return name.contains("."); - } - - private static long[] getInputs(long[] arrays, String[] names) { - List> outputs = new ArrayList<>(); + private static IValue[] getInputs(NDList ndList) { + List> outputs = new ArrayList<>(); Map indexMap = new ConcurrentHashMap<>(); - for (int i = 0; i < arrays.length; i++) { - String name = names[i]; - if (name == null || (!isNameList(name) && !isNameDict(name))) { - PairList list = new PairList<>(); - list.add(new Pair<>(null, toIValuePointer(arrays[i]))); - outputs.add(list); - continue; - } - String mapKey = null; - boolean isDict = isNameDict(names[i]); - if (isDict) { - String[] strings = names[i].split("\\."); - Preconditions.checkArgument( - strings.length == 2, - "Please make sure you only include one '.' in the name. Nested Map is not supported!"); - name = strings[0]; - mapKey = strings[1]; - } - if (!indexMap.containsKey(name)) { - outputs.add(new PairList<>()); - indexMap.put(name, outputs.size() - 1); - } - if (isDict) { - outputs.get(indexMap.get(name)).add(new Pair<>(mapKey, arrays[i])); + for (NDArray array : ndList) { + String name = array.getName(); + if (name != null && name.contains(".")) { + String[] strings = name.split("\\.", 2); + int index = addToMap(indexMap, strings[0], outputs); + PairList pl = outputs.get(index); + pl.add(strings[1], (PtNDArray) array); + } else if (name != null && Pattern.matches("\\w+\\[]", name)) { + int index = addToMap(indexMap, name, outputs); + PairList pl = outputs.get(index); + pl.add("[]", (PtNDArray) array); } else { - outputs.get(indexMap.get(name)).add(new Pair<>(name, arrays[i])); + PairList pl = new PairList<>(); + pl.add(null, (PtNDArray) array); + outputs.add(pl); } } - long[] pointers = new long[outputs.size()]; + IValue[] ret = new IValue[outputs.size()]; for (int i = 0; i < outputs.size(); ++i) { - // not List, Dict input - if (outputs.get(i).size() == 1 && outputs.get(i).get(0).getKey() == null) { - pointers[i] = outputs.get(i).get(0).getValue(); - } else if (isNameList(outputs.get(i).get(0).getKey())) { - pointers[i] = - iValueFromList( - toPrimitiveLongArray(outputs.get(i).valueArray(new Long[0]))); + PairList pl = outputs.get(i); + String key = pl.get(0).getKey(); + if (key == null) { + // not List, Dict input + ret[i] = IValue.from(pl.get(0).getValue()); + } else if ("[]".equals(key)) { + // list + PtNDArray[] arrays = pl.values().toArray(new PtNDArray[0]); + ret[i] = IValue.listFrom(arrays); } else { - PairList dict = outputs.get(i); - pointers[i] = - iValueFromDict( - toPrimitiveLongArray(dict.valueArray(new Long[0])), - dict.keyArray(new String[0])); + Map map = new ConcurrentHashMap<>(); + for (Pair pair : pl) { + map.put(pair.getKey(), pair.getValue()); + } + ret[i] = IValue.stringMapFrom(map); } } - return pointers; - } - - private static long[] toPrimitiveLongArray(Long[] array) { - if (array == null) { - return null; - } else if (array.length == 0) { - return new long[0]; - } - return Stream.of(array).mapToLong(Long::longValue).toArray(); + return ret; } } diff --git a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java index c078e7c7cef..dc2dcb19e4c 100644 --- a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java +++ b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java @@ -491,34 +491,74 @@ native long moduleLoad( native long iValueFromTensor(long tensorHandle); - native long iValueFromList(long[] tensorHandles); + native long iValueFromBool(boolean value); - native long iValueFromDict(long[] tensorHandles, String[] names); + native long iValueFromLong(long value); + + native long iValueFromDouble(double value); + + native long iValueFromString(String value); + + native long iValueFromBoolList(boolean... value); + + native long iValueFromLongList(long... value); + + native long iValueFromDoubleList(double... value); + + native long iValueFromTensorList(long[] tensorHandles); + + native long iValueFromTuple(long[] ivalueHandles); + + native long iValueFromStringMap(String[] keys, long[] tensorHandles); native long iValueToTensor(long iValueHandle); + native boolean iValueToBool(long iValueHandle); + + native long iValueToLong(long iValueHandle); + + native double iValueToDouble(long iValueHandle); + + native String iValueToString(long iValueHandle); + + native boolean[] iValueToBoolList(long iValueHandle); + + native long[] iValueToLongList(long iValueHandle); + + native double[] iValueToDoubleList(long iValueHandle); + native long[] iValueToTensorList(long iValueHandle); - native long[] iValueToList(long iValueHandle); + native long[] iValueToIValueList(long iValueHandle); - native long[] iValueToListFromTuple(long iValueHandle); + native long[] iValueToIValueTuple(long iValueHandle); native long[] iValueToMap(long iValueHandle); - native String iValueToString(long iValueHandle); + native boolean iValueIsTensor(long iValueHandle); + + native boolean iValueIsBool(long iValueHandle); + + native boolean iValueIsLong(long iValueHandle); + + native boolean iValueIsDouble(long iValueHandle); native boolean iValueIsString(long iValueHandle); - native boolean iValueIsTensor(long iValueHandle); + native boolean iValueIsBoolList(long iValueHandle); + + native boolean iValueIsLongList(long iValueHandle); + + native boolean iValueIsDoubleList(long iValueHandle); native boolean iValueIsTensorList(long iValueHandle); native boolean iValueIsList(long iValueHandle); - native boolean iValueIsMap(long iValueHandle); - native boolean iValueIsTuple(long iValueHandle); + native boolean iValueIsMap(long iValueHandle); + native void zeroGrad(long handle); native void adamUpdate( diff --git a/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/IValueTest.java b/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/IValueTest.java new file mode 100644 index 00000000000..6c776f59678 --- /dev/null +++ b/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/IValueTest.java @@ -0,0 +1,140 @@ +/* + * Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.pytorch.integration; + +import ai.djl.ModelException; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.types.Shape; +import ai.djl.pytorch.engine.PtNDArray; +import ai.djl.pytorch.engine.PtNDManager; +import ai.djl.pytorch.engine.PtSymbolBlock; +import ai.djl.pytorch.jni.IValue; +import ai.djl.repository.zoo.Criteria; +import ai.djl.repository.zoo.ZooModel; +import ai.djl.training.util.ProgressBar; +import java.io.IOException; +import java.util.Arrays; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import org.testng.Assert; +import org.testng.annotations.Test; + +public class IValueTest { + + @Test + public void testIValue() { + try (PtNDManager manager = (PtNDManager) NDManager.newBaseManager()) { + PtNDArray array1 = (PtNDArray) manager.zeros(new Shape(1)); + PtNDArray array2 = (PtNDArray) manager.ones(new Shape(1)); + + try (IValue ivalue = IValue.from(array1)) { + Assert.assertTrue(ivalue.isTensor()); + NDArray ret = ivalue.toTensor(manager); + Assert.assertEquals(ret, array1); + NDList list = ivalue.toNDList(manager); + Assert.assertEquals(list.size(), 1); + Assert.assertEquals(list.head(), array1); + } + + try (IValue ivalue = IValue.from(true)) { + Assert.assertTrue(ivalue.isBoolean()); + Assert.assertTrue(ivalue.toBoolean()); + } + + try (IValue ivalue = IValue.from(1)) { + Assert.assertTrue(ivalue.isLong()); + Assert.assertEquals(ivalue.toLong(), 1); + } + + try (IValue ivalue = IValue.from(1d)) { + Assert.assertTrue(ivalue.isDouble()); + Assert.assertEquals(ivalue.toDouble(), 1d); + } + + try (IValue ivalue = IValue.from("test")) { + Assert.assertTrue(ivalue.isString()); + Assert.assertEquals(ivalue.toStringValue(), "test"); + } + + try (IValue ivalue = IValue.listFrom(true, false)) { + Assert.assertTrue(ivalue.isList()); + Assert.assertTrue(ivalue.isBooleanList()); + Assert.assertEquals(ivalue.toBooleanArray(), new boolean[] {true, false}); + } + + try (IValue ivalue = IValue.listFrom(1, 2)) { + Assert.assertTrue(ivalue.isLongList()); + Assert.assertEquals(ivalue.toLongArray(), new long[] {1, 2}); + } + + try (IValue ivalue = IValue.listFrom(1d, 2d)) { + Assert.assertTrue(ivalue.isDoubleList()); + Assert.assertEquals(ivalue.toDoubleArray(), new double[] {1d, 2d}); + } + + try (IValue ivalue = IValue.listFrom(array1, array2)) { + Assert.assertTrue(ivalue.isTensorList()); + NDArray[] ret = ivalue.toTensorArray(manager); + Assert.assertEquals(ret.length, 2); + NDList list = ivalue.toNDList(manager); + Assert.assertEquals(list.size(), 2); + Assert.assertEquals(list.head(), array1); + + IValue[] iValues = ivalue.toIValueArray(); + Assert.assertEquals(iValues.length, 2); + Assert.assertTrue(iValues[0].isTensor()); + Arrays.stream(iValues).forEach(IValue::close); + } + + Map map = new ConcurrentHashMap<>(); + map.put("data1", array1); + map.put("data2", array2); + try (IValue ivalue = IValue.stringMapFrom(map)) { + Assert.assertTrue(ivalue.isMap()); + Map ret = ivalue.toIValueMap(); + Assert.assertEquals(ret.size(), 2); + + NDList list = ivalue.toNDList(manager); + Assert.assertEquals(list.size(), 2); + Assert.assertEquals(list.get("data1"), array1); + } + } + } + + @Test + public void testIValueModel() throws IOException, ModelException { + Criteria criteria = + Criteria.builder() + .setTypes(NDList.class, NDList.class) + .optModelUrls("https://resources.djl.ai/test-models/ivalue_jit.zip") + .optProgress(new ProgressBar()) + .build(); + + try (ZooModel model = criteria.loadModel()) { + PtSymbolBlock block = (PtSymbolBlock) model.getBlock(); + IValue tokens = IValue.listFrom(1, 2, 3); + IValue cls = IValue.from(0); + IValue sep = IValue.from(4); + IValue ret = block.forward(tokens, cls, sep); + long[] actual = ret.toLongArray(); + Assert.assertEquals(actual, new long[] {0, 1, 2, 3, 4}); + + tokens.close(); + cls.close(); + sep.close(); + ret.close(); + } + } +} diff --git a/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/PtModelTest.java b/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/PtModelTest.java index cda9b86d8dd..ab2bcf16b82 100644 --- a/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/PtModelTest.java +++ b/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/PtModelTest.java @@ -13,30 +13,43 @@ package ai.djl.pytorch.integration; import ai.djl.Model; +import ai.djl.ModelException; import ai.djl.inference.Predictor; import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDList; import ai.djl.ndarray.types.Shape; import ai.djl.pytorch.engine.PtModel; +import ai.djl.repository.zoo.Criteria; +import ai.djl.repository.zoo.ZooModel; +import ai.djl.training.util.ProgressBar; import ai.djl.translate.NoopTranslator; import ai.djl.translate.TranslateException; import java.io.IOException; -import java.net.URL; +import java.nio.file.Files; +import java.nio.file.Path; import org.testng.Assert; import org.testng.annotations.Test; public class PtModelTest { @Test - public void testLoadFromStream() throws IOException, TranslateException { - URL url = - new URL("https://djl-ai.s3.amazonaws.com/resources/test-models/traced_resnet18.pt"); - try (PtModel model = (PtModel) Model.newInstance("test model")) { - model.load(url.openStream()); - try (Predictor predictor = model.newPredictor(new NoopTranslator())) { - NDArray array = model.getNDManager().ones(new Shape(1, 3, 224, 224)); - NDArray result = predictor.predict(new NDList(array)).singletonOrThrow(); - Assert.assertEquals(result.getShape(), new Shape(1, 1000)); + public void testLoadFromStream() throws IOException, TranslateException, ModelException { + Criteria criteria = + Criteria.builder() + .setTypes(NDList.class, NDList.class) + .optModelUrls("djl://ai.djl.pytorch/resnet/0.0.1/traced_resnet18") + .optProgress(new ProgressBar()) + .build(); + try (ZooModel zooModel = criteria.loadModel()) { + Path modelFile = zooModel.getModelPath().resolve("traced_resnet18.pt"); + try (PtModel model = (PtModel) Model.newInstance("test model")) { + model.load(Files.newInputStream(modelFile)); + try (Predictor predictor = + model.newPredictor(new NoopTranslator())) { + NDArray array = model.getNDManager().ones(new Shape(1, 3, 224, 224)); + NDArray result = predictor.predict(new NDList(array)).singletonOrThrow(); + Assert.assertEquals(result.getShape(), new Shape(1, 1000)); + } } } } diff --git a/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/TorchScriptTest.java b/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/TorchScriptTest.java index d84565025a9..a57f5cafb90 100644 --- a/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/TorchScriptTest.java +++ b/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/TorchScriptTest.java @@ -15,10 +15,10 @@ import ai.djl.ModelException; import ai.djl.inference.Predictor; -import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDList; import ai.djl.ndarray.NDManager; import ai.djl.ndarray.types.Shape; +import ai.djl.pytorch.engine.PtNDArray; import ai.djl.pytorch.engine.PtNDManager; import ai.djl.pytorch.engine.PtSymbolBlock; import ai.djl.pytorch.jni.JniUtils; @@ -30,7 +30,7 @@ import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.InputStream; -import java.net.URL; +import java.nio.file.Files; import java.nio.file.Path; import org.testng.Assert; import org.testng.annotations.Test; @@ -51,7 +51,7 @@ public void testDictInput() throws ModelException, IOException, TranslateExcepti Path modelFile; try (ZooModel model = criteria.loadModel(); Predictor predictor = model.newPredictor()) { - NDArray array = manager.ones(new Shape(2, 2)); + PtNDArray array = (PtNDArray) manager.ones(new Shape(2, 2)); array.setName("input1.input"); NDList output = predictor.predict(new NDList(array)); Assert.assertEquals(output.singletonOrThrow(), array); @@ -73,11 +73,17 @@ public void testDictInput() throws ModelException, IOException, TranslateExcepti } @Test - public void testInputOutput() throws IOException { - URL url = - new URL("https://djl-ai.s3.amazonaws.com/resources/test-models/traced_resnet18.pt"); - try (PtNDManager manager = (PtNDManager) NDManager.newBaseManager()) { - try (InputStream is = url.openStream()) { + public void testInputOutput() throws IOException, ModelException { + Criteria criteria = + Criteria.builder() + .setTypes(NDList.class, NDList.class) + .optModelUrls("djl://ai.djl.pytorch/resnet/0.0.1/traced_resnet18") + .optProgress(new ProgressBar()) + .build(); + try (ZooModel model = criteria.loadModel()) { + PtNDManager manager = (PtNDManager) model.getNDManager(); + Path modelFile = model.getModelPath().resolve("traced_resnet18.pt"); + try (InputStream is = Files.newInputStream(modelFile)) { PtSymbolBlock block = JniUtils.loadModule(manager, is, manager.getDevice(), false); ByteArrayOutputStream os = new ByteArrayOutputStream(); JniUtils.writeModule(block, os, true); diff --git a/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_inference.cc b/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_inference.cc index 66f2e171ba4..628c862b2ef 100644 --- a/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_inference.cc +++ b/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_inference.cc @@ -183,11 +183,6 @@ JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_moduleForward( JITCallGuard guard; return module_ptr->forward(inputs); }(); - // release resource - // each IValue is created by new, free the memory after the inference - for (auto i = 0; i < len; ++i) { - delete reinterpret_cast(jptrs[i]); - } env->ReleaseLongArrayElements(jivalue_ptrs, jptrs, djl::utils::jni::RELEASE_MODE); const auto* result_ptr = new torch::IValue(output); return reinterpret_cast(result_ptr); diff --git a/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_ivalue.cc b/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_ivalue.cc index fea7582bc95..a20ac6f5369 100644 --- a/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_ivalue.cc +++ b/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_ivalue.cc @@ -26,35 +26,120 @@ JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueFromTensor( API_END_RETURN() } -JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueFromList( - JNIEnv* env, jobject jthis, jlongArray jtensor_ptrs) { - jsize len = env->GetArrayLength(jtensor_ptrs); - jlong* jptrs = env->GetLongArrayElements(jtensor_ptrs, JNI_FALSE); +JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueFromBool( + JNIEnv* env, jobject jthis, jboolean jvalue) { + API_BEGIN() + const auto* ivalue_ptr = new torch::IValue((bool) jvalue); + return reinterpret_cast(ivalue_ptr); + API_END_RETURN() +} + +JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueFromLong( + JNIEnv* env, jobject jthis, jlong jvalue) { + API_BEGIN() + const auto* ivalue_ptr = new torch::IValue((int64_t) jvalue); + return reinterpret_cast(ivalue_ptr); + API_END_RETURN() +} + +JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueFromDouble( + JNIEnv* env, jobject jthis, jdouble jvalue) { + API_BEGIN() + const auto* ivalue_ptr = new torch::IValue(jvalue); + return reinterpret_cast(ivalue_ptr); + API_END_RETURN() +} + +JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueFromString( + JNIEnv* env, jobject jthis, jstring jvalue) { + API_BEGIN() + const std::string value = djl::utils::jni::GetStringFromJString(env, jvalue); + const auto* ivalue_ptr = new torch::IValue(value); + return reinterpret_cast(ivalue_ptr); + API_END_RETURN() +} + +JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueFromBoolList( + JNIEnv* env, jobject jthis, jbooleanArray jvalues) { + API_BEGIN() + jsize len = env->GetArrayLength(jvalues); + jboolean* jptrs = env->GetBooleanArrayElements(jvalues, JNI_FALSE); + torch::List list; + list.reserve(len); + for (size_t i = 0; i < len; ++i) { + list.emplace_back(jptrs[i]); + } + env->ReleaseBooleanArrayElements(jvalues, jptrs, JNI_ABORT); + const auto* ivalue_ptr = new torch::IValue(list); + return reinterpret_cast(ivalue_ptr); + API_END_RETURN() +} + +JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueFromLongList( + JNIEnv* env, jobject jthis, jlongArray jvalues) { + API_BEGIN() + jsize len = env->GetArrayLength(jvalues); + jlong* jptrs = env->GetLongArrayElements(jvalues, JNI_FALSE); + torch::List list; + list.reserve(len); + for (size_t i = 0; i < len; ++i) { + list.emplace_back(jptrs[i]); + } + env->ReleaseLongArrayElements(jvalues, jptrs, JNI_ABORT); + const auto* ivalue_ptr = new torch::IValue(list); + return reinterpret_cast(ivalue_ptr); + API_END_RETURN() +} + +JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueFromDoubleList( + JNIEnv* env, jobject jthis, jdoubleArray jvalues) { + API_BEGIN() + jsize len = env->GetArrayLength(jvalues); + jdouble* jptrs = env->GetDoubleArrayElements(jvalues, JNI_FALSE); + torch::List list; + list.reserve(len); + for (size_t i = 0; i < len; ++i) { + list.emplace_back(jptrs[i]); + } + env->ReleaseDoubleArrayElements(jvalues, jptrs, JNI_ABORT); + const auto* ivalue_ptr = new torch::IValue(list); + return reinterpret_cast(ivalue_ptr); + API_END_RETURN() +} + +JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueFromTensorList( + JNIEnv* env, jobject jthis, jlongArray jvalues) { + API_BEGIN() + jsize len = env->GetArrayLength(jvalues); + jlong* jptrs = env->GetLongArrayElements(jvalues, JNI_FALSE); torch::List list; list.reserve(len); for (size_t i = 0; i < len; ++i) { list.emplace_back(*reinterpret_cast(jptrs[i])); } - env->ReleaseLongArrayElements(jtensor_ptrs, jptrs, JNI_ABORT); - auto* result_ptr = new torch::IValue(list); - return reinterpret_cast(result_ptr); + env->ReleaseLongArrayElements(jvalues, jptrs, JNI_ABORT); + const auto* ivalue_ptr = new torch::IValue(list); + return reinterpret_cast(ivalue_ptr); + API_END_RETURN() } -JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueFromDict( - JNIEnv* env, jobject jthis, jlongArray jtensor_ptrs, jobjectArray jnames) { - auto len = static_cast(env->GetArrayLength(jtensor_ptrs)); - jlong* jptrs = env->GetLongArrayElements(jtensor_ptrs, JNI_FALSE); +JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueFromStringMap( + JNIEnv* env, jobject jthis, jobjectArray jkeys, jlongArray jvalues) { + API_BEGIN() + auto len = static_cast(env->GetArrayLength(jvalues)); + jlong* jptrs = env->GetLongArrayElements(jvalues, JNI_FALSE); torch::Dict dict; dict.reserve(len); for (size_t i = 0; i < len; ++i) { - auto jname = (jstring) env->GetObjectArrayElement(jnames, i); + auto jname = (jstring) env->GetObjectArrayElement(jkeys, i); std::string name = djl::utils::jni::GetStringFromJString(env, jname); dict.insert(name, *reinterpret_cast(jptrs[i])); } - env->ReleaseLongArrayElements(jtensor_ptrs, jptrs, JNI_ABORT); - env->DeleteLocalRef(jnames); - auto* result_ptr = new torch::IValue(dict); - return reinterpret_cast(result_ptr); + env->ReleaseLongArrayElements(jvalues, jptrs, JNI_ABORT); + env->DeleteLocalRef(jkeys); + const auto* ivalue_ptr = new torch::IValue(dict); + return reinterpret_cast(ivalue_ptr); + API_END_RETURN() } JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueToTensor( @@ -65,12 +150,85 @@ JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueToTensor( API_END_RETURN() } -JNIEXPORT jlongArray JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueToListFromTuple( +JNIEXPORT jboolean JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueToBool( JNIEnv* env, jobject jthis, jlong jhandle) { API_BEGIN() auto* ivalue_ptr = reinterpret_cast(jhandle); - std::vector ivalue_vec = ivalue_ptr->toTuple()->elements(); - return djl::utils::jni::GetPtrArrayFromContainer, torch::IValue>(env, ivalue_vec); + return ivalue_ptr->toBool(); + API_END_RETURN() +} + +JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueToLong(JNIEnv* env, jobject jthis, jlong jhandle) { + API_BEGIN() + auto* ivalue_ptr = reinterpret_cast(jhandle); + return ivalue_ptr->toInt(); + API_END_RETURN() +} + +JNIEXPORT jdouble JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueToDouble( + JNIEnv* env, jobject jthis, jlong jhandle) { + API_BEGIN() + auto* ivalue_ptr = reinterpret_cast(jhandle); + return ivalue_ptr->toDouble(); + API_END_RETURN() +} + +JNIEXPORT jstring JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueToString( + JNIEnv* env, jobject jthis, jlong jhandle) { + API_BEGIN() + auto* ivalue_ptr = reinterpret_cast(jhandle); + return env->NewStringUTF(ivalue_ptr->toString()->string().c_str()); + API_END_RETURN() +} + +JNIEXPORT jbooleanArray JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueToBoolList( + JNIEnv* env, jobject jthis, jlong jhandle) { + API_BEGIN() + auto* ivalue_ptr = reinterpret_cast(jhandle); + torch::List list = ivalue_ptr->toBoolList(); + size_t len = list.size(); + jbooleanArray jarray = env->NewBooleanArray(len); + std::vector jptrs; + jptrs.reserve(len); + for (size_t i = 0; i < len; ++i) { + jptrs[i] = list[i]; + } + env->SetBooleanArrayRegion(jarray, 0, len, jptrs.data()); + return jarray; + API_END_RETURN() +} + +JNIEXPORT jlongArray JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueToLongList( + JNIEnv* env, jobject jthis, jlong jhandle) { + API_BEGIN() + auto* ivalue_ptr = reinterpret_cast(jhandle); + torch::List list = ivalue_ptr->toIntList(); + size_t len = list.size(); + jlongArray jarray = env->NewLongArray(len); + std::vector jptrs; + jptrs.reserve(len); + for (size_t i = 0; i < len; ++i) { + jptrs[i] = list[i]; + } + env->SetLongArrayRegion(jarray, 0, len, jptrs.data()); + return jarray; + API_END_RETURN() +} + +JNIEXPORT jdoubleArray JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueToDoubleList( + JNIEnv* env, jobject jthis, jlong jhandle) { + API_BEGIN() + auto* ivalue_ptr = reinterpret_cast(jhandle); + torch::List list = ivalue_ptr->toDoubleList(); + size_t len = list.size(); + jdoubleArray jarray = env->NewDoubleArray(len); + std::vector jptrs; + jptrs.reserve(len); + for (size_t i = 0; i < len; ++i) { + jptrs[i] = list[i]; + } + env->SetDoubleArrayRegion(jarray, 0, len, jptrs.data()); + return jarray; API_END_RETURN() } @@ -83,7 +241,7 @@ JNIEXPORT jlongArray JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueToTens API_END_RETURN() } -JNIEXPORT jlongArray JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueToList( +JNIEXPORT jlongArray JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueToIValueList( JNIEnv* env, jobject jthis, jlong jhandle) { API_BEGIN() auto* ivalue_ptr = reinterpret_cast(jhandle); @@ -92,6 +250,15 @@ JNIEXPORT jlongArray JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueToList API_END_RETURN() } +JNIEXPORT jlongArray JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueToIValueTuple( + JNIEnv* env, jobject jthis, jlong jhandle) { + API_BEGIN() + auto* ivalue_ptr = reinterpret_cast(jhandle); + std::vector ivalue_vec = ivalue_ptr->toTuple()->elements(); + return djl::utils::jni::GetPtrArrayFromContainer, torch::IValue>(env, ivalue_vec); + API_END_RETURN() +} + JNIEXPORT jlongArray JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueToMap( JNIEnv* env, jobject jthis, jlong jhandle) { API_BEGIN() @@ -113,11 +280,31 @@ JNIEXPORT jlongArray JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueToMap( API_END_RETURN() } -JNIEXPORT jstring JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueToString( +JNIEXPORT jboolean JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueIsTensor( JNIEnv* env, jobject jthis, jlong jhandle) { API_BEGIN() - auto* ivalue_ptr = reinterpret_cast(jhandle); - return env->NewStringUTF(ivalue_ptr->toString()->string().c_str()); + return reinterpret_cast(jhandle)->isTensor(); + API_END_RETURN() +} + +JNIEXPORT jboolean JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueIsBool( + JNIEnv* env, jobject jthis, jlong jhandle) { + API_BEGIN() + return reinterpret_cast(jhandle)->isBool(); + API_END_RETURN() +} + +JNIEXPORT jboolean JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueIsLong( + JNIEnv* env, jobject jthis, jlong jhandle) { + API_BEGIN() + return reinterpret_cast(jhandle)->isInt(); + API_END_RETURN() +} + +JNIEXPORT jboolean JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueIsDouble( + JNIEnv* env, jobject jthis, jlong jhandle) { + API_BEGIN() + return reinterpret_cast(jhandle)->isDouble(); API_END_RETURN() } @@ -128,10 +315,24 @@ JNIEXPORT jboolean JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueIsString API_END_RETURN() } -JNIEXPORT jboolean JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueIsTensor( +JNIEXPORT jboolean JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueIsBoolList( JNIEnv* env, jobject jthis, jlong jhandle) { API_BEGIN() - return reinterpret_cast(jhandle)->isTensor(); + return reinterpret_cast(jhandle)->isBoolList(); + API_END_RETURN() +} + +JNIEXPORT jboolean JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueIsLongList( + JNIEnv* env, jobject jthis, jlong jhandle) { + API_BEGIN() + return reinterpret_cast(jhandle)->isIntList(); + API_END_RETURN() +} + +JNIEXPORT jboolean JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_iValueIsDoubleList( + JNIEnv* env, jobject jthis, jlong jhandle) { + API_BEGIN() + return reinterpret_cast(jhandle)->isDoubleList(); API_END_RETURN() }