mycroft-precise/precise/train_feedback.py

84 lines
3.3 KiB
Python
Executable File

#!/usr/bin/env python3
# This script trains the network, selectively choosing
# segments from data/random that cause an activation. These
# segments are moved into data/not-keyword and the network is retrained
import sys
sys.path += ['.'] # noqa
import argparse
from glob import glob
from os import makedirs
from os.path import join, basename, splitext
import wavio
from precise.common import *
def main():
parser = argparse.ArgumentParser()
parser.add_argument('-e', '--epochs', type=int, default=1, help='Number of epochs to train before continueing evaluation')
parser.add_argument('-j', '--jump', type=int, default=2, help='Number of features to skip while evaluating')
parser.add_argument('-s', '--skip-trained', type=bool, default=True, help='Whether to skip random files that have already been trained on')
parser.add_argument('-l', '--load', type=bool, default=True)
parser.add_argument('-b', '--save-best', type=bool, default=False)
parser.add_argument('-d', '--out-dir', default='data/not-keyword/generated')
args = parser.parse_args()
trained_fns = []
if isfile('keyword.trained.txt'):
with open('keyword.trained.txt', 'r') as f:
trained_fns = f.read().split('\n')
random_inputs = ((f, load_audio(f)) for f in glob('data/random/*.wav') if not args.skip_trained or f not in trained_fns)
validation_data = load_data('data/test')
model = create_model('keyword.net', args.load)
from keras.callbacks import ModelCheckpoint
checkpoint = ModelCheckpoint('keyword.net', monitor='val_acc', save_best_only=args.save_best, mode='max')
makedirs(args.out_dir, exist_ok=True)
try:
for fn, full_audio in random_inputs:
features = vectorize_raw(full_audio[:pr.buffer_samples])
counter = 0
print('Starting file ' + fn + '...')
for i in range(pr.buffer_samples - pr.buffer_samples % pr.hop_samples, len(full_audio), pr.hop_samples):
print('\r' + str(i * 100. / len(full_audio)) + '%', end='', flush=True)
window = full_audio[i-pr.window_samples:i]
vec = vectorize_raw(window)
assert len(vec) == 1
features = np.concatenate([features[1:], vec])
counter += 1
if counter % args.jump == 0:
out = model.predict(np.array([features]))[0]
else:
out = 0.0
if out > 0.5:
name = join(args.out_dir, splitext(basename(fn))[0] + '-' + str(i)) + '.wav'
audio = full_audio[i - pr.buffer_samples:i]
audio = (audio * np.iinfo(np.int16).max).astype(np.int16)
wavio.write(name, audio, pr.sample_rate, sampwidth=pr.sample_depth, scale='none')
inputs, outputs = load_data('data')
model.fit(inputs, outputs, 5000, 1, validation_data=validation_data)
model.save('keyword.net')
model.fit(inputs, outputs, 5000, args.epochs - 1, validation_data=validation_data, callbacks=[checkpoint])
print()
with open('keyword.trained.txt', 'a') as f:
f.write('\n' + fn)
except KeyboardInterrupt:
print()
if __name__ == '__main__':
main()