Skip to content

Commit

Permalink
Add amsgrad into ADAM.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuliu committed Oct 2, 2023
1 parent 0ce86f0 commit 0546f0d
Show file tree
Hide file tree
Showing 12 changed files with 1,762 additions and 752 deletions.
6 changes: 3 additions & 3 deletions bin/nnc/iwslt.c
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ static void eval_wmt(const int max_length, const int embedding_size, const char*
.dropout = 0.1,
};
ccv_cnnp_model_t* const wmt = ccv_cnnp_dynamic_new(_dynamic_encoder_decoder, &encoder_decoder_params, 0);
ccv_nnc_cmd_t adam = CMD_ADAM_FORWARD(1, 0.0001, 0.9, 0.98, 0, 1e-9);
ccv_nnc_cmd_t adam = CMD_ADAM_FORWARD(1, 0.0001, 0.9, 0.98, 0, 1e-9, 0);
ccv_nnc_tensor_param_t inputs[4];
inputs[0] = GPU_TENSOR_NCHW(000, 32F, 1, max_length, embedding_size);
inputs[1] = GPU_TENSOR_NCHW(000, 32F, 1, max_length, embedding_size);
Expand Down Expand Up @@ -488,7 +488,7 @@ static void train_wmt(const int epoch_limit, const int src_vocab_size, const int
ccv_cnnp_model_set_data_parallel(wmt, device_count);
const int epoch_end = (ccv_cnnp_dataframe_row_count(train_data) + device_count * batch_size - 1) / (device_count * batch_size);
ccv_cnnp_dataframe_shuffle(train_data);
ccv_nnc_cmd_t adam = CMD_ADAM_FORWARD(1, 0.0001, 0.9, 0.98, 0, 1e-9);
ccv_nnc_cmd_t adam = CMD_ADAM_FORWARD(1, 0.0001, 0.9, 0.98, 0, 1e-9, 0);
const int aux_size = ccv_nnc_minimizer_saved_aux_size(adam);
ccv_nnc_dynamic_graph_t* const dynamic_graph = ccv_nnc_dynamic_graph_new();
ccv_nnc_tensor_t* const seq_vec_ = ccv_nnc_tensor_new(0, CPU_TENSOR_NCHW(32F, max_length, embedding_size), 0);
Expand Down Expand Up @@ -772,7 +772,7 @@ static void train_wmt(const int epoch_limit, const int src_vocab_size, const int
if ((i + 1) % big_step == 0)
{
float learn_rate = 1. / sqrt_d_model * ccv_min(1. / sqrtf((i + 1) / big_step), (float)((i + 1) / big_step) / (sqrtf(warmup_steps) * warmup_steps));
adam = CMD_ADAM_FORWARD((i + 1) / big_step, learn_rate, 0.9, 0.98, 0, 1e-9);
adam = CMD_ADAM_FORWARD((i + 1) / big_step, learn_rate, 0.9, 0.98, 0, 1e-9, 0);
ccv_cnnp_model_set_minimizer(wmt, adam, 0, 0, 0);
for (j = 0; j < device_count; j++)
tvin[j * 2] = src_vocab_vec_grad[j], tvin[j * 2 + 1] = tgt_vocab_vec_grad[j], tvout[j * 2] = src_vocab_vec[j], tvout[j * 2 + 1] = tgt_vocab_vec[j];
Expand Down
1 change: 1 addition & 0 deletions lib/nnc/ccv_nnc.h
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ typedef struct {
float beta2; /**< [adam.beta2] The beta2 hyper-parameter in adam optimizer. */
float decay; /**< [adam.decay] This is the weight decay parameter, which represents L2 regularization. */
float epsilon; /**< [adam.epsilon] The epsilon for standard derivation. */
int amsgrad; /**< [adam.amsgrad] Whether use amsgrad. */
} adam;
struct {
int step; /**< [lamb.step] Step t in lamb optimizer. */
Expand Down
20 changes: 14 additions & 6 deletions lib/nnc/cmd/adam/ccv_nnc_adam.c
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,18 @@

static int _ccv_nnc_adam_forw_bitmask(const ccv_nnc_cmd_param_t cmd, const int input_size, const int output_size, const uint64_t* const input_bitmasks, const int input_bitmask_size, const uint64_t* const output_bitmasks, const int output_bitmask_size)
{
// 3 inputs (gradient, x, momentum, velocity)
// 2 outputs (y, new momentum, new velocity)
if (input_bitmasks[0] == 15u && output_bitmasks[0] == 7u)
return 1;
if (cmd.adam.amsgrad)
{
// 5 inputs (gradient, x, momentum, velocity, v_max)
// 4 outputs (y, new momentum, new velocity, new v_max)
if (input_bitmasks[0] == 31u && output_bitmasks[0] == 15u)
return 1;
} else {
// 4 inputs (gradient, x, momentum, velocity)
// 3 outputs (y, new momentum, new velocity)
if (input_bitmasks[0] == 15u && output_bitmasks[0] == 7u)
return 1;
}
return 0;
}

Expand Down Expand Up @@ -52,7 +60,7 @@ REGISTER_COMMAND(CCV_NNC_ADAM_BACKWARD)(ccv_nnc_cmd_registry_t* const registry)
}

//@REGISTER_EASY_COMMAND_MACRO(CCV_NNC_ADAM_FORWARD)
#define CMD_ADAM_FORWARD(_step, _rate, _beta1, _beta2, _decay, _epsilon) ccv_nnc_cmd(CCV_NNC_ADAM_FORWARD, 0, ((ccv_nnc_cmd_param_t){.size={.dim={1,1,1}},.adam={.step=_step,.rate=_rate,.scale=1,.beta1=_beta1,.beta2=_beta2,.decay=_decay,.epsilon=_epsilon}}), 0)
#define CMD_ADAM_FORWARD(_step, _rate, _beta1, _beta2, _decay, _epsilon, _amsgrad) ccv_nnc_cmd(CCV_NNC_ADAM_FORWARD, 0, ((ccv_nnc_cmd_param_t){.size={.dim={1,1,1}},.adam={.step=_step,.rate=_rate,.scale=1,.beta1=_beta1,.beta2=_beta2,.decay=_decay,.epsilon=_epsilon,.amsgrad=_amsgrad}}), 0)

REGISTER_COMMAND(CCV_NNC_ADAMW_FORWARD)(ccv_nnc_cmd_registry_t* const registry)
FIND_BACKEND(ccv_nnc_adamw_cpu_ref.c, gpu/ccv_nnc_adamw_gpu_ref.cu, mps/ccv_nnc_adamw_mps.m)
Expand All @@ -70,4 +78,4 @@ REGISTER_COMMAND(CCV_NNC_ADAMW_BACKWARD)(ccv_nnc_cmd_registry_t* const registry)
}

//@REGISTER_EASY_COMMAND_MACRO(CCV_NNC_ADAMW_FORWARD)
#define CMD_ADAMW_FORWARD(_step, _rate, _beta1, _beta2, _decay, _epsilon) ccv_nnc_cmd(CCV_NNC_ADAMW_FORWARD, 0, ((ccv_nnc_cmd_param_t){.size={.dim={1,1,1}},.adam={.step=_step,.rate=_rate,.scale=1,.beta1=_beta1,.beta2=_beta2,.decay=_decay,.epsilon=_epsilon}}), 0)
#define CMD_ADAMW_FORWARD(_step, _rate, _beta1, _beta2, _decay, _epsilon, _amsgrad) ccv_nnc_cmd(CCV_NNC_ADAMW_FORWARD, 0, ((ccv_nnc_cmd_param_t){.size={.dim={1,1,1}},.adam={.step=_step,.rate=_rate,.scale=1,.beta1=_beta1,.beta2=_beta2,.decay=_decay,.epsilon=_epsilon,.amsgrad=_amsgrad}}), 0)
128 changes: 96 additions & 32 deletions lib/nnc/cmd/adam/ccv_nnc_adam_cpu_ref.c
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,17 @@

static int _ccv_nnc_adam_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint, const int flags, ccv_nnc_tensor_t* const* const inputs, const int input_size, ccv_nnc_tensor_t* const* const outputs, const int output_size, ccv_nnc_stream_context_t* const stream_context)
{
assert(input_size == 4);
assert(output_size == 3);
assert(input_size >= 4);
assert(output_size >= 3);
ccv_nnc_tensor_view_t* const g = (ccv_nnc_tensor_view_t*)inputs[0];
ccv_nnc_tensor_view_t* const a = (ccv_nnc_tensor_view_t*)inputs[1];
ccv_nnc_tensor_view_t* const m = (ccv_nnc_tensor_view_t*)inputs[2];
ccv_nnc_tensor_view_t* const v = (ccv_nnc_tensor_view_t*)inputs[3];
ccv_nnc_tensor_view_t* const vm = input_size >= 5 ? (ccv_nnc_tensor_view_t*)inputs[4] : 0;
ccv_nnc_tensor_view_t* const b = (ccv_nnc_tensor_view_t*)outputs[0];
ccv_nnc_tensor_view_t* const n = (ccv_nnc_tensor_view_t*)outputs[1];
ccv_nnc_tensor_view_t* const u = (ccv_nnc_tensor_view_t*)outputs[2];
ccv_nnc_tensor_view_t* const um = output_size >= 4 ? (ccv_nnc_tensor_view_t*)outputs[3] : 0;
assert(ccv_nnc_tensor_nd(a->info.dim) <= CCV_NNC_MAX_DIM + 2);
assert(ccv_nnc_tensor_nd(b->info.dim) <= CCV_NNC_MAX_DIM + 2);
// Assuming this is float 32.
Expand All @@ -32,24 +34,34 @@ static int _ccv_nnc_adam_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint
assert(ccv_nnc_tensor_view_check_dim(g, adim));
assert(ccv_nnc_tensor_view_check_dim(m, adim));
assert(ccv_nnc_tensor_view_check_dim(v, adim));
if (vm)
{ assert(ccv_nnc_tensor_view_check_dim(vm, adim)); }
assert(ccv_nnc_tensor_view_check_dim(b, adim));
assert(ccv_nnc_tensor_view_check_dim(n, adim));
assert(ccv_nnc_tensor_view_check_dim(u, adim));
if (um)
{ assert(ccv_nnc_tensor_view_check_dim(um, adim)); }
assert(CCV_NNC_MAX_DIM == 2); // Need to change this logic for CCV_NNC_MAX_DIM == other number.
int gstride[CCV_NNC_MAX_DIM_ALLOC];
int astride[CCV_NNC_MAX_DIM_ALLOC];
int mstride[CCV_NNC_MAX_DIM_ALLOC];
int vstride[CCV_NNC_MAX_DIM_ALLOC];
int vmstride[CCV_NNC_MAX_DIM_ALLOC];
int bstride[CCV_NNC_MAX_DIM_ALLOC];
int nstride[CCV_NNC_MAX_DIM_ALLOC];
int ustride[CCV_NNC_MAX_DIM_ALLOC];
int umstride[CCV_NNC_MAX_DIM_ALLOC];
ccv_nnc_tensor_view_get_stride(g, gstride);
ccv_nnc_tensor_view_get_stride(a, astride);
ccv_nnc_tensor_view_get_stride(m, mstride);
ccv_nnc_tensor_view_get_stride(v, vstride);
if (vm)
ccv_nnc_tensor_view_get_stride(vm, vmstride);
ccv_nnc_tensor_view_get_stride(b, bstride);
ccv_nnc_tensor_view_get_stride(n, nstride);
ccv_nnc_tensor_view_get_stride(u, ustride);
if (um)
ccv_nnc_tensor_view_get_stride(um, umstride);
const int step = cmd.info.adam.step;
const float rate = cmd.info.adam.rate;
const float scale = cmd.info.adam.scale;
Expand All @@ -69,41 +81,93 @@ static int _ccv_nnc_adam_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint
float* const bp = b->data.f32;
float* const np = n->data.f32;
float* const up = u->data.f32;
for (i[0] = 0; i[0] < adim[0]; i[0]++)
if (cmd.info.adam.amsgrad && vm && um)
{
float* const gp0 = gp + i[0] * gstride[0];
float* const ap0 = ap + i[0] * astride[0];
float* const mp0 = mp + i[0] * mstride[0];
float* const vp0 = vp + i[0] * vstride[0];
float* const bp0 = bp + i[0] * bstride[0];
float* const np0 = np + i[0] * nstride[0];
float* const up0 = up + i[0] * ustride[0];
for (i[1] = 0; i[1] < adim[1]; i[1]++)
float* const vmp = vm->data.f32;
float* const ump = um->data.f32;
for (i[0] = 0; i[0] < adim[0]; i[0]++)
{
float* gp1 = gp0 + i[1] * gstride[1];
float* ap1 = ap0 + i[1] * astride[1];
float* mp1 = mp0 + i[1] * mstride[1];
float* vp1 = vp0 + i[1] * vstride[1];
float* bp1 = bp0 + i[1] * bstride[1];
float* np1 = np0 + i[1] * nstride[1];
float* up1 = up0 + i[1] * ustride[1];
for (i[2] = 0; i[2] < adim[2]; i[2]++)
float* const gp0 = gp + i[0] * gstride[0];
float* const ap0 = ap + i[0] * astride[0];
float* const mp0 = mp + i[0] * mstride[0];
float* const vp0 = vp + i[0] * vstride[0];
float* const vmp0 = vmp + i[0] * vmstride[0];
float* const bp0 = bp + i[0] * bstride[0];
float* const np0 = np + i[0] * nstride[0];
float* const up0 = up + i[0] * ustride[0];
float* const ump0 = ump + i[0] * umstride[0];
for (i[1] = 0; i[1] < adim[1]; i[1]++)
{
for (x = 0; x < adim[3]; x++)
float* gp1 = gp0 + i[1] * gstride[1];
float* ap1 = ap0 + i[1] * astride[1];
float* mp1 = mp0 + i[1] * mstride[1];
float* vp1 = vp0 + i[1] * vstride[1];
float* vmp1 = vmp0 + i[1] * vmstride[1];
float* bp1 = bp0 + i[1] * bstride[1];
float* np1 = np0 + i[1] * nstride[1];
float* up1 = up0 + i[1] * ustride[1];
float* ump1 = ump0 + i[1] * umstride[1];
for (i[2] = 0; i[2] < adim[2]; i[2]++)
{
float grad = scale * gp1[x];
grad += decay * ap1[x];
const float mom = np1[x] = beta1 * mp1[x] + (1 - beta1) * grad;
const float vel = up1[x] = beta2 * vp1[x] + (1 - beta2) * grad * grad;
bp1[x] = ap1[x] - (mom * rate_inv_bias_correction1) / (sqrtf(vel * inv_bias_correction2) + epsilon);
for (x = 0; x < adim[3]; x++)
{
float grad = scale * gp1[x];
grad += decay * ap1[x];
const float mom = np1[x] = beta1 * mp1[x] + (1 - beta1) * grad;
const float vel = up1[x] = beta2 * vp1[x] + (1 - beta2) * grad * grad;
const float vel_hat = vel * inv_bias_correction2;
const float vel_max_hat = ump1[x] = ccv_max(vmp1[x], vel_hat);
bp1[x] = ap1[x] - (mom * rate_inv_bias_correction1) / (sqrtf(vel_max_hat) + epsilon);
}
gp1 += gstride[2];
ap1 += astride[2];
mp1 += mstride[2];
vp1 += vstride[2];
vmp1 += vmstride[2];
bp1 += bstride[2];
np1 += nstride[2];
up1 += ustride[2];
ump1 += umstride[2];
}
}
}
} else {
for (i[0] = 0; i[0] < adim[0]; i[0]++)
{
float* const gp0 = gp + i[0] * gstride[0];
float* const ap0 = ap + i[0] * astride[0];
float* const mp0 = mp + i[0] * mstride[0];
float* const vp0 = vp + i[0] * vstride[0];
float* const bp0 = bp + i[0] * bstride[0];
float* const np0 = np + i[0] * nstride[0];
float* const up0 = up + i[0] * ustride[0];
for (i[1] = 0; i[1] < adim[1]; i[1]++)
{
float* gp1 = gp0 + i[1] * gstride[1];
float* ap1 = ap0 + i[1] * astride[1];
float* mp1 = mp0 + i[1] * mstride[1];
float* vp1 = vp0 + i[1] * vstride[1];
float* bp1 = bp0 + i[1] * bstride[1];
float* np1 = np0 + i[1] * nstride[1];
float* up1 = up0 + i[1] * ustride[1];
for (i[2] = 0; i[2] < adim[2]; i[2]++)
{
for (x = 0; x < adim[3]; x++)
{
float grad = scale * gp1[x];
grad += decay * ap1[x];
const float mom = np1[x] = beta1 * mp1[x] + (1 - beta1) * grad;
const float vel = up1[x] = beta2 * vp1[x] + (1 - beta2) * grad * grad;
bp1[x] = ap1[x] - (mom * rate_inv_bias_correction1) / (sqrtf(vel * inv_bias_correction2) + epsilon);
}
gp1 += gstride[2];
ap1 += astride[2];
mp1 += mstride[2];
vp1 += vstride[2];
bp1 += bstride[2];
np1 += nstride[2];
up1 += ustride[2];
}
gp1 += gstride[2];
ap1 += astride[2];
mp1 += mstride[2];
vp1 += vstride[2];
bp1 += bstride[2];
np1 += nstride[2];
up1 += ustride[2];
}
}
}
Expand Down
Loading

0 comments on commit 0546f0d

Please sign in to comment.