Skip to content

Commit

Permalink
Fix issues when reinit tensor arena, it doesn't handle quantized tens…
Browse files Browse the repository at this point in the history
…ors well.
  • Loading branch information
liuliu committed Dec 22, 2023
1 parent d6045cf commit 54f82af
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 3 deletions.
29 changes: 29 additions & 0 deletions lib/nnc/ccv_nnc_easy.h
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,35 @@ static inline size_t ccv_nnc_tensor_data_size(const ccv_nnc_tensor_param_t param
return ((data_size + 63) & -64);
}

static inline size_t ccv_nnc_tensor_decompressed_data_size_without_padding(const ccv_nnc_tensor_param_t params)
{
const ssize_t count = (ssize_t)ccv_nnc_tensor_count(params);
ssize_t data_size;
if (CCV_GET_DATA_TYPE(params.datatype) == CCV_QX)
{
// Our QX right now only does palettization. Hence, we need to get the palette datatype.
const int palette_datatype = (params.datatype & 0xff) << 12;
data_size = CCV_GET_DATA_TYPE_SIZE(palette_datatype) * count;
} else
data_size = CCV_GET_DATA_TYPE_SIZE(params.datatype) * count;
return data_size;
}

static inline size_t ccv_nnc_tensor_decompressed_data_size(const ccv_nnc_tensor_param_t params)
{
ssize_t data_size = ccv_nnc_tensor_decompressed_data_size_without_padding(params);
#ifdef HAVE_CUDA // For CUDA, we align to 128-bytes.
if (CCV_TENSOR_GET_MEMORY(params.type) == CCV_TENSOR_GPU_MEMORY)
return ((data_size + 127) & -128);
else
#elif defined(HAVE_MPS) // For MPS, we have to align to PAGE_SIZE.
if (CCV_TENSOR_GET_MEMORY(params.type) == CCV_TENSOR_GPU_MEMORY)
return ((data_size + PAGE_SIZE - 1) & -PAGE_SIZE);
else
#endif
return ((data_size + 63) & -64);
}

static inline void ccv_nnc_tensor_view_get_dim(const ccv_nnc_tensor_view_t* const tv, int dim[CCV_NNC_MAX_DIM_ALLOC])
{
int x;
Expand Down
12 changes: 9 additions & 3 deletions lib/nnc/ccv_nnc_symbolic_graph_compile.c
Original file line number Diff line number Diff line change
Expand Up @@ -4214,7 +4214,7 @@ int ccv_nnc_tensor_arena_reinit(ccv_nnc_tensor_arena_t* const tensor_arena, cons
mv = (ccv_nnc_tensor_multiview_t*)(mv->it ? mv->it : CCV_NNC_MULTIVIEW_DATA(mv)[0]);
tensor = (ccv_nnc_tensor_t*)mv;
}
tensor_arena->vt_sizes[i] = ccv_nnc_tensor_data_size(tensor->info);
tensor_arena->vt_sizes[i] = ccv_nnc_tensor_decompressed_data_size(tensor->info);
}
}
int flag = 0;
Expand All @@ -4235,11 +4235,17 @@ int ccv_nnc_tensor_arena_reinit(ccv_nnc_tensor_arena_t* const tensor_arena, cons
{
assert(!tensor_arena->vt_alias_refs[i]);
_ccv_nnc_multiview_update_params((ccv_nnc_tensor_multiview_t*)tensor, symbol_info->info);
} else if (!tensor_arena->vt_alias_refs[i])
} else if (!tensor_arena->vt_alias_refs[i]) {
ccv_nnc_tensor_param_t params = tensor->info;
tensor->info = symbol_info->info;
else {
tensor->info.datatype = params.datatype;
tensor->info.reserved = params.reserved;
} else {
off_t off = ccv_nnc_tensor_view_offset(tensor->info.datatype, symbol_info->stride, symbol_info->ofs);
ccv_nnc_tensor_param_t params = tensor->info;
tensor->info = symbol_info->info;
tensor->info.datatype = params.datatype;
tensor->info.reserved = params.reserved;
const int alias_ref = tensor_arena->vt_alias_refs[i] - 1;
ccv_nnc_tensor_data(tensor->info, tensor_arena->vt_tensors[alias_ref]->data.u8, off + tensor_arena->vt_tensors[alias_ref]->dataof, &tensor->data, &tensor->dataof);
if (CCV_IS_TENSOR_VIEW(tensor))
Expand Down

0 comments on commit 54f82af

Please sign in to comment.