74 lines
		
	
	
		
			1.9 KiB
		
	
	
	
		
			Python
		
	
	
		
			Executable File
		
	
			
		
		
	
	
			74 lines
		
	
	
		
			1.9 KiB
		
	
	
	
		
			Python
		
	
	
		
			Executable File
		
	
#!/usr/bin/env python3
 | 
						|
# Copyright (c) 2017 Mycroft AI Inc.
 | 
						|
from prettyparse import create_parser
 | 
						|
 | 
						|
from precise.model import create_model
 | 
						|
from precise.params import inject_params, save_params
 | 
						|
from precise.train_data import TrainData
 | 
						|
 | 
						|
usage = '''
 | 
						|
    Train a new model on a dataset
 | 
						|
    
 | 
						|
    :model str
 | 
						|
        Keras model file (.net) to load from and save to
 | 
						|
    
 | 
						|
    :-e --epochs int 10
 | 
						|
        Number of epochs to train model for
 | 
						|
    
 | 
						|
    :-sb --save-best
 | 
						|
        Only save the model each epoch if its stats improve
 | 
						|
    
 | 
						|
    :-nv --no-validation
 | 
						|
        Disable accuracy and validation calculation
 | 
						|
        to improve speed during training
 | 
						|
    
 | 
						|
    :-mm --metric-monitor str loss
 | 
						|
        Metric used to determine when to save
 | 
						|
    
 | 
						|
    :-em --extra-metrics
 | 
						|
        Add extra metrics during training
 | 
						|
    
 | 
						|
    ...
 | 
						|
'''
 | 
						|
 | 
						|
 | 
						|
def main():
 | 
						|
    args = TrainData.parse_args(create_parser(usage))
 | 
						|
 | 
						|
    inject_params(args.model)
 | 
						|
    save_params(args.model)
 | 
						|
 | 
						|
    data = TrainData.from_both(args.db_file, args.db_folder, args.data_dir)
 | 
						|
    print('Data:', data)
 | 
						|
    (inputs, outputs), test_data = data.load(True, not args.no_validation)
 | 
						|
 | 
						|
    print('Inputs shape:', inputs.shape)
 | 
						|
    print('Outputs shape:', outputs.shape)
 | 
						|
 | 
						|
    if test_data:
 | 
						|
        print('Test inputs shape:', test_data[0].shape)
 | 
						|
        print('Test outputs shape:', test_data[1].shape)
 | 
						|
 | 
						|
    if 0 in inputs.shape or 0 in outputs.shape:
 | 
						|
        print('Not enough data to train')
 | 
						|
        exit(1)
 | 
						|
 | 
						|
    model = create_model(args.model, args.no_validation, args.extra_metrics)
 | 
						|
    model.summary()
 | 
						|
 | 
						|
    from keras.callbacks import ModelCheckpoint
 | 
						|
    checkpoint = ModelCheckpoint(args.model, monitor=args.metric_monitor,
 | 
						|
                                 save_best_only=args.save_best)
 | 
						|
 | 
						|
    try:
 | 
						|
        model.fit(inputs, outputs, 5000, args.epochs, validation_data=test_data,
 | 
						|
                  callbacks=[checkpoint])
 | 
						|
    except KeyboardInterrupt:
 | 
						|
        print()
 | 
						|
    finally:
 | 
						|
        model.save(args.model)
 | 
						|
 | 
						|
 | 
						|
if __name__ == '__main__':
 | 
						|
    main()
 |