forked from buriburisuri/speech-to-text-wavenet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
executable file
·49 lines (35 loc) · 1.03 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import sugartensor as tf
from data import SpeechCorpus, voca_size
from model import *
__author__ = '[email protected]'
# set log level to debug
tf.sg_verbosity(10)
#
# hyper parameters
#
batch_size = 16 # total batch size
#
# inputs
#
# corpus input tensor
data = SpeechCorpus(batch_size=batch_size * tf.sg_gpus())
# mfcc feature of audio
inputs = tf.split(data.mfcc, tf.sg_gpus(), axis=0)
# target sentence label
labels = tf.split(data.label, tf.sg_gpus(), axis=0)
# sequence length except zero-padding
seq_len = []
for input_ in inputs:
seq_len.append(tf.not_equal(input_.sg_sum(axis=2), 0.).sg_int().sg_sum(axis=1))
# parallel loss tower
@tf.sg_parallel
def get_loss(opt):
# encode audio feature
logit = get_logit(opt.input[opt.gpu_index], voca_size=voca_size)
# CTC loss
return logit.sg_ctc(target=opt.target[opt.gpu_index], seq_len=opt.seq_len[opt.gpu_index])
#
# train
#
tf.sg_train(lr=0.0001, loss=get_loss(input=inputs, target=labels, seq_len=seq_len),
ep_size=data.num_batch, max_ep=50)