Skip to content

Commit

Permalink
Dilation support in convolution.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuliu committed Dec 20, 2023
1 parent abd9796 commit 76fb9ea
Show file tree
Hide file tree
Showing 9 changed files with 286 additions and 29 deletions.
1 change: 1 addition & 0 deletions lib/nnc/ccv_nnc.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ typedef struct {
struct {
int count; /**< [convolution.count] The number of filters for convolutional layer. */
int groups; /**< [convolution.groups] The number of groups for convolutional layer. */
int dilation[CCV_NNC_MAX_DIM_ALLOC]; /**< [convolution.dilation[]] The dilation factor for convolutional layer. Default to 1. */
} convolution;
struct {
int hidden_size; /**< [rnn.hidden_size] The number of features in the hidden state h. */
Expand Down
42 changes: 42 additions & 0 deletions lib/nnc/cmd/blas/ccv_nnc_blas.c
Original file line number Diff line number Diff line change
Expand Up @@ -234,3 +234,45 @@ REGISTER_COMMAND(CCV_NNC_SCALAR_MUL_BACKWARD)(ccv_nnc_cmd_registry_t* const regi
#define CMD_SCALAR_MUL_FORWARD(_a) ccv_nnc_cmd(CCV_NNC_SCALAR_MUL_FORWARD, 0, (ccv_nnc_cmd_param_t){.size={.dim={1,1,1}},.blas={.a={_a,}}}, 0)
//@REGISTER_EASY_COMMAND_MACRO(CCV_NNC_SCALAR_MUL_BACKWARD)
#define CMD_SCALAR_MUL_BACKWARD(_a) ccv_nnc_cmd(CCV_NNC_SCALAR_MUL_BACKWARD, 0, (ccv_nnc_cmd_param_t){.size={.dim={1,1,1}},.blas={.a={_a,}}}, 0)

static int _ccv_nnc_cmul_forw_bitmask(const ccv_nnc_cmd_param_t cmd, const int input_size, const int output_size, const uint64_t* const input_bitmasks, const int input_bitmask_size, const uint64_t* const output_bitmasks, const int output_bitmask_size)
{
if ((input_bitmasks[0] & 3u) == ((1u << 0) | (1u << 1)) && output_bitmasks[0] == 1u)
return 1;
return 0;
}

static int _ccv_nnc_cmul_back_bitmask(const ccv_nnc_cmd_param_t cmd, const int input_size, const int output_size, const uint64_t* const input_bitmasks, const int input_bitmask_size, const uint64_t* const output_bitmasks, const int output_bitmask_size)
{
// w.r.t. both x and y
if ((input_bitmasks[0] & 7u) == 7u && output_bitmasks[0] == ((1u << 0) | (1u << 1)))
return 1;
// w.r.t. x
if ((input_bitmasks[0] & 5u) == 5u && output_bitmasks[0] == ((1u << 0) | (0u << 1)))
return 1;
// w.r.t. y
if ((input_bitmasks[0] & 3u) == 3u && output_bitmasks[0] == ((0u << 0) | (1u << 1)))
return 1;
return 0;
}

REGISTER_COMMAND(CCV_NNC_CMUL_FORWARD)(ccv_nnc_cmd_registry_t* const registry)
FIND_BACKEND(ccv_nnc_cmul_cpu_ref.c)
{
registry->bitmask = _ccv_nnc_cmul_forw_bitmask;
registry->tensor_auto = _ccv_nnc_broadcast_tensor_auto_forw;
registry->allow_inplace = _ccv_nnc_same_pos_inplace;
}

REGISTER_COMMAND(CCV_NNC_CMUL_BACKWARD)(ccv_nnc_cmd_registry_t* const registry)
FIND_BACKEND(ccv_nnc_cmul_cpu_ref.c)
{
registry->flags = CCV_NNC_CMD_ATTR_NULL_IS_ONES;
registry->bitmask = _ccv_nnc_cmul_back_bitmask;
registry->tensor_auto = ccv_nnc_hint_tensor_auto_backward_from_inputs;
}

//@REGISTER_EASY_COMMAND_MACRO(CCV_NNC_CMUL_FORWARD)
#define CMD_CMUL_FORWARD() ccv_nnc_cmd(CCV_NNC_CMUL_FORWARD, 0, (ccv_nnc_cmd_param_t){.size={.dim={1,1,1}}}, 0)
//@REGISTER_EASY_COMMAND_MACRO(CCV_NNC_CMUL_BACKWARD)
#define CMD_CMUL_BACKWARD() ccv_nnc_cmd(CCV_NNC_CMUL_BACKWARD, 0, (ccv_nnc_cmd_param_t){.size={.dim={1,1,1}}}, 0)
2 changes: 2 additions & 0 deletions lib/nnc/cmd/convolution/ccv_nnc_conv_cpu_opt.c
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ static int _ccv_nnc_conv_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint
assert(bdim[CCV_NNC_MAX_DIM] == cmd.info.convolution.count);
if (cmd.info.convolution.groups != 1)
return CCV_NNC_EXEC_INVALID;
if (cmd.info.convolution.dilation[0] > 1 || cmd.info.convolution.dilation[1] > 1)
return CCV_NNC_EXEC_INVALID;
int i;
// Make sure the weights dimension matches the network dimension
for (i = 1; i < CCV_NNC_MAX_DIM_ALLOC; i++)
Expand Down
104 changes: 84 additions & 20 deletions lib/nnc/cmd/convolution/ccv_nnc_conv_cpu_ref.c
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,19 @@ static int _ccv_nnc_conv_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint
ccv_nnc_tensor_view_get_stride(b, bstride);
assert(!bias || bias->info.dim[0] == cmd.info.convolution.count);
const int batch_size = (a_nd == CCV_NNC_MAX_DIM + 2) ? a->info.dim[0] : 1;
const int dilation[CCV_NNC_MAX_DIM] = {
ccv_max(cmd.info.convolution.dilation[0], 1),
ccv_max(cmd.info.convolution.dilation[1], 1)
};
if (a->info.format == CCV_TENSOR_FORMAT_NHWC)
{
// Make sure the weights dimension matches the network dimension
assert(w->info.dim[1] == cmd.info.size.dim[0]);
assert(w->info.dim[2] == cmd.info.size.dim[1]);
const int wdim[CCV_NNC_MAX_DIM] = {
(w->info.dim[1] - 1) * dilation[0] + 1,
(w->info.dim[2] - 1) * dilation[1] + 1
};
assert(w->info.dim[CCV_NNC_MAX_DIM + 1] * groups == adim[CCV_NNC_MAX_DIM]);
assert(b->info.format == CCV_TENSOR_FORMAT_NHWC);
const int channel_size = w->info.dim[CCV_NNC_MAX_DIM + 1];
Expand All @@ -59,25 +67,36 @@ static int _ccv_nnc_conv_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint
// This block will be cause in each for-loop, therefore, you can use it to generate some temporary variables.
int i[CCV_NNC_MAX_DIM];
int n[CCV_NNC_MAX_DIM];
int d[CCV_NNC_MAX_DIM];
int m[CCV_NNC_MAX_DIM];
int j[CCV_NNC_MAX_DIM];
for (i[0] = 0; i[0] < bdim[0]; i[0]++)
{
SET_BORDER_OFFSET_SIZE_FOR(0, i, hint, w->info.dim + 1, adim, n, m);
SET_BORDER_OFFSET_SIZE_FOR(0, i, hint, wdim, adim, n, m);
m[0] = (m[0] + n[0] - 1) / dilation[0] + 1;
const int n0 = (n[0] + dilation[0] - 1) / dilation[0];
d[0] = n0 * dilation[0] - n[0];
n[0] = n0;
m[0] = m[0] - n[0];
float* wpu = wp + n[0] * w->info.dim[CCV_NNC_MAX_DIM] * channel_size;
for (i[1] = 0; i[1] < bdim[1]; i[1]++)
{
SET_BORDER_OFFSET_SIZE_FOR(1, i, hint, w->info.dim + 1, adim, n, m);
SET_BORDER_OFFSET_SIZE_FOR(1, i, hint, wdim, adim, n, m);
m[1] = (m[1] + n[1] - 1) / dilation[1] + 1;
const int n1 = (n[1] + dilation[1] - 1) / dilation[1];
d[1] = n1 * dilation[1] - n[1];
n[1] = n1;
m[1] = m[1] - n[1];
float p = biasval;
float* wpz = wpu + n[1] * channel_size;
float* apz = ap + ccv_max(i[1] * hint.stride.dim[1] - hint.border.begin[1], 0) * astride[CCV_NNC_MAX_DIM] + gidx * channel_size;
float* apz = ap + d[0] * astride[CCV_NNC_MAX_DIM - 1] + (ccv_max(i[1] * hint.stride.dim[1] - hint.border.begin[1], 0) + d[1]) * astride[CCV_NNC_MAX_DIM] + gidx * channel_size;
for (j[0] = 0; j[0] < m[0]; j[0]++)
{
for (j[1] = 0; j[1] < m[1]; j[1]++)
for (c = 0; c < channel_size; c++)
p += wpz[j[1] * channel_size + c] * apz[j[1] * astride[CCV_NNC_MAX_DIM] + c];
p += wpz[j[1] * channel_size + c] * apz[j[1] * dilation[1] * astride[CCV_NNC_MAX_DIM] + c];
wpz += w->info.dim[CCV_NNC_MAX_DIM] * channel_size;
apz += astride[CCV_NNC_MAX_DIM - 1];
apz += astride[CCV_NNC_MAX_DIM - 1] * dilation[0];
}
bp[i[1] * bstride[CCV_NNC_MAX_DIM]] = p;
}
Expand All @@ -89,6 +108,10 @@ static int _ccv_nnc_conv_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint
// Make sure the weights dimension matches the network dimension
assert(w->info.dim[2] == cmd.info.size.dim[0]);
assert(w->info.dim[3] == cmd.info.size.dim[1]);
const int wdim[CCV_NNC_MAX_DIM] = {
(w->info.dim[2] - 1) * dilation[0] + 1,
(w->info.dim[3] - 1) * dilation[1] + 1
};
assert(w->info.dim[1] * groups == adim[0]);
assert(b->info.format == CCV_TENSOR_FORMAT_NCHW);
const int channel_size = w->info.dim[1];
Expand All @@ -107,25 +130,36 @@ static int _ccv_nnc_conv_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint
// This block will be cause in each for-loop, therefore, you can use it to generate some temporary variables.
int i[CCV_NNC_MAX_DIM];
int n[CCV_NNC_MAX_DIM];
int d[CCV_NNC_MAX_DIM];
int m[CCV_NNC_MAX_DIM];
int j[CCV_NNC_MAX_DIM];
for (i[0] = 0; i[0] < bdim[1]; i[0]++)
{
SET_BORDER_OFFSET_SIZE_FOR(0, i, hint, w->info.dim + 2, adim + 1, n, m);
SET_BORDER_OFFSET_SIZE_FOR(0, i, hint, wdim, adim + 1, n, m);
m[0] = (m[0] + n[0] - 1) / dilation[0] + 1;
const int n0 = (n[0] + dilation[0] - 1) / dilation[0];
d[0] = n0 * dilation[0] - n[0];
n[0] = n0;
m[0] = m[0] - n[0];
float* wpu = wp + n[0] * w->info.dim[CCV_NNC_MAX_DIM + 1];
for (i[1] = 0; i[1] < bdim[2]; i[1]++)
{
SET_BORDER_OFFSET_SIZE_FOR(1, i, hint, w->info.dim + 2, adim + 1, n, m);
SET_BORDER_OFFSET_SIZE_FOR(1, i, hint, wdim, adim + 1, n, m);
m[1] = (m[1] + n[1] - 1) / dilation[1] + 1;
const int n1 = (n[1] + dilation[1] - 1) / dilation[1];
d[1] = n1 * dilation[1] - n[1];
n[1] = n1;
m[1] = m[1] - n[1];
float p = biasval;
float* wpz = wpu + n[1];
float* apz = ap + ccv_max(i[1] * hint.stride.dim[1] - hint.border.begin[1], 0) * astride[CCV_NNC_MAX_DIM + 1] + gidx * channel_size * astride[1];
float* apz = ap + d[0] * astride[CCV_NNC_MAX_DIM] + (ccv_max(i[1] * hint.stride.dim[1] - hint.border.begin[1], 0) + d[1]) * astride[CCV_NNC_MAX_DIM + 1] + gidx * channel_size * astride[1];
for (j[0] = 0; j[0] < m[0]; j[0]++)
{
for (j[1] = 0; j[1] < m[1]; j[1]++)
for (c = 0; c < channel_size; c++)
p += wpz[j[1] + c * hw] * apz[j[1] * astride[CCV_NNC_MAX_DIM + 1] + c * astride[1]];
p += wpz[j[1] + c * hw] * apz[j[1] * dilation[1] * astride[CCV_NNC_MAX_DIM + 1] + c * astride[1]];
wpz += w->info.dim[CCV_NNC_MAX_DIM + 1];
apz += astride[CCV_NNC_MAX_DIM];
apz += astride[CCV_NNC_MAX_DIM] * dilation[0];
}
bp[i[1]] = p;
}
Expand Down Expand Up @@ -173,6 +207,14 @@ static int _ccv_nnc_conv_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint
const int group_size = cmd.info.convolution.count / groups;
const int channel_size = w ? w->info.dim[CCV_NNC_MAX_DIM + 1] : inputs[2]->info.dim[CCV_NNC_MAX_DIM + 1];
const int batch_size = (a_nd == CCV_NNC_MAX_DIM + 2) ? a->info.dim[0] : 1;
const int dilation[CCV_NNC_MAX_DIM] = {
ccv_max(cmd.info.convolution.dilation[0], 1),
ccv_max(cmd.info.convolution.dilation[1], 1)
};
const int wdim[CCV_NNC_MAX_DIM] = {
(w->info.dim[1] - 1) * dilation[0] + 1,
(w->info.dim[2] - 1) * dilation[1] + 1
};
if (w)
{
parallel_for(k, cmd.info.convolution.count) {
Expand All @@ -183,6 +225,7 @@ static int _ccv_nnc_conv_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint
float biasval = 0;
int i[CCV_NNC_MAX_DIM];
int n[CCV_NNC_MAX_DIM];
int d[CCV_NNC_MAX_DIM];
int m[CCV_NNC_MAX_DIM];
int j[CCV_NNC_MAX_DIM];
int bidx;
Expand All @@ -192,24 +235,34 @@ static int _ccv_nnc_conv_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint
const float* gp = g->data.f32 + bidx * gstride[0] + k;
for (i[0] = 0; i[0] < gdim[0]; i[0]++)
{
SET_BORDER_OFFSET_SIZE_FOR(0, i, hint, w->info.dim + 1, adim, n, m);
SET_BORDER_OFFSET_SIZE_FOR(0, i, hint, wdim, adim, n, m);
m[0] = (m[0] + n[0] - 1) / dilation[0] + 1;
const int n0 = (n[0] + dilation[0] - 1) / dilation[0];
d[0] = n0 * dilation[0] - n[0];
n[0] = n0;
m[0] = m[0] - n[0];
float* wpu = wp + n[0] * w->info.dim[CCV_NNC_MAX_DIM] * channel_size;
for (i[1] = 0; i[1] < gdim[1]; i[1]++)
{
SET_BORDER_OFFSET_SIZE_FOR(1, i, hint, w->info.dim + 1, adim, n, m);
SET_BORDER_OFFSET_SIZE_FOR(1, i, hint, wdim, adim, n, m);
m[1] = (m[1] + n[1] - 1) / dilation[1] + 1;
const int n1 = (n[1] + dilation[1] - 1) / dilation[1];
d[1] = n1 * dilation[1] - n[1];
n[1] = n1;
m[1] = m[1] - n[1];
const float v = gp[i[1] * gstride[CCV_NNC_MAX_DIM]];
if (v == 0) // shortcut if v is zero
continue;
biasval += v;
float* wpz = wpu + n[1] * channel_size;
const float* apz = ap + ccv_max(i[1] * hint.stride.dim[1] - hint.border.begin[1], 0) * astride[CCV_NNC_MAX_DIM] + gidx * channel_size;
const float* apz = ap + d[0] * astride[CCV_NNC_MAX_DIM - 1] + (ccv_max(i[1] * hint.stride.dim[1] - hint.border.begin[1], 0) + d[1]) * astride[CCV_NNC_MAX_DIM] + gidx * channel_size;
for (j[0] = 0; j[0] < m[0]; j[0]++)
{
for (j[1] = 0; j[1] < m[1]; j[1]++)
for (c = 0; c < channel_size; c++)
wpz[j[1] * channel_size + c] += v * apz[j[1] * astride[CCV_NNC_MAX_DIM] + c];
wpz[j[1] * channel_size + c] += v * apz[j[1] * dilation[1] * astride[CCV_NNC_MAX_DIM] + c];
wpz += w->info.dim[CCV_NNC_MAX_DIM] * channel_size;
apz += astride[CCV_NNC_MAX_DIM - 1];
apz += astride[CCV_NNC_MAX_DIM - 1] * dilation[0];
}
}
gp += gstride[CCV_NNC_MAX_DIM - 1];
Expand Down Expand Up @@ -248,27 +301,38 @@ static int _ccv_nnc_conv_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint
// This block will be cause in each for-loop, therefore, you can use it to generate some temporary variables.
int i[CCV_NNC_MAX_DIM];
int n[CCV_NNC_MAX_DIM];
int d[CCV_NNC_MAX_DIM];
int m[CCV_NNC_MAX_DIM];
int j[CCV_NNC_MAX_DIM];
for (i[0] = 0; i[0] < gdim[0]; i[0]++)
{
SET_BORDER_OFFSET_SIZE_FOR(0, i, hint, w->info.dim + 1, hdim, n, m);
SET_BORDER_OFFSET_SIZE_FOR(0, i, hint, wdim, hdim, n, m);
m[0] = (m[0] + n[0] - 1) / dilation[0] + 1;
const int n0 = (n[0] + dilation[0] - 1) / dilation[0];
d[0] = n0 * dilation[0] - n[0];
n[0] = n0;
m[0] = m[0] - n[0];
const float* wpu = wp + n[0] * w->info.dim[CCV_NNC_MAX_DIM] * channel_size;
for (i[1] = 0; i[1] < gdim[1]; i[1]++)
{
SET_BORDER_OFFSET_SIZE_FOR(1, i, hint, w->info.dim + 1, hdim, n, m);
SET_BORDER_OFFSET_SIZE_FOR(1, i, hint, wdim, hdim, n, m);
m[1] = (m[1] + n[1] - 1) / dilation[1] + 1;
const int n1 = (n[1] + dilation[1] - 1) / dilation[1];
d[1] = n1 * dilation[1] - n[1];
n[1] = n1;
m[1] = m[1] - n[1];
const float v = gp[i[1] * gstride[CCV_NNC_MAX_DIM]];
if (v == 0) // shortcut if v is zero
continue;
const float* wpz = wpu + n[1] * channel_size;
float* hpz = hp + ccv_max(i[1] * hint.stride.dim[1] - hint.border.begin[1], 0) * hstride[CCV_NNC_MAX_DIM] + gidx * channel_size;
float* hpz = hp + d[0] * hstride[CCV_NNC_MAX_DIM - 1] + (ccv_max(i[1] * hint.stride.dim[1] - hint.border.begin[1], 0) + d[1]) * hstride[CCV_NNC_MAX_DIM] + gidx * channel_size;
for (j[0] = 0; j[0] < m[0]; j[0]++)
{
for (j[1] = 0; j[1] < m[1]; j[1]++)
for (c = 0; c < channel_size; c++)
hpz[j[1] * hstride[CCV_NNC_MAX_DIM] + c] += v * wpz[j[1] * channel_size + c];
hpz[j[1] * dilation[1] * hstride[CCV_NNC_MAX_DIM] + c] += v * wpz[j[1] * channel_size + c];
wpz += w->info.dim[CCV_NNC_MAX_DIM] * channel_size;
hpz += hstride[CCV_NNC_MAX_DIM - 1];
hpz += hstride[CCV_NNC_MAX_DIM - 1] * dilation[0];
}
}
gp += gstride[CCV_NNC_MAX_DIM - 1];
Expand Down
6 changes: 5 additions & 1 deletion lib/nnc/cmd/convolution/ccv_nnc_convolution.c
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,11 @@ static void _ccv_nnc_conv_tensor_auto_forw(const ccv_nnc_cmd_param_t cmd, const
assert(count == cmd.convolution.count);
ccv_nnc_tensor_set_c(outputs, ccv_nnc_tensor_nd(inputs[0].dim), count);
ccv_nnc_tensor_set_n(outputs, ccv_nnc_tensor_get_n(inputs[0]));
ccv_nnc_hint_tensor_forward(cmd, inputs[0], hint, outputs);
ccv_nnc_cmd_param_t modified_cmd = cmd;
int i = 0;
for (i = 0; i < CCV_NNC_MAX_DIM; i++)
modified_cmd.size.dim[i] = (modified_cmd.size.dim[i] - 1) * ccv_max(cmd.convolution.dilation[i], 1) + 1;
ccv_nnc_hint_tensor_forward(modified_cmd, inputs[0], hint, outputs);
}

REGISTER_COMMAND(CCV_NNC_CONVOLUTION_FORWARD)(ccv_nnc_cmd_registry_t* const registry)
Expand Down
8 changes: 4 additions & 4 deletions lib/nnc/cmd/convolution/gpu/ccv_nnc_conv_gpu_cudnn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ static int _ccv_nnc_conv_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint
const ccv_nnc_cudnn_tensor_view_descriptor_t a = ccv_nnc_cudnn_get_tensor_view_descriptor(stream_context, (const ccv_nnc_tensor_view_t*)inputs[0]);
const ccv_nnc_cudnn_filter_descriptor_t w = ccv_nnc_cudnn_get_filter_descriptor(stream_context, (const ccv_nnc_tensor_t*)inputs[1]);
const ccv_nnc_cudnn_tensor_view_descriptor_t b = ccv_nnc_cudnn_get_tensor_view_descriptor(stream_context, (const ccv_nnc_tensor_view_t*)outputs[0]);
const ccv_nnc_cudnn_convolution_descriptor_t conv = ccv_nnc_cudnn_get_convolution_descriptor(stream_context, hint, inputs[1]->info.datatype);
const ccv_nnc_cudnn_convolution_descriptor_t conv = ccv_nnc_cudnn_get_convolution_descriptor(stream_context, cmd.info, hint, inputs[1]->info.datatype);
cudnnSetConvolutionGroupCount(conv.descriptor, cmd.info.convolution.groups);

cudnnConvolutionFwdAlgo_t algo;
Expand Down Expand Up @@ -124,7 +124,7 @@ static int _ccv_nnc_conv_forw_autotune(const ccv_nnc_cmd_t cmd, size_t max_works
const ccv_nnc_cudnn_tensor_view_descriptor_t a = ccv_nnc_cudnn_get_tensor_view_descriptor(stream_context, (const ccv_nnc_tensor_view_t*)inputs[0]);
const ccv_nnc_cudnn_filter_descriptor_t w = ccv_nnc_cudnn_get_filter_descriptor(stream_context, (const ccv_nnc_tensor_t*)inputs[1]);
const ccv_nnc_cudnn_tensor_view_descriptor_t b = ccv_nnc_cudnn_get_tensor_view_descriptor(stream_context, (const ccv_nnc_tensor_view_t*)outputs[0]);
const ccv_nnc_cudnn_convolution_descriptor_t conv = ccv_nnc_cudnn_get_convolution_descriptor(stream_context, hint, inputs[1]->info.datatype);
const ccv_nnc_cudnn_convolution_descriptor_t conv = ccv_nnc_cudnn_get_convolution_descriptor(stream_context, cmd.info, hint, inputs[1]->info.datatype);
cudnnSetConvolutionGroupCount(conv.descriptor, cmd.info.convolution.groups);
int count = 0;
cudnnConvolutionFwdAlgoPerf_t perfs[CCV_NNC_CMD_CUDNN_CONV_FWD_ALGO_COUNT];
Expand Down Expand Up @@ -210,7 +210,7 @@ static int _ccv_nnc_conv_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint
const ccv_nnc_cudnn_tensor_view_descriptor_t g = ccv_nnc_cudnn_get_tensor_view_descriptor(stream_context, (const ccv_nnc_tensor_view_t*)inputs[0]);
const int is_w_nhwc = (output_size > 1 && outputs[1]) ? outputs[1]->info.format == CCV_TENSOR_FORMAT_NHWC : inputs[2]->info.format == CCV_TENSOR_FORMAT_NHWC;
const int w_datatype = (output_size > 1 && outputs[1]) ? outputs[1]->info.datatype : inputs[2]->info.datatype;
const ccv_nnc_cudnn_convolution_descriptor_t conv = ccv_nnc_cudnn_get_convolution_descriptor(stream_context, hint, (is_w_nhwc && w_datatype == CCV_16F) ? CCV_32F : w_datatype);
const ccv_nnc_cudnn_convolution_descriptor_t conv = ccv_nnc_cudnn_get_convolution_descriptor(stream_context, cmd.info, hint, (is_w_nhwc && w_datatype == CCV_16F) ? CCV_32F : w_datatype);
cudnnSetConvolutionGroupCount(conv.descriptor, cmd.info.convolution.groups);

static const float one = 1, zero = 0;
Expand Down Expand Up @@ -370,7 +370,7 @@ static int _ccv_nnc_conv_back_autotune(const ccv_nnc_cmd_t cmd, size_t max_works
int count = 0;
const int is_w_nhwc = (output_size > 1 && outputs[1]) ? outputs[1]->info.format == CCV_TENSOR_FORMAT_NHWC : inputs[2]->info.format == CCV_TENSOR_FORMAT_NHWC;
const int w_datatype = (output_size > 1 && outputs[1]) ? outputs[1]->info.datatype : inputs[2]->info.datatype;
const ccv_nnc_cudnn_convolution_descriptor_t conv = ccv_nnc_cudnn_get_convolution_descriptor(stream_context, hint, (is_w_nhwc && w_datatype == CCV_16F) ? CCV_32F : w_datatype);
const ccv_nnc_cudnn_convolution_descriptor_t conv = ccv_nnc_cudnn_get_convolution_descriptor(stream_context, cmd.info, hint, (is_w_nhwc && w_datatype == CCV_16F) ? CCV_32F : w_datatype);
cudnnSetConvolutionGroupCount(conv.descriptor, cmd.info.convolution.groups);
cudnnConvolutionBwdFilterAlgo_t filter_algorithm = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT;
if (output_size > 1 && outputs[1])
Expand Down
Loading

0 comments on commit 76fb9ea

Please sign in to comment.