Skip to content

Commit

Permalink
Update with an not-merged-yet version of MFA which support upcast int…
Browse files Browse the repository at this point in the history
…ernally.
  • Loading branch information
liuliu committed Dec 12, 2023
1 parent fb458d9 commit b88beca
Show file tree
Hide file tree
Showing 11 changed files with 23 additions and 11 deletions.
7 changes: 5 additions & 2 deletions lib/nnc/ccv_cnnp_model_addons.c
Original file line number Diff line number Diff line change
Expand Up @@ -3090,6 +3090,7 @@ typedef struct {
float scale;
int is_causal;
int has_attn_mask;
int upcast;
int fused_unify_head_weights;
int no_bias;
} ccv_cnnp_model_scaled_dot_product_attention_t;
Expand All @@ -3116,6 +3117,7 @@ static void _ccv_cnnp_scaled_dot_product_attention_build(ccv_cnnp_model_t* const
cmd.cmd = CCV_NNC_SCALED_DOT_PRODUCT_ATTENTION_FORWARD;
cmd.info.scaled_dot_product_attention.scale = self->scale;
cmd.info.scaled_dot_product_attention.is_causal = self->is_causal;
cmd.info.scaled_dot_product_attention.upcast = self->upcast;
ccv_nnc_tensor_param_t output_params[3];
ccv_nnc_tensor_symbol_t output;
ccv_nnc_tensor_symbol_t saved_softmax;
Expand Down Expand Up @@ -3202,7 +3204,7 @@ static const ccv_cnnp_model_vtab_t ccv_cnnp_scaled_dot_product_attention_fused_i
.copy = _ccv_cnnp_scaled_dot_product_attention_copy,
};

ccv_cnnp_model_t* ccv_cnnp_scaled_dot_product_attention(const float scale, const int is_causal, const int has_attn_mask, const int fused_unify_head_weights, const int no_bias, const int is_trainable, const char* const name)
ccv_cnnp_model_t* ccv_cnnp_scaled_dot_product_attention(const float scale, const int is_causal, const int has_attn_mask, const int upcast, const int fused_unify_head_weights, const int no_bias, const int is_trainable, const char* const name)
{
ccv_cnnp_model_scaled_dot_product_attention_t* const model_scaled_dot_product_attention = (ccv_cnnp_model_scaled_dot_product_attention_t*)cccalloc(1, sizeof(ccv_cnnp_model_scaled_dot_product_attention_t));
model_scaled_dot_product_attention->super.isa = fused_unify_head_weights ? &ccv_cnnp_scaled_dot_product_attention_fused_isa : &ccv_cnnp_scaled_dot_product_attention_isa;
Expand All @@ -3218,6 +3220,7 @@ ccv_cnnp_model_t* ccv_cnnp_scaled_dot_product_attention(const float scale, const
model_scaled_dot_product_attention->scale = scale;
model_scaled_dot_product_attention->is_causal = is_causal;
model_scaled_dot_product_attention->has_attn_mask = has_attn_mask;
model_scaled_dot_product_attention->upcast = upcast;
model_scaled_dot_product_attention->fused_unify_head_weights = fused_unify_head_weights;
model_scaled_dot_product_attention->no_bias = no_bias;
return (ccv_cnnp_model_t*)model_scaled_dot_product_attention;
Expand All @@ -3226,5 +3229,5 @@ ccv_cnnp_model_t* ccv_cnnp_scaled_dot_product_attention(const float scale, const
static ccv_cnnp_model_t* _ccv_cnnp_scaled_dot_product_attention_copy(const ccv_cnnp_model_t* const super, void* const context)
{
const ccv_cnnp_model_scaled_dot_product_attention_t* const self = (const ccv_cnnp_model_scaled_dot_product_attention_t*)super;
return ccv_cnnp_scaled_dot_product_attention(self->scale, self->is_causal, self->has_attn_mask, self->fused_unify_head_weights, self->no_bias, self->super.is_trainable, self->super.name);
return ccv_cnnp_scaled_dot_product_attention(self->scale, self->is_causal, self->has_attn_mask, self->upcast, self->fused_unify_head_weights, self->no_bias, self->super.is_trainable, self->super.name);
}
4 changes: 3 additions & 1 deletion lib/nnc/ccv_nnc.h
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ typedef struct {
struct {
float scale; /**< [scaled_dot_product_attention.scale] The scale we multiple to the dot product of Q & K */
int is_causal; /**< [scaled_dot_product_attention.is_causal] Whether we have causal matrix associated with the attention. The attention mask will be cut to triangular if provided. */
int upcast; /**< [scaled_dot_product_attention.upcast] Whether we want to run the attention computation at higher precision (from FP16 to FP32). */
} scaled_dot_product_attention;
void* userdata;
};
Expand Down Expand Up @@ -4451,13 +4452,14 @@ CCV_WARN_UNUSED(ccv_cnnp_model_t*) ccv_cnnp_scalar(const int type, const int for
* @param scale The scale to be applied to the qk dot product.
* @param is_causal Whether to apply is_causal mask to it. If both attn_mask and is_causal supplied, we will cut attn_mask to upper right triangle.
* @param has_attn_mask Whether the input would accept a 4th parameter the attention mask.
* @param upcast Whether the attention computation will be run at higher precision (from FP16 to FP32).
* @param fused_unify_head_weights Whether we also have unifying head weight fused into it. The output would be in shape of (N, S, H * Ev).
* @param no_bias Whether we have bias or not for the unifying head output.
* @param is_trainable Whether or not it is trainable (if weight / bias provided).
* @param name The unique name of the model.
* @return A model that can apply scaled dot product attention compute.
*/
CCV_WARN_UNUSED(ccv_cnnp_model_t*) ccv_cnnp_scaled_dot_product_attention(const float scale, const int is_causal, const int has_attn_mask, const int fused_unify_head_weights, const int no_bias, const int is_trainable, const char* const name);
CCV_WARN_UNUSED(ccv_cnnp_model_t*) ccv_cnnp_scaled_dot_product_attention(const float scale, const int is_causal, const int has_attn_mask, const int upcast, const int fused_unify_head_weights, const int no_bias, const int is_trainable, const char* const name);

/** @} */

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ static int _ccv_nnc_scaled_dot_product_attention_forw(const ccv_nnc_cmd_t cmd, c
.alpha = cmd.info.scaled_dot_product_attention.scale,
.batched = (attention_is_batched ? 1 : 0),
.masked = (attn_mask != NULL ? 1 : 0),
.upcast = cmd.info.scaled_dot_product_attention.upcast,

.batch_dims_q = { 0 },
.batch_dims_mask = { 0 },
Expand Down
Binary file removed lib/nnc/mfa/3rdparty/libmfaios16-v0.2.metallib
Binary file not shown.
Binary file not shown.
Binary file removed lib/nnc/mfa/3rdparty/libmfamacos13-v0.2.metallib
Binary file not shown.
Binary file not shown.
4 changes: 2 additions & 2 deletions lib/nnc/mfa/ccv_nnc_mfa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,9 +168,9 @@ mfa::context::context(MTL::Device* device)
this->library = NS::TransferPtr(device->newLibrary(url, &error));
} else {
#if TARGET_OS_IPHONE
dispatch_data_t data = dispatch_data_create(libmfaios16_v1_0_1_metallib, sizeof(libmfaios16_v1_0_1_metallib), NULL, 0);
dispatch_data_t data = dispatch_data_create(libmfaios16_v1_0_2_a_metallib, sizeof(libmfaios16_v1_0_2_a_metallib), NULL, 0);
#else
dispatch_data_t data = dispatch_data_create(libmfamacos13_v1_0_1_metallib, sizeof(libmfamacos13_v1_0_1_metallib), NULL, 0);
dispatch_data_t data = dispatch_data_create(libmfamacos13_v1_0_2_a_metallib, sizeof(libmfamacos13_v1_0_2_a_metallib), NULL, 0);
#endif
this->library = NS::TransferPtr(device->newLibrary(data, &error));
dispatch_release(data);
Expand Down
8 changes: 6 additions & 2 deletions lib/nnc/mfa/ccv_nnc_mfa_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ mfa::attention::hash::hash(ccv_nnc_mfa_attention_params_t params) {
alpha = params.alpha;
batched = params.batched;
masked = params.masked;
upcast = params.upcast;
}

bool mfa::attention::hash::operator==(const mfa::attention::hash& hash) const {
Expand All @@ -168,7 +169,8 @@ bool mfa::attention::hash::operator==(const mfa::attention::hash& hash) const {
(O_trans == hash.O_trans) &&
(alpha == hash.alpha) &&
(batched == hash.batched) &&
(masked == hash.masked);
(masked == hash.masked) &&
(upcast == hash.upcast);
}

std::ostream& operator<<(std::ostream& os, const mfa::attention::hash& hash) {
Expand All @@ -184,7 +186,8 @@ std::ostream& operator<<(std::ostream& os, const mfa::attention::hash& hash) {
os << " .O_trans = " << bool(hash.O_trans) << ',';
os << " .alpha = " << double(hash.alpha) << ',';
os << " .batched = " << bool(hash.batched) << ',';
os << " .masked = " << bool(hash.masked) << " ";
os << " .masked = " << bool(hash.masked) << ", ";
os << " .upcast = " << bool(hash.upcast) << " ";
os << "}";
return os;
}
Expand Down Expand Up @@ -218,6 +221,7 @@ mfa::attention::pipeline::pipeline(mfa::context* context, mfa::attention::hash h
constants->setConstantValue(&hash.data_type, MTL::DataTypeUInt, 30);
constants->setConstantValue(&hash.batched, MTL::DataTypeBool, 100);
constants->setConstantValue(&hash.masked, MTL::DataTypeBool, 50000);
constants->setConstantValue(&hash.upcast, MTL::DataTypeBool, 114);

{
bool block_sparse = hash.masked;
Expand Down
2 changes: 2 additions & 0 deletions lib/nnc/mfa/ccv_nnc_mfa_attention.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ typedef struct {
float alpha;
uint8_t batched;
uint8_t masked;
uint8_t upcast;

// Since grouped queries are not supported yet, assume Q, K, V, and O all have
// the same batch dimensions.
Expand Down Expand Up @@ -45,6 +46,7 @@ class hash {
float alpha;
uint8_t batched;
uint8_t masked;
uint8_t upcast;

hash(ccv_nnc_mfa_attention_params_t);

Expand Down
8 changes: 4 additions & 4 deletions lib/nnc/mfa/libmfa.inc

Large diffs are not rendered by default.

0 comments on commit b88beca

Please sign in to comment.