Add more arguments to train_feedback.py

pull/1/head
Matthew D. Scholefield 2017-11-14 19:28:24 -06:00
parent 5e3b2594d4
commit 2741b7afc3
1 changed files with 18 additions and 10 deletions

View File

@ -23,7 +23,10 @@ def main():
parser.add_argument('-s', '--skip-trained', type=bool, default=True, help='Whether to skip random files that have already been trained on')
parser.add_argument('-l', '--load', type=bool, default=True)
parser.add_argument('-b', '--save-best', type=bool, default=False)
parser.add_argument('-d', '--out-dir', default='data/not-keyword/generated')
parser.add_argument('-o', '--out-dir', default='data/not-keyword/generated')
parser.add_argument('-d', '--data-dir', default='data')
parser.add_argument('-r', '--random-data-dir', default='data/random')
parser.add_argument('-m', '--model', default='keyword.net')
args = parser.parse_args()
@ -32,13 +35,13 @@ def main():
with open('keyword.trained.txt', 'r') as f:
trained_fns = f.read().split('\n')
random_inputs = ((f, load_audio(f)) for f in glob('data/random/*.wav') if not args.skip_trained or f not in trained_fns)
validation_data = load_data('data/test')
random_inputs = ((f, load_audio(f)) for f in glob(args.random_data_dir + '/*.wav') if not args.skip_trained or f not in trained_fns)
validation_data = load_data(args.data_dir + '/test')
model = create_model('keyword.net', args.load)
model = create_model(args.model, args.load)
from keras.callbacks import ModelCheckpoint
checkpoint = ModelCheckpoint('keyword.net', monitor='val_acc', save_best_only=args.save_best, mode='max')
checkpoint = ModelCheckpoint(args.model, monitor='val_acc', save_best_only=args.save_best, mode='max')
makedirs(args.out_dir, exist_ok=True)
@ -62,15 +65,20 @@ def main():
else:
out = 0.0
if out > 0.5:
name = join(args.out_dir, splitext(basename(fn))[0] + '-' + str(i)) + '.wav'
folder = args.out_dir
if args.epochs > 0 and (counter // args.jump) % 5 == 0:
folder = join(args.data_dir, 'test', 'not-keyword', 'generated')
name = join(folder, splitext(basename(fn))[0] + '-' + str(i)) + '.wav'
audio = full_audio[i - pr.buffer_samples:i]
audio = (audio * np.iinfo(np.int16).max).astype(np.int16)
wavio.write(name, audio, pr.sample_rate, sampwidth=pr.sample_depth, scale='none')
print('\nSaving to:', name)
inputs, outputs = load_data('data')
model.fit(inputs, outputs, 5000, 1, validation_data=validation_data)
model.save('keyword.net')
model.fit(inputs, outputs, 5000, args.epochs - 1, validation_data=validation_data, callbacks=[checkpoint])
if args.epochs > 0:
inputs, outputs = load_data(args.data_dir)
model.fit(inputs, outputs, 5000, 1, validation_data=validation_data)
model.save('keyword.net')
model.fit(inputs, outputs, 5000, args.epochs - 1, validation_data=validation_data, callbacks=[checkpoint])
print()
with open('keyword.trained.txt', 'a') as f: