-
Notifications
You must be signed in to change notification settings - Fork 87
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Xinjian Li
committed
Jun 14, 2023
1 parent
3a558a7
commit 0d64039
Showing
20 changed files
with
383 additions
and
194 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
import torch | ||
import torch.nn as nn | ||
from allosaurus.am.module.frontend.ssl import read_ssl_frontend | ||
from allosaurus.am.module.utils.register import register_arch | ||
from allosaurus.am.module.encoder.transformer import TransformerEncoderLayer | ||
from allosaurus.am.module.utils.phonetics import create_remap_matrix | ||
from phonepiece.inventory import read_inventory | ||
|
||
|
||
@register_arch | ||
class SSLFinetune(nn.Module): | ||
|
||
type_ = 'ssl_finetune' | ||
|
||
def __init__(self, config): | ||
|
||
super().__init__() | ||
|
||
self.config = config | ||
|
||
self.hidden_size = config.hidden_size | ||
|
||
self.lang_output_size = dict() | ||
|
||
self.phone_tensor = None | ||
|
||
# prepare SSL frontend | ||
self.frontend = read_ssl_frontend(config) | ||
|
||
self.logsoftmax = nn.LogSoftmax(dim=2) | ||
|
||
self.lang2linear = nn.ParameterDict() | ||
|
||
self.default_lang = config.langs[0] | ||
|
||
for lang in config.langs: | ||
self.prep_language_layer(lang, self.config.device) | ||
|
||
self.remap_layer = {} | ||
|
||
def prep_language_layer(self, lang_id, device=None): | ||
|
||
if str(lang_id) not in self.lang2linear: | ||
|
||
inventory = read_inventory(lang_id) | ||
num_phoneme = len(inventory.phoneme.elems) - 1 | ||
|
||
lang_id = str(lang_id) | ||
print("preparing ", lang_id, ' pos_only ', self.config.pos_only, ' inventory ', str(inventory)) | ||
self.lang2linear[lang_id] = nn.Linear(1024, num_phoneme).to(device) | ||
|
||
def prep_remap_language_layer(self, lang_id, device=None): | ||
|
||
if lang_id in self.lang2linear or lang_id in self.remap_layer: | ||
return | ||
|
||
remap_matrix = create_remap_matrix(self.default_lang, lang_id).to(device) | ||
print("create remapping matrix from {} to {}".format(self.default_lang, lang_id)) | ||
print("shape: ", remap_matrix.shape) | ||
self.remap_layer[lang_id] = remap_matrix | ||
|
||
def forward(self, input_tensor, input_lengths, meta=None): | ||
""" | ||
:param input: an Tensor with shape (B,T,H) | ||
:lengths: a list of length of input_tensor, if None then no padding | ||
:meta: dictionary containing meta information (should contain lang_id in this case | ||
:return: | ||
""" | ||
|
||
#if utt_ids: | ||
#print("utt_ids {} \n target_tensor {}".format(' '.join(utt_ids), target_tensor)) | ||
#print("input_lengths {}".format(str(input_lengths))) | ||
#print("target_tensor {}".format(target_tensor)) | ||
#print("target_lengths {}".format(target_lengths)) | ||
|
||
lang_id = str(meta['lang_id']) | ||
|
||
self.prep_remap_language_layer(lang_id, input_tensor.device) | ||
|
||
feats, mask = self.frontend.forward(input_tensor, input_lengths) | ||
|
||
output_length = torch.sum(mask, dim=1) | ||
|
||
if lang_id in self.lang2linear: | ||
predicted = self.lang2linear[lang_id](feats) | ||
else: | ||
predicted = self.lang2linear[self.default_lang](feats) | ||
predicted = torch.matmul(predicted, self.remap_layer[lang_id]) | ||
|
||
output_tensor = self.logsoftmax(predicted) | ||
|
||
result = { | ||
'output': output_tensor, | ||
'output_length': output_length, | ||
} | ||
|
||
# return (B,T,H) for gathering | ||
return result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
import torch | ||
import argparse | ||
|
||
|
||
def print_model_info(model_path): | ||
model = torch.load(model_path, map_location=torch.device('cpu')) | ||
|
||
print(f"Keys in the model '{model_path}':") | ||
for key in model.keys(): | ||
print(key) | ||
|
||
print("\nTensor shapes in the model:") | ||
for key, value in model.items(): | ||
if isinstance(value, torch.Tensor): | ||
print(f"Key: {key}, Tensor shape: {value.shape}") | ||
|
||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('-m', '--model', type=str, help='Path to the model') | ||
args = parser.parse_args() | ||
|
||
print_model_info(args.model) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
model: ssl_finetune | ||
criterion: ctc | ||
feat_size: 1024 | ||
hidden_size: 1024 | ||
ssl: xlsr | ||
ssl_feature: hidden_states | ||
langs: ['eng'] | ||
input_size: 40 | ||
window_size: 0.02 | ||
window_shift: 0.02 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
model: ssl_finetune | ||
criterion: ctc | ||
feat_size: 1024 | ||
hidden_size: 1024 | ||
ssl: xlsr | ||
ssl_feature: hidden_states | ||
langs: ['eng'] | ||
input_size: 40 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
model: ssl_finetune | ||
criterion: ctc | ||
feat_size: 1024 | ||
hidden_size: 1024 | ||
ssl: xlsr | ||
ssl_feature: hidden_states | ||
langs: ['jpn'] | ||
input_size: 40 | ||
window_size: 0.02 | ||
window_shift: 0.02 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
model: raw | ||
sample_rate: 16000 | ||
max_feature_length: 160000 | ||
max_feature_length: 180000 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
corpus_ids: [eng_cvtrain] | ||
batch_size: 16 | ||
am: xlsr_finetune_hidden_states | ||
lm: phone | ||
pm: raw | ||
model_path: /ocean/projects/cis210027p/xinjianl/asr2k/allosaurus/allosaurus/data/arch/23042002_finetune_hidden_states | ||
gpu_size: 1 | ||
lr_rate: 0.0001 | ||
optimizer: adamw | ||
epoch: 2000 | ||
grad_clip: 2.0 | ||
report_per_batch: 10 | ||
eval_per_epoch: False | ||
eval_per_second: 3600 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
corpus_ids: [eng_cvtrain] | ||
batch_size: 16 | ||
am: xlsr_finetune_eng | ||
lm: phone | ||
pm: raw | ||
gpu_size: 1 | ||
lr_rate: 0.0001 | ||
optimizer: adamw | ||
epoch: 2000 | ||
grad_clip: 2.0 | ||
report_per_batch: 10 | ||
eval_per_epoch: False | ||
eval_per_second: 3600 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
corpus_ids: [cmn_cvtrain] | ||
batch_size: 16 | ||
am: xlsr_finetune_hidden_states | ||
lm: phone | ||
pm: raw | ||
gpu_size: 1 | ||
lr_rate: 0.0001 | ||
optimizer: adamw | ||
epoch: 2000 | ||
grad_clip: 2.0 | ||
report_per_batch: 10 | ||
eval_per_epoch: False | ||
eval_per_second: 3600 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
corpus_ids: [jpn_cvtrain] | ||
batch_size: 16 | ||
am: xlsr_finetune_jpn | ||
lm: phone | ||
pm: raw | ||
gpu_size: 1 | ||
lr_rate: 0.0001 | ||
optimizer: adamw | ||
epoch: 2000 | ||
grad_clip: 2.0 | ||
report_per_batch: 10 | ||
eval_per_epoch: False | ||
eval_per_second: 3600 |
Oops, something went wrong.