diff --git a/config/diva_flash.yaml b/config/diva_flash.yaml index ae51eb321..7aa49b1b4 100644 --- a/config/diva_flash.yaml +++ b/config/diva_flash.yaml @@ -31,6 +31,8 @@ optimizer: #learning_rate: 5E-5 learning_rate: 5e-4 weight_decay: 0.1 + weight_decay_modules: None + default_weight_decay_mask: False warmup: 0.01 hf_save_path: gs://diva-flash/librispeech-hf-checkpoints diva_training: true diff --git a/src/levanter/main/train_asr.py b/src/levanter/main/train_asr.py index 86cc04cb1..9a8b481c1 100644 --- a/src/levanter/main/train_asr.py +++ b/src/levanter/main/train_asr.py @@ -143,12 +143,12 @@ def compute_loss( model_init=lambda: config.model.build_asr(Vocab, key=model_key), ) + if config.diva_training and config.model.asr_model_type == DivaASRModel: + state = dataclasses.replace(state, model=None) + model = DivaASRModel.init(Vocab, config.model, key=model_key, init_from_submodels=True) + model = named_jit(trainer.mp.cast_to_param, parameter_axis_mapping)(model) + state = dataclasses.replace(state, model=model, is_trainable=diva_connector_only(model)) if int(state.step) == 0: - if config.diva_training and config.model.asr_model_type == DivaASRModel: - state = dataclasses.replace(state, model=None) - model = DivaASRModel.init(Vocab, config.model, key=model_key, init_from_submodels=True) - model = named_jit(trainer.mp.cast_to_param, parameter_axis_mapping)(model) - state = dataclasses.replace(state, model=model, is_trainable=diva_connector_only(model)) # TODO: I don't love that we init the model twice, but it's not a big deal i think? if config.initialize_from_hf: # initialize from an hf pretrained model diff --git a/src/levanter/models/diva.py b/src/levanter/models/diva.py index 08ac6c518..49bea6589 100644 --- a/src/levanter/models/diva.py +++ b/src/levanter/models/diva.py @@ -103,6 +103,7 @@ class DivaConfig(HFCompatConfig, ASRConfig): reference_encoder: str = "openai/whisper-large-v3-turbo" reference_decoder: str = "meta-llama/Llama-3.1-8B-Instruct" reference_checkpoint: str = "WillHeld/DiVA-llama-3-v0-8b" + max_length: int = 448 init_from_submodel: bool = True # Connector Config @@ -118,7 +119,7 @@ class DivaConfig(HFCompatConfig, ASRConfig): ) prefix = property(lambda self: hax.named(self.pre_audio_prompt, axis="position")) suffix = property(lambda self: hax.named(self.pre_text_prompt, axis="position")) - Pos = property(lambda self: Axis(name="position", size=448)) + Pos = property(lambda self: Axis(name="position", size=self.max_length)) AudioPos = property(lambda self: self.enc_config.AudioPos) KeyPos = property(lambda self: self.Pos.alias("key_position")) @@ -138,6 +139,7 @@ def to_hf_config(self, vocab_size: int, config_overrides: Optional[dict] = None) "vocab_size": vocab_size, "reference_encoder": self.reference_encoder, "reference_decoder": self.reference_decoder, + "max_length": self.max_length, } return HfConfig.from_dict(merged_config) @@ -146,7 +148,7 @@ def from_hf_config(cls, hf_config: HfConfig): config_dict = hf_config.to_dict() reference_encoder = config_dict["encoder_reference"] reference_decoder = config_dict["decoder_reference"] - return DivaConfig(reference_encoder, reference_decoder) + return DivaConfig(reference_encoder, reference_decoder, max_length=config_dict["max_length"]) def hf_checkpoint_converter(cls) -> HFCheckpointConverter["DivaModel"]: # type: ignore return DivaHFCheckpointer(cls, cls.reference_checkpoint, trust_remote_code=True) @@ -198,7 +200,7 @@ def init( query_tokens = hax.random.normal(k_query, (config.Pos, config.enc_config.Embed)) * 0.02 projection = hnn.Linear.init( - In=config.enc_config.Embed.alias("whisp_embed"), Out=config.dec_config.Embed, key=key + In=config.enc_config.Embed.alias("whisp_embed"), Out=config.dec_config.Embed, init_scale=0.01, key=key ) if init_from_submodels: @@ -217,8 +219,15 @@ def init( config.enc_config, ) # type: ignore[assignment] encoder = whisper.encoder - connector = whisper.decoder + # connector = whisper.decoder + connector = WhisperDecoder.init(config.enc_config, key=k_connector) decoder = llm + mean_embedding = hax.mean(llm.embeddings.token_embeddings.weight, llm.embeddings.Vocab) + projection = dataclasses.replace( + projection, + weight=hax.rearrange(mean_embedding.broadcast_axis(projection.In), (projection.Out, projection.In)), + ) + else: encoder = WhisperEncoder.init(config.enc_config, key=k_enc) connector = WhisperDecoder.init(config.enc_config, key=k_connector) @@ -235,7 +244,7 @@ def __call__( mel: NamedArray, input_ids: NamedArray, attn_mask: Optional[AttentionMask | NamedArray] = None, - pad_token_id: int = 128002, + pad_token_id: int = 128255, *, key=None, ) -> NamedArray: @@ -273,7 +282,6 @@ def __call__( suffix.broadcast_axis(OtherAxes), ], ) - text_tokens = hax.concatenate( "position", [ @@ -283,7 +291,6 @@ def __call__( ], ) push_back_padding = hax.argsort(text_tokens == pad_token_id, "position") - text_tokens_left_pad = text_tokens[{"batch": hax.arange(Batch), "position": push_back_padding}] text_embeds = self.decoder.embeddings.embed(text_tokens_left_pad) @@ -292,7 +299,7 @@ def __call__( text = self.decoder.transformer(text_embeds, attn_mask=causal_mask, key=k_decoder) push_forward_padding = hax.argsort(input_ids != pad_token_id, "position") - input_ids_right_pad = text_tokens[{"batch": hax.arange(Batch), "position": push_forward_padding}] + input_ids_right_pad = input_ids[{"batch": hax.arange(Batch), "position": push_forward_padding}] return ( audio["position", -1], text[ @@ -332,11 +339,12 @@ def compute_loss( # Mask Final Tokens So That Initial Tokens can be used for extra computation reversed_loss_mask = corrected_loss_mask["position", -1::-1] diff_contrast = virtual_embeds - real_embeds - loss2 = hax.dot(diff_contrast, diff_contrast, axis="embed") ** 0.5 + tal_loss = hax.dot(diff_contrast, diff_contrast, axis="embed") ** 0.5 if reduction is None: return kl_proxy_loss else: - return reduction(kl_proxy_loss, axis=reduction_axis) + reduction( - loss2, axis=reduction_axis, where=reversed_loss_mask - ) + loss1 = reduction(kl_proxy_loss, axis=reduction_axis) + loss2 = reduction(tal_loss, axis=reduction_axis, where=reversed_loss_mask) + loss = loss1 + loss2 + return loss