Skip to content

Commit

Permalink
No public description
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 693513345
  • Loading branch information
tensorflower-gardener committed Nov 6, 2024
1 parent 38f781c commit ca66e59
Showing 1 changed file with 19 additions and 0 deletions.
19 changes: 19 additions & 0 deletions official/nlp/modeling/layers/rezero_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ def __init__(self,
num_kv_heads=None,
src_block_size=None,
tgt_block_size=None,
linformer_dim=None,
linformer_shared_kv_projection=True,
use_sigmoid_attn=False,
sigmoid_attn_bias=None,
**kwargs):
# attention_dropout will override attention_dropout_rate.
# This is to unify the input params with TransformerEncoderBlock.
Expand Down Expand Up @@ -115,6 +119,15 @@ def __init__(self,
self._num_kv_heads = num_kv_heads
self._src_block_size = src_block_size
self._tgt_block_size = tgt_block_size
self._linformer_dim = linformer_dim
self._linformer_shared_kv_projection = linformer_shared_kv_projection
self._use_sigmoid_attn = use_sigmoid_attn
self._sigmoid_attn_bias = sigmoid_attn_bias
if self._linformer_dim is not None or self._use_sigmoid_attn:
raise ValueError(
"Linformer and Sigmoid attention are not supported in ReZero"
" Transformer."
)
if self._num_kv_heads is not None and self._src_block_size is not None:
raise ValueError(
"Block sparse attention does not support Multi-query attention."
Expand Down Expand Up @@ -284,6 +297,12 @@ def get_config(self):
tf_keras.constraints.serialize(self._kernel_constraint),
"bias_constraint":
tf_keras.constraints.serialize(self._bias_constraint),
"linformer_dim": self._linformer_dim,
"linformer_shared_kv_projection": (
self._linformer_shared_kv_projection
),
"use_sigmoid_attn": self._use_sigmoid_attn,
"sigmoid_attn_bias": self._sigmoid_attn_bias,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
Expand Down

0 comments on commit ca66e59

Please sign in to comment.