diff --git a/api/src/main/java/ai/djl/engine/Engine.java b/api/src/main/java/ai/djl/engine/Engine.java index 8a1fc8871ac..a799c70f600 100644 --- a/api/src/main/java/ai/djl/engine/Engine.java +++ b/api/src/main/java/ai/djl/engine/Engine.java @@ -59,7 +59,7 @@ public abstract class Engine { private static final Map ALL_ENGINES = new ConcurrentHashMap<>(); - private static final String DEFAULT_ENGINE = initEngine(); + private static String defaultEngine = initEngine(); private static final Pattern PATTERN = Pattern.compile("KEY|TOKEN|PASSWORD", Pattern.CASE_INSENSITIVE); @@ -69,6 +69,10 @@ public abstract class Engine { private Integer seed; private static synchronized String initEngine() { + if (Boolean.parseBoolean(Utils.getenv("DJL_ENGINE_MANUAL_INIT"))) { + return null; + } + ServiceLoader loaders = ServiceLoader.load(EngineProvider.class); for (EngineProvider provider : loaders) { registerEngine(provider); @@ -80,21 +84,21 @@ private static synchronized String initEngine() { } String def = System.getProperty("ai.djl.default_engine"); - String defaultEngine = Utils.getenv("DJL_DEFAULT_ENGINE", def); - if (defaultEngine == null || defaultEngine.isEmpty()) { + String newDefaultEngine = Utils.getenv("DJL_DEFAULT_ENGINE", def); + if (newDefaultEngine == null || newDefaultEngine.isEmpty()) { int rank = Integer.MAX_VALUE; for (EngineProvider provider : ALL_ENGINES.values()) { if (provider.getEngineRank() < rank) { - defaultEngine = provider.getEngineName(); + newDefaultEngine = provider.getEngineName(); rank = provider.getEngineRank(); } } - } else if (!ALL_ENGINES.containsKey(defaultEngine)) { - throw new EngineException("Unknown default engine: " + defaultEngine); + } else if (!ALL_ENGINES.containsKey(newDefaultEngine)) { + throw new EngineException("Unknown default engine: " + newDefaultEngine); } - logger.debug("Found default engine: {}", defaultEngine); - Ec2Utils.callHome(defaultEngine); - return defaultEngine; + logger.debug("Found default engine: {}", newDefaultEngine); + Ec2Utils.callHome(newDefaultEngine); + return newDefaultEngine; } /** @@ -124,7 +128,7 @@ private static synchronized String initEngine() { * @return the default Engine name */ public static String getDefaultEngineName() { - return System.getProperty("ai.djl.default_engine", DEFAULT_ENGINE); + return System.getProperty("ai.djl.default_engine", defaultEngine); } /** @@ -134,7 +138,7 @@ public static String getDefaultEngineName() { * @see EngineProvider */ public static Engine getInstance() { - if (DEFAULT_ENGINE == null) { + if (defaultEngine == null) { throw new EngineException( "No deep learning engine found." + System.lineSeparator() @@ -163,7 +167,29 @@ public static boolean hasEngine(String engineName) { */ public static void registerEngine(EngineProvider provider) { logger.debug("Registering EngineProvider: {}", provider.getEngineName()); - ALL_ENGINES.putIfAbsent(provider.getEngineName(), provider); + ALL_ENGINES.put(provider.getEngineName(), provider); + } + + /** + * Returns the default engine. + * + * @return the default engine + */ + public static String getDefaultEngine() { + return defaultEngine; + } + + /** + * Sets the default engine returned by {@link #getInstance()}. + * + * @param engineName the new default engine's name + */ + public static void setDefaultEngine(String engineName) { + // Requires an engine to be loaded (without exception) before being the default + getEngine(engineName); + + logger.debug("Setting new default engine: {}", engineName); + defaultEngine = engineName; } /** @@ -187,7 +213,12 @@ public static Engine getEngine(String engineName) { if (provider == null) { throw new IllegalArgumentException("Deep learning engine not found: " + engineName); } - return provider.getEngine(); + Engine engine = provider.getEngine(); + if (engine == null) { + throw new IllegalStateException( + "The engine " + engineName + " was not able to initialize"); + } + return engine; } /** diff --git a/docs/development/troubleshooting.md b/docs/development/troubleshooting.md index ff03d32648e..1a04592dc12 100644 --- a/docs/development/troubleshooting.md +++ b/docs/development/troubleshooting.md @@ -105,6 +105,11 @@ For more information, please refer to [DJL Cache Management](cache_management.md It happened when you had a wrong version with DJL and Deep Engines. You can check the combination [here](dependency_management.md) and use DJL BOM to solve the issue. +### 1.6 Manual initialization + +If you are using manual engine initialization, you must both register an engine and set it as the default. +This can be done with `Engine.registerEngine(..)` and `Engine.setDefaultEngine(..)`. + ## 2. IntelliJ throws the `No Log4j 2 configuration file found.` exception. The following exception may appear after running the `./gradlew clean` command: diff --git a/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmEngineProvider.java b/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmEngineProvider.java index f8c84c753ef..583cd8132b2 100644 --- a/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmEngineProvider.java +++ b/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmEngineProvider.java @@ -18,6 +18,9 @@ /** {@code LgbmEngineProvider} is the LightGBM implementation of {@link EngineProvider}. */ public class LgbmEngineProvider implements EngineProvider { + private volatile Engine engine; // NOPMD + private volatile boolean initialized; // NOPMD + /** {@inheritDoc} */ @Override public String getEngineName() { @@ -33,10 +36,14 @@ public int getEngineRank() { /** {@inheritDoc} */ @Override public Engine getEngine() { - return InstanceHolder.INSTANCE; - } - - private static class InstanceHolder { - static final Engine INSTANCE = LgbmEngine.newInstance(); + if (!initialized) { + synchronized (LgbmEngineProvider.class) { + if (!initialized) { + initialized = true; + engine = LgbmEngine.newInstance(); + } + } + } + return engine; } } diff --git a/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbEngineProvider.java b/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbEngineProvider.java index 5859f3f344d..8b534d5196c 100644 --- a/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbEngineProvider.java +++ b/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbEngineProvider.java @@ -18,6 +18,9 @@ /** {@code XgbEngineProvider} is the XGBoost implementation of {@link EngineProvider}. */ public class XgbEngineProvider implements EngineProvider { + private volatile Engine engine; // NOPMD + private volatile boolean initialized; // NOPMD + /** {@inheritDoc} */ @Override public String getEngineName() { @@ -33,10 +36,14 @@ public int getEngineRank() { /** {@inheritDoc} */ @Override public Engine getEngine() { - return InstanceHolder.INSTANCE; - } - - private static class InstanceHolder { - static final Engine INSTANCE = XgbEngine.newInstance(); + if (!initialized) { + synchronized (XgbEngineProvider.class) { + if (!initialized) { + initialized = true; + engine = XgbEngine.newInstance(); + } + } + } + return engine; } } diff --git a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxEngineProvider.java b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxEngineProvider.java index 5f45116f615..2a5ab970560 100644 --- a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxEngineProvider.java +++ b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxEngineProvider.java @@ -18,6 +18,9 @@ /** {@code MxEngineProvider} is the MXNet implementation of {@link EngineProvider}. */ public class MxEngineProvider implements EngineProvider { + private volatile Engine engine; // NOPMD + private volatile boolean initialized; // NOPMD + /** {@inheritDoc} */ @Override public String getEngineName() { @@ -33,10 +36,14 @@ public int getEngineRank() { /** {@inheritDoc} */ @Override public Engine getEngine() { - return InstanceHolder.INSTANCE; - } - - private static class InstanceHolder { - static final Engine INSTANCE = MxEngine.newInstance(); + if (!initialized) { + synchronized (MxEngineProvider.class) { + if (!initialized) { + initialized = true; + engine = MxEngine.newInstance(); + } + } + } + return engine; } } diff --git a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngineProvider.java b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngineProvider.java index 005c0fa25f1..5616eb80edb 100644 --- a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngineProvider.java +++ b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngineProvider.java @@ -18,6 +18,9 @@ /** {@code OrtEngineProvider} is the ONNX Runtime implementation of {@link EngineProvider}. */ public class OrtEngineProvider implements EngineProvider { + private volatile Engine engine; // NOPMD + private volatile boolean initialized; // NOPMD + /** {@inheritDoc} */ @Override public String getEngineName() { @@ -33,10 +36,14 @@ public int getEngineRank() { /** {@inheritDoc} */ @Override public Engine getEngine() { - return InstanceHolder.INSTANCE; - } - - private static class InstanceHolder { - static final Engine INSTANCE = OrtEngine.newInstance(); + if (!initialized) { + synchronized (OrtEngineProvider.class) { + if (!initialized) { + initialized = true; + engine = OrtEngine.newInstance(); + } + } + } + return engine; } } diff --git a/engines/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpEngineProvider.java b/engines/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpEngineProvider.java index 59e5cd90724..e2fb86974f5 100644 --- a/engines/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpEngineProvider.java +++ b/engines/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpEngineProvider.java @@ -18,6 +18,9 @@ /** {@code PpEngineProvider} is the PaddlePaddle implementation of {@link EngineProvider}. */ public class PpEngineProvider implements EngineProvider { + private volatile Engine engine; // NOPMD + private volatile boolean initialized; // NOPMD + /** {@inheritDoc} */ @Override public String getEngineName() { @@ -33,10 +36,14 @@ public int getEngineRank() { /** {@inheritDoc} */ @Override public Engine getEngine() { - return InstanceHolder.INSTANCE; - } - - private static class InstanceHolder { - static final Engine INSTANCE = PpEngine.newInstance(); + if (!initialized) { + synchronized (PpEngineProvider.class) { + if (!initialized) { + initialized = true; + engine = PpEngine.newInstance(); + } + } + } + return engine; } } diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtEngineProvider.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtEngineProvider.java index 42ca3c5b8a5..24be3e91d7a 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtEngineProvider.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtEngineProvider.java @@ -18,7 +18,8 @@ /** {@code PtEngineProvider} is the PyTorch implementation of {@link EngineProvider}. */ public class PtEngineProvider implements EngineProvider { - private static volatile Engine engine; // NOPMD + private volatile Engine engine; // NOPMD + private volatile boolean initialized; // NOPMD /** {@inheritDoc} */ @Override @@ -35,9 +36,10 @@ public int getEngineRank() { /** {@inheritDoc} */ @Override public Engine getEngine() { - if (engine == null) { + if (!initialized) { synchronized (PtEngineProvider.class) { - if (engine == null) { + if (!initialized) { + initialized = true; engine = PtEngine.newInstance(); } } diff --git a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfEngineProvider.java b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfEngineProvider.java index ad440a47951..fa7813a49fb 100644 --- a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfEngineProvider.java +++ b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfEngineProvider.java @@ -18,7 +18,8 @@ /** {@code TfEngineProvider} is the TensorFlow implementation of {@link EngineProvider}. */ public class TfEngineProvider implements EngineProvider { - private static volatile Engine engine; // NOPMD + private volatile Engine engine; // NOPMD + private volatile boolean initialized; // NOPMD /** {@inheritDoc} */ @Override @@ -35,9 +36,10 @@ public int getEngineRank() { /** {@inheritDoc} */ @Override public Engine getEngine() { - if (engine == null) { + if (!initialized) { synchronized (TfEngineProvider.class) { - if (engine == null) { + if (!initialized) { + initialized = true; engine = TfEngine.newInstance(); } } diff --git a/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtEngineProvider.java b/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtEngineProvider.java index d92ed9e449d..8c90859c6c6 100644 --- a/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtEngineProvider.java +++ b/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtEngineProvider.java @@ -18,6 +18,9 @@ /** {@code TrtEngineProvider} is the TensorRT implementation of {@link EngineProvider}. */ public class TrtEngineProvider implements EngineProvider { + private volatile Engine engine; // NOPMD + private volatile boolean initialized; // NOPMD + /** {@inheritDoc} */ @Override public String getEngineName() { @@ -33,10 +36,14 @@ public int getEngineRank() { /** {@inheritDoc} */ @Override public Engine getEngine() { - return InstanceHolder.INSTANCE; - } - - private static class InstanceHolder { - static final Engine INSTANCE = TrtEngine.newInstance(); + if (!initialized) { + synchronized (TrtEngineProvider.class) { + if (!initialized) { + initialized = true; + engine = TrtEngine.newInstance(); + } + } + } + return engine; } } diff --git a/engines/tensorrt/src/test/java/ai/djl/tensorrt/engine/TrtEngineTest.java b/engines/tensorrt/src/test/java/ai/djl/tensorrt/engine/TrtEngineTest.java index efd9d89e509..96066b380e1 100644 --- a/engines/tensorrt/src/test/java/ai/djl/tensorrt/engine/TrtEngineTest.java +++ b/engines/tensorrt/src/test/java/ai/djl/tensorrt/engine/TrtEngineTest.java @@ -26,7 +26,7 @@ public void getVersion() { try { Engine engine = Engine.getEngine("TensorRT"); version = engine.getVersion(); - } catch (Throwable ignore) { + } catch (Exception ignore) { throw new SkipException("Your os configuration doesn't support TensorRT."); } Assert.assertEquals(version, "8.4.1"); diff --git a/engines/tensorrt/src/test/java/ai/djl/tensorrt/engine/TrtNDManagerTest.java b/engines/tensorrt/src/test/java/ai/djl/tensorrt/engine/TrtNDManagerTest.java index 09001f0e2da..24d734af54c 100644 --- a/engines/tensorrt/src/test/java/ai/djl/tensorrt/engine/TrtNDManagerTest.java +++ b/engines/tensorrt/src/test/java/ai/djl/tensorrt/engine/TrtNDManagerTest.java @@ -28,7 +28,7 @@ public void testNDArray() { Engine engine; try { engine = Engine.getEngine("TensorRT"); - } catch (Throwable ignore) { + } catch (Exception ignore) { throw new SkipException("Your os configuration doesn't support TensorRT."); } if (!engine.defaultDevice().isGpu()) { diff --git a/engines/tensorrt/src/test/java/ai/djl/tensorrt/integration/TrtTest.java b/engines/tensorrt/src/test/java/ai/djl/tensorrt/integration/TrtTest.java index 99cbc6f763e..105e057ba0a 100644 --- a/engines/tensorrt/src/test/java/ai/djl/tensorrt/integration/TrtTest.java +++ b/engines/tensorrt/src/test/java/ai/djl/tensorrt/integration/TrtTest.java @@ -49,7 +49,7 @@ public void testTrtOnnx() throws ModelException, IOException, TranslateException Engine engine; try { engine = Engine.getEngine("TensorRT"); - } catch (Throwable ignore) { + } catch (Exception ignore) { throw new SkipException("Your os configuration doesn't support TensorRT."); } if (!engine.defaultDevice().isGpu()) { @@ -75,7 +75,7 @@ public void testTrtUff() throws ModelException, IOException, TranslateException Engine engine; try { engine = Engine.getEngine("TensorRT"); - } catch (Throwable ignore) { + } catch (Exception ignore) { throw new SkipException("Your os configuration doesn't support TensorRT."); } if (!engine.defaultDevice().isGpu()) { @@ -112,7 +112,7 @@ public void testSerializedEngine() throws ModelException, IOException, Translate Engine engine; try { engine = Engine.getEngine("TensorRT"); - } catch (Throwable ignore) { + } catch (Exception ignore) { throw new SkipException("Your os configuration doesn't support TensorRT."); } Device device = engine.defaultDevice(); diff --git a/engines/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteEngineProvider.java b/engines/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteEngineProvider.java index fb61551a3bf..b46cad53b99 100644 --- a/engines/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteEngineProvider.java +++ b/engines/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteEngineProvider.java @@ -18,6 +18,9 @@ /** {@code TfLiteEngineProvider} is the TFLite implementation of {@link EngineProvider}. */ public class TfLiteEngineProvider implements EngineProvider { + private volatile Engine engine; // NOPMD + private volatile boolean initialized; // NOPMD + /** {@inheritDoc} */ @Override public String getEngineName() { @@ -33,10 +36,14 @@ public int getEngineRank() { /** {@inheritDoc} */ @Override public Engine getEngine() { - return InstanceHolder.INSTANCE; - } - - private static class InstanceHolder { - static final Engine INSTANCE = TfLiteEngine.newInstance(); + if (!initialized) { + synchronized (TfLiteEngineProvider.class) { + if (!initialized) { + initialized = true; + engine = TfLiteEngine.newInstance(); + } + } + } + return engine; } }