mycroft-precise/precise/scripts/train.py

86 lines
2.5 KiB
Python
Executable File

#!/usr/bin/env python3
# Copyright 2018 Mycroft AI Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from prettyparse import create_parser
from precise.model import create_model
from precise.params import inject_params, save_params
from precise.train_data import TrainData
usage = '''
Train a new model on a dataset
:model str
Keras model file (.net) to load from and save to
:-e --epochs int 10
Number of epochs to train model for
:-sb --save-best
Only save the model each epoch if its stats improve
:-nv --no-validation
Disable accuracy and validation calculation
to improve speed during training
:-mm --metric-monitor str loss
Metric used to determine when to save
:-em --extra-metrics
Add extra metrics during training
...
'''
def main():
args = TrainData.parse_args(create_parser(usage))
inject_params(args.model)
save_params(args.model)
data = TrainData.from_both(args.tags_file, args.tags_folder, args.folder)
print('Data:', data)
(inputs, outputs), test_data = data.load(True, not args.no_validation)
print('Inputs shape:', inputs.shape)
print('Outputs shape:', outputs.shape)
if test_data:
print('Test inputs shape:', test_data[0].shape)
print('Test outputs shape:', test_data[1].shape)
if 0 in inputs.shape or 0 in outputs.shape:
print('Not enough data to train')
exit(1)
model = create_model(args.model, args.no_validation, args.extra_metrics)
model.summary()
from keras.callbacks import ModelCheckpoint
checkpoint = ModelCheckpoint(args.model, monitor=args.metric_monitor,
save_best_only=args.save_best)
try:
model.fit(inputs, outputs, 5000, args.epochs, validation_data=test_data,
callbacks=[checkpoint])
except KeyboardInterrupt:
print()
finally:
model.save(args.model)
if __name__ == '__main__':
main()