Make training fail gracefully if dimensions are invalid
parent
64598a0d8d
commit
e3cb1f0bf4
|
@ -23,10 +23,17 @@ def main():
|
|||
args = parser.parse_args()
|
||||
|
||||
inputs, outputs = load_data(args.data_dir)
|
||||
validation_data = load_data(args.data_dir + '/test')
|
||||
val_in, val_out = load_data(args.data_dir + '/test')
|
||||
|
||||
print('Inputs shape:', inputs.shape)
|
||||
print('Outputs shape:', outputs.shape)
|
||||
print('Test inputs shape:', val_in.shape)
|
||||
print('Test outputs shape:', val_out.shape)
|
||||
|
||||
if (0 in inputs.shape or 0 in outputs.shape or
|
||||
0 in val_in.shape or 0 in val_out.shape):
|
||||
print('Not enough data to train')
|
||||
exit(1)
|
||||
|
||||
model = create_model(args.model, args.load)
|
||||
|
||||
|
@ -37,7 +44,7 @@ def main():
|
|||
checkpoint = ModelCheckpoint(args.model, monitor='val_acc', save_best_only=args.save_best, mode='max')
|
||||
|
||||
try:
|
||||
model.fit(inputs, outputs, 5000, args.epochs, validation_data=validation_data, callbacks=[checkpoint])
|
||||
model.fit(inputs, outputs, 5000, args.epochs, validation_data=(val_in, val_out), callbacks=[checkpoint])
|
||||
except KeyboardInterrupt:
|
||||
print()
|
||||
|
||||
|
|
Loading…
Reference in New Issue