-
Notifications
You must be signed in to change notification settings - Fork 4
/
wav_to_features.py
185 lines (164 loc) · 6.76 KB
/
wav_to_features.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Converts WAV audio files into input features for neural networks.
The models used in this example take in two-dimensional spectrograms as the
input to their neural network portions. For testing and porting purposes it's
useful to be able to generate these spectrograms outside of the full model, so
that on-device implementations using their own FFT and streaming code can be
tested against the version used in training for example. The output is as a
C source file, so it can be easily linked into an embedded test application.
To use this, run:
bazel run tensorflow/examples/speech_commands:wav_to_features -- \
--input_wav=my.wav --output_c_file=my_wav_data.c
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import os.path
import sys
import tensorflow as tf
import input_data
import models
from tensorflow.python.platform import gfile
FLAGS = None
def wav_to_features(sample_rate, clip_duration_ms, window_size_ms,
window_stride_ms, feature_bin_count, quantize, preprocess,
input_wav, output_c_file):
"""Converts an audio file into its corresponding feature map.
Args:
sample_rate: Expected sample rate of the wavs.
clip_duration_ms: Expected duration in milliseconds of the wavs.
window_size_ms: How long each spectrogram timeslice is.
window_stride_ms: How far to move in time between spectogram timeslices.
feature_bin_count: How many bins to use for the feature fingerprint.
quantize: Whether to train the model for eight-bit deployment.
preprocess: Spectrogram processing mode; "mfcc", "average" or "micro".
input_wav: Path to the audio WAV file to read.
output_c_file: Where to save the generated C source file.
"""
# Start a new TensorFlow session.
sess = tf.compat.v1.InteractiveSession()
model_settings = models.prepare_model_settings(
0, sample_rate, clip_duration_ms, window_size_ms, window_stride_ms,
feature_bin_count, preprocess)
audio_processor = input_data.AudioProcessor(None, None, 0, 0, '', 0, 0,
model_settings, None)
results = audio_processor.get_features_for_wav(input_wav, model_settings,
sess)
features = results[0]
variable_base = os.path.splitext(os.path.basename(input_wav).lower())[0]
# Save a C source file containing the feature data as an array.
with gfile.GFile(output_c_file, 'w') as f:
f.write('/* File automatically created by\n')
f.write(' * tensorflow/examples/speech_commands/wav_to_features.py \\\n')
f.write(' * --sample_rate=%d \\\n' % sample_rate)
f.write(' * --clip_duration_ms=%d \\\n' % clip_duration_ms)
f.write(' * --window_size_ms=%d \\\n' % window_size_ms)
f.write(' * --window_stride_ms=%d \\\n' % window_stride_ms)
f.write(' * --feature_bin_count=%d \\\n' % feature_bin_count)
if quantize:
f.write(' * --quantize=1 \\\n')
f.write(' * --preprocess="%s" \\\n' % preprocess)
f.write(' * --input_wav="%s" \\\n' % input_wav)
f.write(' * --output_c_file="%s" \\\n' % output_c_file)
f.write(' */\n\n')
f.write('const int g_%s_width = %d;\n' %
(variable_base, model_settings['fingerprint_width']))
f.write('const int g_%s_height = %d;\n' %
(variable_base, model_settings['spectrogram_length']))
if quantize:
features_min, features_max = input_data.get_features_range(model_settings)
f.write('const unsigned char g_%s_data[] = {' % variable_base)
i = 0
for value in features.flatten():
quantized_value = int(
round(
(255 * (value - features_min)) / (features_max - features_min)))
if quantized_value < 0:
quantized_value = 0
if quantized_value > 255:
quantized_value = 255
if i == 0:
f.write('\n ')
f.write('%d, ' % (quantized_value))
i = (i + 1) % 10
else:
f.write('const float g_%s_data[] = {\n' % variable_base)
i = 0
for value in features.flatten():
if i == 0:
f.write('\n ')
f.write(' ,%f' % value)
i = (i + 1) % 10
f.write('\n};\n')
def main(_):
# We want to see all the logging messages.
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
wav_to_features(FLAGS.sample_rate, FLAGS.clip_duration_ms,
FLAGS.window_size_ms, FLAGS.window_stride_ms,
FLAGS.feature_bin_count, FLAGS.quantize, FLAGS.preprocess,
FLAGS.input_wav, FLAGS.output_c_file)
tf.compat.v1.logging.info('Wrote to "%s"' % (FLAGS.output_c_file))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--sample_rate',
type=int,
default=16000,
help='Expected sample rate of the wavs',)
parser.add_argument(
'--clip_duration_ms',
type=int,
default=1000,
help='Expected duration in milliseconds of the wavs',)
parser.add_argument(
'--window_size_ms',
type=float,
default=30.0,
help='How long each spectrogram timeslice is.',)
parser.add_argument(
'--window_stride_ms',
type=float,
default=10.0,
help='How far to move in time between spectogram timeslices.',)
parser.add_argument(
'--feature_bin_count',
type=int,
default=40,
help='How many bins to use for the MFCC fingerprint',
)
parser.add_argument(
'--quantize',
type=bool,
default=False,
help='Whether to train the model for eight-bit deployment')
parser.add_argument(
'--preprocess',
type=str,
default='mfcc',
help='Spectrogram processing mode. Can be "mfcc", "average", or "micro"')
parser.add_argument(
'--input_wav',
type=str,
default=None,
help='Path to the audio WAV file to read')
parser.add_argument(
'--output_c_file',
type=str,
default=None,
help='Where to save the generated C source file containing the features')
FLAGS, unparsed = parser.parse_known_args()
tf.compat.v1.app.run(main=main, argv=[sys.argv[0]] + unparsed)