Skip to content

Commit

Permalink
Relax constraints on weight loading.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuliu committed Dec 13, 2023
1 parent 2330081 commit 1ca3108
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion lib/nnc/ccv_nnc_tensor_io.c
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ int ccv_nnc_tensor_read(void* const handle, const char* const name, const char*
int dim[CCV_NNC_MAX_DIM_ALLOC];
memcpy(dim, sqlite3_column_blob(tensor_select_stmt, 4), ccv_min(sizeof(dim), sqlite3_column_bytes(tensor_select_stmt, 4)));
const int nd = ccv_nnc_tensor_nd(dim);
if (datatype != tensor_params.datatype)
if (datatype != tensor_params.datatype && CCV_GET_DATA_TYPE(tensor_params.datatype) != CCV_QX)
{
// Only ever works for 16F to 32F or 32F to 16F transparently.
assert((datatype == CCV_16F && tensor_params.datatype == CCV_32F) || (datatype == CCV_32F && tensor_params.datatype == CCV_16F));
Expand Down Expand Up @@ -429,6 +429,9 @@ int ccv_nnc_tensor_read(void* const handle, const char* const name, const char*
}
#endif
} else {
// If it is QX, we need to have a custom decoder to decode properly.
if (datatype != tensor_params.datatype)
{ assert(options && options->decode); }
size_t data_size = ccv_nnc_tensor_data_size(tensor_params);
#ifdef HAVE_CUDA
if (!options || !options->decode)
Expand Down

0 comments on commit 1ca3108

Please sign in to comment.