mycroft-precise/precise/scripts/train.py

173 lines
5.7 KiB
Python
Executable File

#!/usr/bin/env python3
# Copyright 2019 Mycroft AI Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Train a new model on a dataset
:model str
Keras model file (.net) to load from and save to
:-sf --samples-file str -
Loads subset of data from the provided json file
generated with precise-train-sampled
:-is --invert-samples
Loads subset of data not inside --samples-file
:-e --epochs int 10
Number of epochs to train model for
:-s --sensitivity float 0.2
Target sensitivity when training. Higher values cause
more false positives
:-b --batch-size int 5000
Batch size for training
:-sb --save-best
Only save the model each epoch if its stats improve
:-nv --no-validation
Disable accuracy and validation calculation
to improve speed during training
:-mm --metric-monitor str loss
Metric used to determine when to save
:-em --extra-metrics
Add extra metrics during training
:-f --freeze-till int 0
Freeze all weights up to this index (non-inclusive).
Can be negative to wrap from end
...
"""
from fitipy import Fitipy
from keras.callbacks import LambdaCallback
from os.path import splitext, isfile
from prettyparse import Usage
from typing import Any, Tuple
from precise.model import create_model, ModelParams
from precise.params import inject_params, save_params
from precise.scripts.base_script import BaseScript
from precise.train_data import TrainData
from precise.util import calc_sample_hash
class TrainScript(BaseScript):
usage = Usage(__doc__) | TrainData.usage
def __init__(self, args):
super().__init__(args)
if args.invert_samples and not args.samples_file:
raise ValueError('You must specify --samples-file when using --invert-samples')
if args.samples_file and not isfile(args.samples_file):
raise ValueError('No such file: ' + (args.invert_samples or args.samples_file))
if not 0.0 <= args.sensitivity <= 1.0:
raise ValueError('sensitivity must be between 0.0 and 1.0')
inject_params(args.model)
save_params(args.model)
params = ModelParams(skip_acc=args.no_validation, extra_metrics=args.extra_metrics,
loss_bias=1.0 - args.sensitivity, freeze_till=args.freeze_till)
self.model = create_model(args.model, params)
self.train, self.test = self.load_data(self.args)
from keras.callbacks import ModelCheckpoint, TensorBoard
checkpoint = ModelCheckpoint(args.model, monitor=args.metric_monitor,
save_best_only=args.save_best)
epoch_fiti = Fitipy(splitext(args.model)[0] + '.epoch')
self.epoch = epoch_fiti.read().read(0, int)
def on_epoch_end(_a, _b):
self.epoch += 1
epoch_fiti.write().write(self.epoch, str)
self.model_base = splitext(self.args.model)[0]
if args.samples_file:
self.samples, self.hash_to_ind = self.load_sample_data(args.samples_file, self.train)
else:
self.samples = set()
self.hash_to_ind = {}
self.callbacks = [
checkpoint, TensorBoard(
log_dir=self.model_base + '.logs',
), LambdaCallback(on_epoch_end=on_epoch_end)
]
@staticmethod
def load_sample_data(filename, train_data) -> Tuple[set, dict]:
samples = Fitipy(filename).read().set()
hash_to_ind = {
calc_sample_hash(inp, outp): ind
for ind, (inp, outp) in enumerate(zip(*train_data))
}
for hsh in list(samples):
if hsh not in hash_to_ind:
print('Missing hash:', hsh)
samples.remove(hsh)
return samples, hash_to_ind
@staticmethod
def load_data(args: Any) -> Tuple[tuple, tuple]:
data = TrainData.from_both(args.tags_file, args.tags_folder, args.folder)
print('Data:', data)
train, test = data.load(True, not args.no_validation)
print('Inputs shape:', train[0].shape)
print('Outputs shape:', train[1].shape)
if test:
print('Test inputs shape:', test[0].shape)
print('Test outputs shape:', test[1].shape)
if 0 in train[0].shape or 0 in train[1].shape:
print('Not enough data to train')
exit(1)
return train, test
@property
def sampled_data(self):
"""Returns (train_inputs, train_outputs)"""
if self.args.samples_file:
if self.args.invert_samples:
chosen_samples = set(self.hash_to_ind) - self.samples
else:
chosen_samples = self.samples
selected_indices = [self.hash_to_ind[h] for h in chosen_samples]
return self.train[0][selected_indices], self.train[1][selected_indices]
else:
return self.train[0], self.train[1]
def run(self):
self.model.summary()
train_inputs, train_outputs = self.sampled_data
self.model.fit(
train_inputs, train_outputs, self.args.batch_size,
self.epoch + self.args.epochs, validation_data=self.test,
initial_epoch=self.epoch, callbacks=self.callbacks
)
main = TrainScript.run_main
if __name__ == '__main__':
main()