Add more arguments to train_feedback.py
parent
5e3b2594d4
commit
2741b7afc3
|
@ -23,7 +23,10 @@ def main():
|
|||
parser.add_argument('-s', '--skip-trained', type=bool, default=True, help='Whether to skip random files that have already been trained on')
|
||||
parser.add_argument('-l', '--load', type=bool, default=True)
|
||||
parser.add_argument('-b', '--save-best', type=bool, default=False)
|
||||
parser.add_argument('-d', '--out-dir', default='data/not-keyword/generated')
|
||||
parser.add_argument('-o', '--out-dir', default='data/not-keyword/generated')
|
||||
parser.add_argument('-d', '--data-dir', default='data')
|
||||
parser.add_argument('-r', '--random-data-dir', default='data/random')
|
||||
parser.add_argument('-m', '--model', default='keyword.net')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
@ -32,13 +35,13 @@ def main():
|
|||
with open('keyword.trained.txt', 'r') as f:
|
||||
trained_fns = f.read().split('\n')
|
||||
|
||||
random_inputs = ((f, load_audio(f)) for f in glob('data/random/*.wav') if not args.skip_trained or f not in trained_fns)
|
||||
validation_data = load_data('data/test')
|
||||
random_inputs = ((f, load_audio(f)) for f in glob(args.random_data_dir + '/*.wav') if not args.skip_trained or f not in trained_fns)
|
||||
validation_data = load_data(args.data_dir + '/test')
|
||||
|
||||
model = create_model('keyword.net', args.load)
|
||||
model = create_model(args.model, args.load)
|
||||
|
||||
from keras.callbacks import ModelCheckpoint
|
||||
checkpoint = ModelCheckpoint('keyword.net', monitor='val_acc', save_best_only=args.save_best, mode='max')
|
||||
checkpoint = ModelCheckpoint(args.model, monitor='val_acc', save_best_only=args.save_best, mode='max')
|
||||
|
||||
makedirs(args.out_dir, exist_ok=True)
|
||||
|
||||
|
@ -62,15 +65,20 @@ def main():
|
|||
else:
|
||||
out = 0.0
|
||||
if out > 0.5:
|
||||
name = join(args.out_dir, splitext(basename(fn))[0] + '-' + str(i)) + '.wav'
|
||||
folder = args.out_dir
|
||||
if args.epochs > 0 and (counter // args.jump) % 5 == 0:
|
||||
folder = join(args.data_dir, 'test', 'not-keyword', 'generated')
|
||||
name = join(folder, splitext(basename(fn))[0] + '-' + str(i)) + '.wav'
|
||||
audio = full_audio[i - pr.buffer_samples:i]
|
||||
audio = (audio * np.iinfo(np.int16).max).astype(np.int16)
|
||||
wavio.write(name, audio, pr.sample_rate, sampwidth=pr.sample_depth, scale='none')
|
||||
print('\nSaving to:', name)
|
||||
|
||||
inputs, outputs = load_data('data')
|
||||
model.fit(inputs, outputs, 5000, 1, validation_data=validation_data)
|
||||
model.save('keyword.net')
|
||||
model.fit(inputs, outputs, 5000, args.epochs - 1, validation_data=validation_data, callbacks=[checkpoint])
|
||||
if args.epochs > 0:
|
||||
inputs, outputs = load_data(args.data_dir)
|
||||
model.fit(inputs, outputs, 5000, 1, validation_data=validation_data)
|
||||
model.save('keyword.net')
|
||||
model.fit(inputs, outputs, 5000, args.epochs - 1, validation_data=validation_data, callbacks=[checkpoint])
|
||||
print()
|
||||
|
||||
with open('keyword.trained.txt', 'a') as f:
|
||||
|
|
Loading…
Reference in New Issue