86 lines
2.5 KiB
Python
Executable File
86 lines
2.5 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
# Copyright 2018 Mycroft AI Inc.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
from prettyparse import create_parser
|
|
|
|
from precise.model import create_model
|
|
from precise.params import inject_params, save_params
|
|
from precise.train_data import TrainData
|
|
|
|
usage = '''
|
|
Train a new model on a dataset
|
|
|
|
:model str
|
|
Keras model file (.net) to load from and save to
|
|
|
|
:-e --epochs int 10
|
|
Number of epochs to train model for
|
|
|
|
:-sb --save-best
|
|
Only save the model each epoch if its stats improve
|
|
|
|
:-nv --no-validation
|
|
Disable accuracy and validation calculation
|
|
to improve speed during training
|
|
|
|
:-mm --metric-monitor str loss
|
|
Metric used to determine when to save
|
|
|
|
:-em --extra-metrics
|
|
Add extra metrics during training
|
|
|
|
...
|
|
'''
|
|
|
|
|
|
def main():
|
|
args = TrainData.parse_args(create_parser(usage))
|
|
|
|
inject_params(args.model)
|
|
save_params(args.model)
|
|
|
|
data = TrainData.from_both(args.tags_file, args.tags_folder, args.folder)
|
|
print('Data:', data)
|
|
(inputs, outputs), test_data = data.load(True, not args.no_validation)
|
|
|
|
print('Inputs shape:', inputs.shape)
|
|
print('Outputs shape:', outputs.shape)
|
|
|
|
if test_data:
|
|
print('Test inputs shape:', test_data[0].shape)
|
|
print('Test outputs shape:', test_data[1].shape)
|
|
|
|
if 0 in inputs.shape or 0 in outputs.shape:
|
|
print('Not enough data to train')
|
|
exit(1)
|
|
|
|
model = create_model(args.model, args.no_validation, args.extra_metrics)
|
|
model.summary()
|
|
|
|
from keras.callbacks import ModelCheckpoint
|
|
checkpoint = ModelCheckpoint(args.model, monitor=args.metric_monitor,
|
|
save_best_only=args.save_best)
|
|
|
|
try:
|
|
model.fit(inputs, outputs, 5000, args.epochs, validation_data=test_data,
|
|
callbacks=[checkpoint])
|
|
except KeyboardInterrupt:
|
|
print()
|
|
finally:
|
|
model.save(args.model)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|