Make training fail gracefully if dimensions are invalid

pull/1/head
Matthew D. Scholefield 2017-12-18 18:51:03 -06:00
parent 64598a0d8d
commit e3cb1f0bf4
1 changed files with 9 additions and 2 deletions

View File

@ -23,10 +23,17 @@ def main():
args = parser.parse_args()
inputs, outputs = load_data(args.data_dir)
validation_data = load_data(args.data_dir + '/test')
val_in, val_out = load_data(args.data_dir + '/test')
print('Inputs shape:', inputs.shape)
print('Outputs shape:', outputs.shape)
print('Test inputs shape:', val_in.shape)
print('Test outputs shape:', val_out.shape)
if (0 in inputs.shape or 0 in outputs.shape or
0 in val_in.shape or 0 in val_out.shape):
print('Not enough data to train')
exit(1)
model = create_model(args.model, args.load)
@ -37,7 +44,7 @@ def main():
checkpoint = ModelCheckpoint(args.model, monitor='val_acc', save_best_only=args.save_best, mode='max')
try:
model.fit(inputs, outputs, 5000, args.epochs, validation_data=validation_data, callbacks=[checkpoint])
model.fit(inputs, outputs, 5000, args.epochs, validation_data=(val_in, val_out), callbacks=[checkpoint])
except KeyboardInterrupt:
print()