diff --git a/lib/nnc/ccv_nnc_tensor_io.c b/lib/nnc/ccv_nnc_tensor_io.c index 112c34c30..9f31dad5b 100644 --- a/lib/nnc/ccv_nnc_tensor_io.c +++ b/lib/nnc/ccv_nnc_tensor_io.c @@ -188,7 +188,7 @@ int ccv_nnc_tensor_read(void* const handle, const char* const name, const char* int dim[CCV_NNC_MAX_DIM_ALLOC]; memcpy(dim, sqlite3_column_blob(tensor_select_stmt, 4), ccv_min(sizeof(dim), sqlite3_column_bytes(tensor_select_stmt, 4))); const int nd = ccv_nnc_tensor_nd(dim); - if (datatype != tensor_params.datatype) + if (datatype != tensor_params.datatype && CCV_GET_DATA_TYPE(tensor_params.datatype) != CCV_QX) { // Only ever works for 16F to 32F or 32F to 16F transparently. assert((datatype == CCV_16F && tensor_params.datatype == CCV_32F) || (datatype == CCV_32F && tensor_params.datatype == CCV_16F)); @@ -429,6 +429,9 @@ int ccv_nnc_tensor_read(void* const handle, const char* const name, const char* } #endif } else { + // If it is QX, we need to have a custom decoder to decode properly. + if (datatype != tensor_params.datatype) + { assert(options && options->decode); } size_t data_size = ccv_nnc_tensor_data_size(tensor_params); #ifdef HAVE_CUDA if (!options || !options->decode)