Skip to content

Commit

Permalink
add language dependent model
Browse files Browse the repository at this point in the history
  • Loading branch information
Xinjian Li committed Jun 14, 2023
1 parent 3a558a7 commit 0d64039
Show file tree
Hide file tree
Showing 20 changed files with 383 additions and 194 deletions.
39 changes: 25 additions & 14 deletions allosaurus/am/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import editdistance
from collections import defaultdict
import torch.nn as nn
import torch.cuda


def read_am(config_or_name, overwrite_config=None):
Expand Down Expand Up @@ -44,7 +45,13 @@ def __init__(self, model, criterion, config):
self.criterion = criterion

self.config = config
self.device_id = self.config.rank

if "rank" in self.config:
self.device_id = self.config.rank
elif torch.cuda.is_available():
self.device_id = 0
else:
self.device_id = -1

def cuda(self):
self.model = self.model.cuda()
Expand Down Expand Up @@ -217,6 +224,16 @@ def test_step(self, sample: dict):
output = result['output'].detach()
output_length = result['output_length']

logits = move_to_ndarray(output)
logits_length = move_to_ndarray(output_length)

assert len(logits) == len(logits_length)
logit_lst = []

for ii in range(len(logits)):
logit = logits[ii][:logits_length[ii]]
logit_lst.append(logit)

emit = 1.0
if 'emit' in sample:
emit = sample['emit']
Expand All @@ -230,13 +247,13 @@ def test_step(self, sample: dict):
topk = sample['topk']

if format == 'logit':
return output
return logit_lst
elif format == 'both':
decoded_info = self.decode(output, output_length, topk, emit)
decoded_info = self.decode(logit_lst, topk, emit)
return output, decoded_info

else:
decoded_info = self.decode(output, output_length, topk, emit)
decoded_info = self.decode(logit_lst, topk, emit)

return decoded_info

Expand Down Expand Up @@ -334,19 +351,13 @@ def eval_ter(self, output, output_length, sample):
# return decoded_tokens


def decode(self, output, output_length, topk=1, emit=1.0):

logits = move_to_ndarray(output)
logits_length = move_to_ndarray(output_length)

assert len(logits) == len(logits_length)
def decode(self, logit_lst, topk=1, emit=1.0):

decoded_lst = []

for ii in range(len(logits)):

logit = logits[ii][:logits_length[ii]]
for ii in range(len(logit_lst)):

logit = logit_lst[ii]
emit_frame_idx = []

cur_max_arg = -1
Expand Down Expand Up @@ -383,7 +394,7 @@ def decode(self, output, output_length, topk=1, emit=1.0):

start_timestamp = self.config.window_shift * idx
end_timestamp = self.config.window_shift * next_idx
duration = min(1.0, end_timestamp - start_timestamp)
duration = min(0.2, end_timestamp - start_timestamp)

info = {'start': self.config.window_shift * idx,
'duration': duration,
Expand Down
99 changes: 99 additions & 0 deletions allosaurus/am/module/arch/ssl_finetune.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import torch
import torch.nn as nn
from allosaurus.am.module.frontend.ssl import read_ssl_frontend
from allosaurus.am.module.utils.register import register_arch
from allosaurus.am.module.encoder.transformer import TransformerEncoderLayer
from allosaurus.am.module.utils.phonetics import create_remap_matrix
from phonepiece.inventory import read_inventory


@register_arch
class SSLFinetune(nn.Module):

type_ = 'ssl_finetune'

def __init__(self, config):

super().__init__()

self.config = config

self.hidden_size = config.hidden_size

self.lang_output_size = dict()

self.phone_tensor = None

# prepare SSL frontend
self.frontend = read_ssl_frontend(config)

self.logsoftmax = nn.LogSoftmax(dim=2)

self.lang2linear = nn.ParameterDict()

self.default_lang = config.langs[0]

for lang in config.langs:
self.prep_language_layer(lang, self.config.device)

self.remap_layer = {}

def prep_language_layer(self, lang_id, device=None):

if str(lang_id) not in self.lang2linear:

inventory = read_inventory(lang_id)
num_phoneme = len(inventory.phoneme.elems) - 1

lang_id = str(lang_id)
print("preparing ", lang_id, ' pos_only ', self.config.pos_only, ' inventory ', str(inventory))
self.lang2linear[lang_id] = nn.Linear(1024, num_phoneme).to(device)

def prep_remap_language_layer(self, lang_id, device=None):

if lang_id in self.lang2linear or lang_id in self.remap_layer:
return

remap_matrix = create_remap_matrix(self.default_lang, lang_id).to(device)
print("create remapping matrix from {} to {}".format(self.default_lang, lang_id))
print("shape: ", remap_matrix.shape)
self.remap_layer[lang_id] = remap_matrix

def forward(self, input_tensor, input_lengths, meta=None):
"""
:param input: an Tensor with shape (B,T,H)
:lengths: a list of length of input_tensor, if None then no padding
:meta: dictionary containing meta information (should contain lang_id in this case
:return:
"""

#if utt_ids:
#print("utt_ids {} \n target_tensor {}".format(' '.join(utt_ids), target_tensor))
#print("input_lengths {}".format(str(input_lengths)))
#print("target_tensor {}".format(target_tensor))
#print("target_lengths {}".format(target_lengths))

lang_id = str(meta['lang_id'])

self.prep_remap_language_layer(lang_id, input_tensor.device)

feats, mask = self.frontend.forward(input_tensor, input_lengths)

output_length = torch.sum(mask, dim=1)

if lang_id in self.lang2linear:
predicted = self.lang2linear[lang_id](feats)
else:
predicted = self.lang2linear[self.default_lang](feats)
predicted = torch.matmul(predicted, self.remap_layer[lang_id])

output_tensor = self.logsoftmax(predicted)

result = {
'output': output_tensor,
'output_length': output_length,
}

# return (B,T,H) for gathering
return result
13 changes: 13 additions & 0 deletions allosaurus/am/module/utils/phonetics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from torch.nn.utils.rnn import pad_sequence
from scipy.special import softmax
from allosaurus.utils.tensor import make_pad_mask
from phonepiece.inventory import read_inventory

def feature_tensor_to_embedding(padded_feature, padded_mask, feature_embed):

Expand Down Expand Up @@ -66,6 +67,18 @@ def create_feature_list(inventory, pos_only=True):
return feature_lst


def create_remap_matrix(origin_lang_id, target_lang_id):
origin_inventory = read_inventory(origin_lang_id)
target_inventory = read_inventory(target_lang_id)

ylst = target_inventory.phoneme.atoi(target_inventory.remap(origin_inventory.phoneme.elems[:-1]))
xlst = list(range(len(origin_inventory.phoneme)-1))

remap_matrix = np.zeros((len(origin_inventory.phoneme)-1, len(target_inventory.phoneme)-1), dtype=np.float32)
remap_matrix[xlst, ylst] = 1

return torch.from_numpy(remap_matrix)

def create_phone_embedding(inventory, feature_embed, pos_only=True):
"""
build phone embedding from feature embedding
Expand Down
6 changes: 3 additions & 3 deletions allosaurus/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,14 @@ def recognize(self, audio_path, lang_id, output=None, verbose=False, logit=False
sample['meta']['lang_id'] = lang_id
sample['meta']['format'] = 'both'

outputs, decoded_info_lst = self.am.test_step(sample)
logit_lst, decoded_info_lst = self.am.test_step(sample)

for utt_id, decoded_info in zip(sample['utt_ids'], decoded_info_lst):
decoded_info = self.lm.decode(decoded_info, lang_id)
utt_infos.append((utt_id, decoded_info))

for utt_id, utt_logit in zip(sample['utt_ids'], outputs):
utt_logits.append((utt_id, utt_logit.cpu().detach().numpy()))
for utt_id, utt_logit in zip(sample['utt_ids'], logit_lst):
utt_logits.append((utt_id, utt_logit))

utt_infos = self.merge_partial_info(utt_infos)
utt_logits = self.merge_partial_logit(utt_logits)
Expand Down
2 changes: 1 addition & 1 deletion allosaurus/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def read_audio(filename_or_audio, sample_rate=16000):
def read_audio_duration(filename):

if str(filename).endswith('.wav'):
with contextlib.closing(wave.open(filename, 'r')) as f:
with contextlib.closing(wave.open(str(filename), 'r')) as f:
frames = f.getnframes()
rate = f.getframerate()
duration = frames / float(rate)
Expand Down
23 changes: 23 additions & 0 deletions allosaurus/bin/lookup_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import torch
import argparse


def print_model_info(model_path):
model = torch.load(model_path, map_location=torch.device('cpu'))

print(f"Keys in the model '{model_path}':")
for key in model.keys():
print(key)

print("\nTensor shapes in the model:")
for key, value in model.items():
if isinstance(value, torch.Tensor):
print(f"Key: {key}, Tensor shape: {value.shape}")


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-m', '--model', type=str, help='Path to the model')
args = parser.parse_args()

print_model_info(args.model)
10 changes: 3 additions & 7 deletions allosaurus/corpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,12 @@ def read_corpus(corpus_path, utt_cnt=None, lang_id=None):

# load record if exists
record = None
if (corpus_path / 'record.txt').exists():
record = read_record(corpus_path / 'record.txt')
elif (corpus_path / 'segments').exists():
record = read_record(corpus_path / 'segments')
if (corpus_path / 'wav.scp').exists():
record = read_record(corpus_path / 'wav.scp')

# load text if exists
text = None
if (corpus_path / 'text.txt').exists():
text = read_text(corpus_path / 'text.txt', lang_id)
elif (corpus_path / 'text').exists():
if (corpus_path / 'text').exists():
text = read_text(corpus_path / 'text', lang_id)

assert (record is not None) or (text is not None), " both text and record are empty!"
Expand Down
10 changes: 10 additions & 0 deletions allosaurus/data/config/am/xlsr_finetune_eng.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
model: ssl_finetune
criterion: ctc
feat_size: 1024
hidden_size: 1024
ssl: xlsr
ssl_feature: hidden_states
langs: ['eng']
input_size: 40
window_size: 0.02
window_shift: 0.02
8 changes: 8 additions & 0 deletions allosaurus/data/config/am/xlsr_finetune_hidden_states.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
model: ssl_finetune
criterion: ctc
feat_size: 1024
hidden_size: 1024
ssl: xlsr
ssl_feature: hidden_states
langs: ['eng']
input_size: 40
10 changes: 10 additions & 0 deletions allosaurus/data/config/am/xlsr_finetune_jpn.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
model: ssl_finetune
criterion: ctc
feat_size: 1024
hidden_size: 1024
ssl: xlsr
ssl_feature: hidden_states
langs: ['jpn']
input_size: 40
window_size: 0.02
window_shift: 0.02
2 changes: 1 addition & 1 deletion allosaurus/data/config/pm/raw.yml
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
model: raw
sample_rate: 16000
max_feature_length: 160000
max_feature_length: 180000
14 changes: 14 additions & 0 deletions allosaurus/data/exp/23042002_finetune_hidden_states.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
corpus_ids: [eng_cvtrain]
batch_size: 16
am: xlsr_finetune_hidden_states
lm: phone
pm: raw
model_path: /ocean/projects/cis210027p/xinjianl/asr2k/allosaurus/allosaurus/data/arch/23042002_finetune_hidden_states
gpu_size: 1
lr_rate: 0.0001
optimizer: adamw
epoch: 2000
grad_clip: 2.0
report_per_batch: 10
eval_per_epoch: False
eval_per_second: 3600
13 changes: 13 additions & 0 deletions allosaurus/data/exp/23060201_eng.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
corpus_ids: [eng_cvtrain]
batch_size: 16
am: xlsr_finetune_eng
lm: phone
pm: raw
gpu_size: 1
lr_rate: 0.0001
optimizer: adamw
epoch: 2000
grad_clip: 2.0
report_per_batch: 10
eval_per_epoch: False
eval_per_second: 3600
13 changes: 13 additions & 0 deletions allosaurus/data/exp/23060202_cmn.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
corpus_ids: [cmn_cvtrain]
batch_size: 16
am: xlsr_finetune_hidden_states
lm: phone
pm: raw
gpu_size: 1
lr_rate: 0.0001
optimizer: adamw
epoch: 2000
grad_clip: 2.0
report_per_batch: 10
eval_per_epoch: False
eval_per_second: 3600
13 changes: 13 additions & 0 deletions allosaurus/data/exp/23060203_jpn.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
corpus_ids: [jpn_cvtrain]
batch_size: 16
am: xlsr_finetune_jpn
lm: phone
pm: raw
gpu_size: 1
lr_rate: 0.0001
optimizer: adamw
epoch: 2000
grad_clip: 2.0
report_per_batch: 10
eval_per_epoch: False
eval_per_second: 3600
Loading

0 comments on commit 0d64039

Please sign in to comment.