Skip to content

Commit

Permalink
Fix weight loading errors when implicit conversion happens.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuliu committed Sep 15, 2023
1 parent 1be99b7 commit 9cfdad0
Showing 1 changed file with 17 additions and 5 deletions.
22 changes: 17 additions & 5 deletions lib/nnc/ccv_nnc_tensor_io.c
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,11 @@ int ccv_nnc_tensor_read(void* const handle, const char* const name, const char*
copying = workspace + data_size;
else
ccv_half_precision_to_float((uint16_t*)(workspace + data_size), (float*)workspace, ccv_min(tensor_count, ccv_min(source_data_size, decoded_size) / sizeof(uint16_t)));
} else
} else {
if (!tensor)
*tensor_out = tensor = ccv_nnc_tensor_new(0, tensor_params, 0);
ccv_half_precision_to_float((uint16_t*)data, (float*)workspace, ccv_min(tensor_count, sqlite3_column_bytes(tensor_select_stmt, 0) / sizeof(uint16_t)));
}
} else if (datatype == CCV_32F && tensor_params.datatype == CCV_16F) {
size_t decoded_size = source_data_size;
if (options->decode(data, sqlite3_column_bytes(tensor_select_stmt, 0), datatype, dim, nd, identifier, options->context, tensor_params, tensor_out, workspace + data_size, &decoded_size))
Expand All @@ -217,8 +220,11 @@ int ccv_nnc_tensor_read(void* const handle, const char* const name, const char*
copying = workspace + data_size;
else
ccv_float_to_half_precision((float*)(workspace + data_size), (uint16_t*)workspace, ccv_min(tensor_count, ccv_min(source_data_size, decoded_size) / sizeof(float)));
} else
} else {
if (!tensor)
*tensor_out = tensor = ccv_nnc_tensor_new(0, tensor_params, 0);
ccv_float_to_half_precision((float*)data, (uint16_t*)workspace, ccv_min(tensor_count, sqlite3_column_bytes(tensor_select_stmt, 0) / sizeof(float)));
}
} else
{ assert(0); }
}
Expand Down Expand Up @@ -283,7 +289,7 @@ int ccv_nnc_tensor_read(void* const handle, const char* const name, const char*
else
{ assert(0); }
} else {
workspace = ccmalloc(data_size + source_data_size);
copying = workspace = ccmalloc(data_size + source_data_size);
if (datatype == CCV_16F && tensor_params.datatype == CCV_32F)
{
size_t decoded_size = source_data_size;
Expand All @@ -293,8 +299,11 @@ int ccv_nnc_tensor_read(void* const handle, const char* const name, const char*
copying = workspace + data_size;
else
ccv_half_precision_to_float((uint16_t*)(workspace + data_size), (float*)workspace, ccv_min(tensor_count, ccv_min(source_data_size, decoded_size) / sizeof(uint16_t)));
} else
} else {
if (!tensor)
*tensor_out = tensor = ccv_nnc_tensor_new(0, tensor_params, 0);
ccv_half_precision_to_float((uint16_t*)data, (float*)workspace, ccv_min(tensor_count, sqlite3_column_bytes(tensor_select_stmt, 0) / sizeof(uint16_t)));
}
} else if (datatype == CCV_32F && tensor_params.datatype == CCV_16F) {
size_t decoded_size = source_data_size;
if (options->decode(data, sqlite3_column_bytes(tensor_select_stmt, 0), datatype, dim, nd, identifier, options->context, tensor_params, tensor_out, workspace + data_size, &decoded_size))
Expand All @@ -303,8 +312,11 @@ int ccv_nnc_tensor_read(void* const handle, const char* const name, const char*
copying = workspace + data_size;
else
ccv_float_to_half_precision((float*)(workspace + data_size), (uint16_t*)workspace, ccv_min(tensor_count, ccv_min(source_data_size, decoded_size) / sizeof(float)));
} else
} else {
if (!tensor)
*tensor_out = tensor = ccv_nnc_tensor_new(0, tensor_params, 0);
ccv_float_to_half_precision((float*)data, (uint16_t*)workspace, ccv_min(tensor_count, sqlite3_column_bytes(tensor_select_stmt, 0) / sizeof(float)));
}
} else
{ assert(0); }
}
Expand Down

0 comments on commit 9cfdad0

Please sign in to comment.