Rework train script to use new Trainer class

This allows IncrementalTrainer to share similar behavior
pull/19/head
Matthew Scholefield 2018-07-10 16:02:20 -05:00
parent 4da65d1f37
commit ebd5e09feb
6 changed files with 148 additions and 113 deletions

View File

@ -16,6 +16,15 @@ from typing import *
LOSS_BIAS = 0.9 # [0..1] where 1 is inf bias
def set_loss_bias(bias: float):
"""
Near 1.0 reduces false positives
Near 0.0 reduces false negatives
"""
global LOSS_BIAS
LOSS_BIAS = bias
def weighted_log_loss(yt, yp) -> Any:
"""
Binary crossentropy with a bias towards false negatives

View File

@ -12,73 +12,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from prettyparse import create_parser
from precise.model import create_model
from precise.params import inject_params, save_params
from precise.train_data import TrainData
usage = '''
Train a new model on a dataset
:model str
Keras model file (.net) to load from and save to
:-e --epochs int 10
Number of epochs to train model for
:-sb --save-best
Only save the model each epoch if its stats improve
:-nv --no-validation
Disable accuracy and validation calculation
to improve speed during training
:-mm --metric-monitor str loss
Metric used to determine when to save
:-em --extra-metrics
Add extra metrics during training
...
'''
from precise.trainer import Trainer
def main():
args = TrainData.parse_args(create_parser(usage))
inject_params(args.model)
save_params(args.model)
data = TrainData.from_both(args.tags_file, args.tags_folder, args.folder)
print('Data:', data)
(inputs, outputs), test_data = data.load(True, not args.no_validation)
print('Inputs shape:', inputs.shape)
print('Outputs shape:', outputs.shape)
if test_data:
print('Test inputs shape:', test_data[0].shape)
print('Test outputs shape:', test_data[1].shape)
if 0 in inputs.shape or 0 in outputs.shape:
print('Not enough data to train')
exit(1)
model = create_model(args.model, args.no_validation, args.extra_metrics)
model.summary()
from keras.callbacks import ModelCheckpoint
checkpoint = ModelCheckpoint(args.model, monitor=args.metric_monitor,
save_best_only=args.save_best)
try:
model.fit(inputs, outputs, 5000, args.epochs, validation_data=test_data,
callbacks=[checkpoint])
except KeyboardInterrupt:
print()
finally:
model.save(args.model)
Trainer().run()
if __name__ == '__main__':

View File

@ -12,27 +12,24 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from os import makedirs
from os.path import basename, splitext, isfile, join
from prettyparse import create_parser
from random import random
from typing import *
import numpy as np
from prettyparse import create_parser
from precise.model import create_model
from precise.network_runner import Listener, KerasRunner
from precise.params import inject_params
from precise.params import pr
from precise.train_data import TrainData
from precise.trainer import Trainer
from precise.util import load_audio, save_audio, glob_all
usage = '''
Train a model to inhibit activation by
marking false activations and retraining
:model str
Keras <NAME>.net file to train
:-e --epochs int 1
Number of epochs to train before continuing evaluation
@ -42,18 +39,6 @@ usage = '''
:-c --chunk-size int 2048
Number of samples between testing the neural network
:-b --batch-size int 2048
Batch size used for training
:-sb --save-best
Only save the model each epoch if its stats improve
:-mm --metric-monitor str loss
Metric used to determine when to save
:-em --extra-metrics
Add extra metrics during training
:-nv --no-validation
Disable accuracy and validation calculation
to improve speed during training
@ -85,34 +70,41 @@ def save_trained_fns(trained_fns: list, model_name: str):
f.write('\n'.join(trained_fns).encode('utf8', 'surrogatepass'))
class IncrementalTrainer:
def __init__(self, args):
self.args = args
self.trained_fns = load_trained_fns(args.model)
pr = inject_params(args.model)
class IncrementalTrainer(Trainer):
def __init__(self):
super().__init__(create_parser(usage))
for i in (
join(self.args.folder, 'not-wake-word', 'generated'),
join(self.args.folder, 'test', 'not-wake-word', 'generated')
):
makedirs(i, exist_ok=True)
self.trained_fns = load_trained_fns(self.args.model)
self.audio_buffer = np.zeros(pr.buffer_samples, dtype=float)
from keras.callbacks import ModelCheckpoint
self.checkpoint = ModelCheckpoint(args.model, monitor=args.metric_monitor,
save_best_only=args.save_best)
data = TrainData.from_tags(args.tags_file, args.tags_folder)
self.tags_data = data.load(True, not args.no_validation)
if not isfile(self.args.model):
create_model(self.args.model, self.args.no_validation, self.args.extra_metrics).save(
self.args.model
)
self.listener = Listener(self.args.model, self.args.chunk_size, runner_cls=KerasRunner)
if not isfile(args.model):
create_model(args.model, args.no_validation, args.extra_metrics).save(args.model)
self.listener = Listener(args.model, args.chunk_size, runner_cls=KerasRunner)
@staticmethod
def load_data(args: Any):
data = TrainData.from_tags(args.tags_file, args.tags_folder)
return data.load(True, not args.no_validation)
def retrain(self):
"""Train for a session, pulling in any new data from the filesystem"""
folder = TrainData.from_folder(self.args.folder)
train_data, test_data = folder.load(True, not self.args.no_validation)
train_data = TrainData.merge(train_data, self.tags_data[0])
test_data = TrainData.merge(test_data, self.tags_data[1])
train_data = TrainData.merge(train_data, self.train)
test_data = TrainData.merge(test_data, self.test)
print()
try:
self.listener.runner.model.fit(*train_data, self.args.batch_size, self.args.epochs,
validation_data=test_data, callbacks=[self.checkpoint])
validation_data=test_data, callbacks=self.callbacks)
finally:
self.listener.runner.model.save(self.args.model)
@ -142,7 +134,7 @@ class IncrementalTrainer:
samples_since_train = 0
self.retrain()
def train_incremental(self):
def run(self):
"""
Begin reading through audio files, saving false
activations and retraining when necessary
@ -161,17 +153,8 @@ class IncrementalTrainer:
def main():
args = TrainData.parse_args(create_parser(usage))
for i in (
join(args.folder, 'not-wake-word', 'generated'),
join(args.folder, 'test', 'not-wake-word', 'generated')
):
makedirs(i, exist_ok=True)
trainer = IncrementalTrainer(args)
try:
trainer.train_incremental()
IncrementalTrainer().run()
except KeyboardInterrupt:
print()

103
precise/trainer.py Normal file
View File

@ -0,0 +1,103 @@
from argparse import ArgumentParser
from fitipy import Fitipy
from keras.callbacks import LambdaCallback
from os.path import splitext
from prettyparse import add_to_parser
from typing import Any, Tuple
from precise.functions import set_loss_bias
from precise.model import create_model
from precise.params import inject_params, save_params
from precise.train_data import TrainData
class Trainer:
usage = '''
Train a new model on a dataset
:model str
Keras model file (.net) to load from and save to
:-e --epochs int 10
Number of epochs to train model for
:-s --sensitivity float 0.2
Weighted loss bias. Higher values decrease increase positives
:-b --batch-size int 5000
Batch size for training
:-sb --save-best
Only save the model each epoch if its stats improve
:-nv --no-validation
Disable accuracy and validation calculation
to improve speed during training
:-mm --metric-monitor str loss
Metric used to determine when to save
:-em --extra-metrics
Add extra metrics during training
...
'''
def __init__(self, parser=None):
parser = parser or ArgumentParser()
add_to_parser(parser, self.usage, True)
self.args = args = TrainData.parse_args(parser)
if not 0.0 <= args.sensitivity <= 1.0:
parser.error('sensitivity must be between 0.0 and 1.0')
inject_params(args.model)
save_params(args.model)
self.train, self.test = self.load_data(self.args)
set_loss_bias(1.0 - args.sensitivity)
self.model = create_model(args.model, args.no_validation, args.extra_metrics)
self.model.summary()
from keras.callbacks import ModelCheckpoint, TensorBoard
checkpoint = ModelCheckpoint(args.model, monitor=args.metric_monitor,
save_best_only=args.save_best)
epoch_fiti = Fitipy(splitext(args.model)[0] + '.epoch')
self.epoch = epoch_fiti.read().read(0, int)
def on_epoch_end(a, b):
self.epoch += 1
epoch_fiti.write().write(self.epoch, str)
self.callbacks = [
checkpoint, TensorBoard(),
LambdaCallback(on_epoch_end=on_epoch_end)
]
@staticmethod
def load_data(args: Any) -> Tuple[tuple, tuple]:
data = TrainData.from_both(args.tags_file, args.tags_folder, args.folder)
print('Data:', data)
train, test = data.load(True, not args.no_validation)
print('Inputs shape:', train[0].shape)
print('Outputs shape:', train[1].shape)
if test:
print('Test inputs shape:', test[0].shape)
print('Test outputs shape:', test[1].shape)
if 0 in train[0].shape or 0 in train[1].shape:
print('Not enough data to train')
exit(1)
return train, test
def run(self):
try:
self.model.fit(
self.train[0], self.train[1], self.args.batch_size, self.epoch + self.args.epochs,
validation_data=self.test, initial_epoch=self.epoch,
callbacks=self.callbacks
)
except KeyboardInterrupt:
print()

View File

@ -11,7 +11,7 @@ Markdown==2.6.11
numpy==1.14.2
pocketsphinx==0.1.3
-e git+https://github.com/mycroftai/mycroft-precise#egg=precise_runner&subdirectory=runner
prettyparse==0.1.2
prettyparse==0.1.4
protobuf==3.5.2.post1
PyAudio==0.2.11
PyYAML==3.12
@ -24,3 +24,4 @@ termcolor==1.1.0
typing==3.6.4
wavio==0.0.3
Werkzeug==0.14.1
fitipy==0.1.1

View File

@ -48,9 +48,10 @@ setup(
'h5py',
'wavio',
'typing',
'prettyparse',
'prettyparse>=0.1.4',
'precise-runner',
'attrs'
'attrs',
'fitipy'
],
author='Matthew Scholefield',