diff --git a/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbNDArray.java b/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbNDArray.java index 69eb6914b0e..62f036469e9 100644 --- a/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbNDArray.java +++ b/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbNDArray.java @@ -88,14 +88,35 @@ public ByteBuffer toByteBuffer(boolean tryDirect) { /** {@inheritDoc} */ @Override public void intern(NDArray replaced) { - if (handle != null && handle.get() != 0L) { - long pointer = handle.getAndSet(0L); - JniUtils.deleteDMatrix(pointer); + if (replaced == null) { + throw new IllegalArgumentException("The replaced NDArray cannot be null."); + } + if (!(replaced instanceof XgbNDArray)) { + throw new IllegalArgumentException("The replaced NDArray must be an instance of XgbNDArray."); } XgbNDArray array = (XgbNDArray) replaced; - data = array.data; - handle = array.handle; - format = array.format; + + synchronized (this) { + if (handle != null && handle.get() != 0L) { + long pointer = handle.getAndSet(0L); + JniUtils.deleteDMatrix(pointer); + } + + data = array.data; + format = array.format; + + if (array.handle != null && array.handle.get() != 0L) { + if (handle == null) { + handle = new AtomicLong(); + } + handle.set(array.handle.getAndSet(0L)); + } + } + + array.data = null; + array.handle = null; + array.format = null; + array.close(); } /** {@inheritDoc} */