(Under review). After the review period, we will open-source the code on our GitHub.
- Calculate importance scores for all tokens
npy_tome_layer/models/hubert_transformer_encoder.py Line 1190
; - Calculate the remaining tokens
npy_tome_layer/models/hubert_transformer_encoder.py Line 1192
; - Calculate the index of the remaining tokens
npy_tome_layer/models/hubert_transformer_encoder.py Line 1196-1198
; - Use the torch.gather() function to extract all remaining tokens
npy_tome_layer/models/hubert_transformer_encoder.py Line 1199
# x: (sequence_length, batch, features_embedding)
# attn: (batch, sequence_length,sequence_length)
if i >= self.pruning_init_layer:
N, b, c = x.size()
attn = z
# calculate the importance scores
pruning_scores = attn.view(b, N, N).sum(dim=1)
# calculate the number of remaining tokens
left_tokens = math.ceil( (1- self.pruning_rate) * (N)) # N = token_num
x = x.transpose(0, 1)
# select remaining tokens
test, idx = torch.topk(pruning_scores, left_tokens, dim=1, largest=True, sorted=True) # [B, left_tokens]
true_idx, _ = torch.topk(idx, left_tokens, dim=1, largest=False, sorted=True) # [B, left_tokens]
index = true_idx.unsqueeze(-1).expand(-1, -1, c) # [B, left_tokens, C]
x = torch.gather(x, dim=1, index=index) # [B, left_tokens, C]
x = x.transpose(0, 1)
# MASK alignment
padding_mask = torch.gather(padding_mask, dim=1, index=true_idx)
Table presents the specific numerical values for the visualization of Fig. 3 in the paper. It can be observed that the attention scores between many tokens in the table are 0, indicating that there is no mutual influence between some tokens. Furthermore, merging tokens consumes additional time; therefore, our proposed method uses token pruning rather than token merging.
If a threshold is used, it will result in different numbers of tokens being pruned for each input sequence, rendering the model unable to batch process. If we employ a MASK matrix to mask the pruning tokens, however, it contradicts the original intention of model acceleration.
The differences in patch length for various image inputs are not significant. For example, most images fed into pre-training models have both H and W dimensions set to 224, corresponding to a patch length of 588 (3 × 224 × 224 // 16 ×16). Therefore, a fixed number of tokens can be gradually clipped in the field of CV.
However, the token length for different speech inputs varies significantly. Additionally, the length of speech sequences is much longer than that of corresponding text sequences. Therefore, we use a pruning rate to ensure that longer speech inputs prune more tokens, and shorter ones prune fewer tokens.
Visualization of input audio, the corresponding audio file is in audio/ted_1096_7.wav
. From the red boxes of this figure, it can be observed that there is a lot of redundant information in the content of the speech input.
The loss curve of retraining, "Conv" represents the utilization of the convolution module. Both pruning and convolution techniques are applied to the sixth layer of the speech pre-trained model.
Create a conda environment with Pytorch and install fairseq
conda create --name pruning python=3.9
conda activate pruning
git clone https://github.com/pytorch/fairseq
cd fairseq
pip install --editable ./
python setup.py build develop
# if you meet the following error, please reinstall the packages
# numpy np.float error
pip install numpy==1.23.5
# generation error: sacrebleu import error TOKENIZER
pip install sacrebleu==1.5.1
This repository is constructed using the codebase from fairseq. If you require information on the basic usage of fairseq, please refer to the fairseq documentation.
- pandas==2.0.3
- sacrebleu==1.5.1
- scikit-learn==1.3.0
- scipy==1.11.1
- sentencepiece==0.1.99
- tensorboard==2.14.0
- torch==2.0.1
- torchaudio==2.0.2
- tqdm==4.65.0
Please Download MuST-C-v1 datasets.
Notes: It appears that the original dataset website hides the download link. However, the dataset can still be downloaded after filling out the dataset request form directly. So we recommend that you use this method.
Make directories to store ST (MuST-C) and datasets. Please specify the target language.
Preprocess spm data.
We use HuBERT model for speech pre-trained model for training. Before training, please download the HuBERT-Base model.
# TEXT_DIR=/workspace/s2t/deltalm_data/en-$target/binary
fairseq-train $data_dir --text-data $TEXT_DIR --tgt-lang $target \
--user-dir $USER_DIR \
--config-yaml config_hyper.yaml --train-subset train --valid-subset dev \
--save-dir $SAVE_DIR --num-workers 4 --max-tokens 3000000 --batch-size 32 --max-tokens-text 8192 \
--task speech_and_text_translation --criterion speech_and_text_translation --label-smoothing 0.1 \
--arch hubert_transformer_pruning_layer --optimizer adam --adam-betas '(0.9, 0.98)' --lr 1e-4 --lr-scheduler inverse_sqrt --weight-decay 0.0001 \
--ddp-backend=legacy_ddp \
--warmup-updates 4000 --clip-norm 0.0 --seed 1 --update-freq 2 \
--layernorm-embedding \
--max-epoch 45 \
--fp16 \
--st-training --mt-finetune \
--st-training \
--hubert-model-path $HU_BERT \
--load-pretrained-mt-encoder-decoder-from $MT_PRETRAINED_MODEL --tensorboard-logdir $SAVE_DIR
# key args --pruning-rate --pruning-init-layer
# TEXT_DIR=/workspace/s2t/deltalm_data/en-$target/binary
fairseq-train $data_dir --text-data $TEXT_DIR --tgt-lang $target \
--user-dir $USER_DIR \
--config-yaml config_hyper.yaml --train-subset train --valid-subset dev \
--save-dir $SAVE_DIR --num-workers 4 --max-tokens 3000000 --batch-size 32 --max-tokens-text 8192 \
--task speech_and_text_translation --criterion speech_and_text_translation --label-smoothing 0.1 \
--arch hubert_transformer_pruning_layer --optimizer adam --adam-betas '(0.9, 0.98)' --lr 1e-4 --lr-scheduler inverse_sqrt --weight-decay 0.0001 \
--ddp-backend=legacy_ddp \
--warmup-updates 4000 --clip-norm 0.0 --seed 1 --update-freq 2 \
--layernorm-embedding \
--max-epoch 45 \
--fp16 \
--st-training --mt-finetune \
--st-training \
--hubert-model-path $HU_BERT \
--load-pretrained-mt-encoder-decoder-from $MT_PRETRAINED_MODEL --tensorboard-logdir $SAVE_DIR --pruning-rate 0.999 --pruning-init-layer 4
# key args --pruning-max-rate --pruning-min-rate --pruning-init-layer --arch
# TEXT_DIR=/workspace/s2t/deltalm_data/en-$target/binary
fairseq-train $data_dir --text-data $TEXT_DIR --tgt-lang $target \
--user-dir $USER_DIR \
--config-yaml config_hyper.yaml --train-subset train --valid-subset dev \
--save-dir $SAVE_DIR --num-workers 4 --max-tokens 3000000 --batch-size 32 --max-tokens-text 8192 \
--task speech_and_text_translation --criterion speech_and_text_translation --label-smoothing 0.1 \
--arch hubert_transformer_pruning_layer_schedule --optimizer adam --adam-betas '(0.9, 0.98)' --lr 1e-4 --lr-scheduler inverse_sqrt --weight-decay 0.0001 \
--ddp-backend=legacy_ddp \
--warmup-updates 4000 --clip-norm 0.0 --seed 1 --update-freq 2 \
--layernorm-embedding \
--max-epoch 45 \
--fp16 \
--st-training --mt-finetune \
--st-training \
--hubert-model-path $HU_BERT \
--load-pretrained-mt-encoder-decoder-from $MT_PRETRAINED_MODEL --tensorboard-logdir $SAVE_DIR --pruning-rate 0.999 --pruning-init-layer 4
We refer to the code of Hubert. Thanks for their great contributions!