84 lines
3.3 KiB
Python
Executable File
84 lines
3.3 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
# This script trains the network, selectively choosing
|
|
# segments from data/random that cause an activation. These
|
|
# segments are moved into data/not-keyword and the network is retrained
|
|
|
|
import sys
|
|
sys.path += ['.'] # noqa
|
|
|
|
import argparse
|
|
from glob import glob
|
|
from os import makedirs
|
|
from os.path import join, basename, splitext
|
|
|
|
import wavio
|
|
|
|
from precise.common import *
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('-e', '--epochs', type=int, default=1, help='Number of epochs to train before continueing evaluation')
|
|
parser.add_argument('-j', '--jump', type=int, default=2, help='Number of features to skip while evaluating')
|
|
parser.add_argument('-s', '--skip-trained', type=bool, default=True, help='Whether to skip random files that have already been trained on')
|
|
parser.add_argument('-l', '--load', type=bool, default=True)
|
|
parser.add_argument('-b', '--save-best', type=bool, default=False)
|
|
parser.add_argument('-d', '--out-dir', default='data/not-keyword/generated')
|
|
|
|
args = parser.parse_args()
|
|
|
|
trained_fns = []
|
|
if isfile('keyword.trained.txt'):
|
|
with open('keyword.trained.txt', 'r') as f:
|
|
trained_fns = f.read().split('\n')
|
|
|
|
random_inputs = ((f, load_audio(f)) for f in glob('data/random/*.wav') if not args.skip_trained or f not in trained_fns)
|
|
validation_data = load_data('data/test')
|
|
|
|
model = create_model('keyword.net', args.load)
|
|
|
|
from keras.callbacks import ModelCheckpoint
|
|
checkpoint = ModelCheckpoint('keyword.net', monitor='val_acc', save_best_only=args.save_best, mode='max')
|
|
|
|
makedirs(args.out_dir, exist_ok=True)
|
|
|
|
try:
|
|
for fn, full_audio in random_inputs:
|
|
features = vectorize_raw(full_audio[:pr.buffer_samples])
|
|
counter = 0
|
|
|
|
print('Starting file ' + fn + '...')
|
|
for i in range(pr.buffer_samples - pr.buffer_samples % pr.hop_samples, len(full_audio), pr.hop_samples):
|
|
print('\r' + str(i * 100. / len(full_audio)) + '%', end='', flush=True)
|
|
window = full_audio[i-pr.window_samples:i]
|
|
|
|
vec = vectorize_raw(window)
|
|
assert len(vec) == 1
|
|
features = np.concatenate([features[1:], vec])
|
|
|
|
counter += 1
|
|
if counter % args.jump == 0:
|
|
out = model.predict(np.array([features]))[0]
|
|
else:
|
|
out = 0.0
|
|
if out > 0.5:
|
|
name = join(args.out_dir, splitext(basename(fn))[0] + '-' + str(i)) + '.wav'
|
|
audio = full_audio[i - pr.buffer_samples:i]
|
|
audio = (audio * np.iinfo(np.int16).max).astype(np.int16)
|
|
wavio.write(name, audio, pr.sample_rate, sampwidth=pr.sample_depth, scale='none')
|
|
|
|
inputs, outputs = load_data('data')
|
|
model.fit(inputs, outputs, 5000, 1, validation_data=validation_data)
|
|
model.save('keyword.net')
|
|
model.fit(inputs, outputs, 5000, args.epochs - 1, validation_data=validation_data, callbacks=[checkpoint])
|
|
print()
|
|
|
|
with open('keyword.trained.txt', 'a') as f:
|
|
f.write('\n' + fn)
|
|
|
|
except KeyboardInterrupt:
|
|
print()
|
|
|
|
if __name__ == '__main__':
|
|
main()
|