Rework train script to use new Trainer class
This allows IncrementalTrainer to share similar behaviorpull/19/head
parent
4da65d1f37
commit
ebd5e09feb
|
@ -16,6 +16,15 @@ from typing import *
|
|||
LOSS_BIAS = 0.9 # [0..1] where 1 is inf bias
|
||||
|
||||
|
||||
def set_loss_bias(bias: float):
|
||||
"""
|
||||
Near 1.0 reduces false positives
|
||||
Near 0.0 reduces false negatives
|
||||
"""
|
||||
global LOSS_BIAS
|
||||
LOSS_BIAS = bias
|
||||
|
||||
|
||||
def weighted_log_loss(yt, yp) -> Any:
|
||||
"""
|
||||
Binary crossentropy with a bias towards false negatives
|
||||
|
|
|
@ -12,73 +12,11 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from prettyparse import create_parser
|
||||
|
||||
from precise.model import create_model
|
||||
from precise.params import inject_params, save_params
|
||||
from precise.train_data import TrainData
|
||||
|
||||
usage = '''
|
||||
Train a new model on a dataset
|
||||
|
||||
:model str
|
||||
Keras model file (.net) to load from and save to
|
||||
|
||||
:-e --epochs int 10
|
||||
Number of epochs to train model for
|
||||
|
||||
:-sb --save-best
|
||||
Only save the model each epoch if its stats improve
|
||||
|
||||
:-nv --no-validation
|
||||
Disable accuracy and validation calculation
|
||||
to improve speed during training
|
||||
|
||||
:-mm --metric-monitor str loss
|
||||
Metric used to determine when to save
|
||||
|
||||
:-em --extra-metrics
|
||||
Add extra metrics during training
|
||||
|
||||
...
|
||||
'''
|
||||
from precise.trainer import Trainer
|
||||
|
||||
|
||||
def main():
|
||||
args = TrainData.parse_args(create_parser(usage))
|
||||
|
||||
inject_params(args.model)
|
||||
save_params(args.model)
|
||||
|
||||
data = TrainData.from_both(args.tags_file, args.tags_folder, args.folder)
|
||||
print('Data:', data)
|
||||
(inputs, outputs), test_data = data.load(True, not args.no_validation)
|
||||
|
||||
print('Inputs shape:', inputs.shape)
|
||||
print('Outputs shape:', outputs.shape)
|
||||
|
||||
if test_data:
|
||||
print('Test inputs shape:', test_data[0].shape)
|
||||
print('Test outputs shape:', test_data[1].shape)
|
||||
|
||||
if 0 in inputs.shape or 0 in outputs.shape:
|
||||
print('Not enough data to train')
|
||||
exit(1)
|
||||
|
||||
model = create_model(args.model, args.no_validation, args.extra_metrics)
|
||||
model.summary()
|
||||
|
||||
from keras.callbacks import ModelCheckpoint
|
||||
checkpoint = ModelCheckpoint(args.model, monitor=args.metric_monitor,
|
||||
save_best_only=args.save_best)
|
||||
|
||||
try:
|
||||
model.fit(inputs, outputs, 5000, args.epochs, validation_data=test_data,
|
||||
callbacks=[checkpoint])
|
||||
except KeyboardInterrupt:
|
||||
print()
|
||||
finally:
|
||||
model.save(args.model)
|
||||
Trainer().run()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -12,27 +12,24 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import numpy as np
|
||||
from os import makedirs
|
||||
from os.path import basename, splitext, isfile, join
|
||||
from prettyparse import create_parser
|
||||
from random import random
|
||||
from typing import *
|
||||
|
||||
import numpy as np
|
||||
from prettyparse import create_parser
|
||||
|
||||
from precise.model import create_model
|
||||
from precise.network_runner import Listener, KerasRunner
|
||||
from precise.params import inject_params
|
||||
from precise.params import pr
|
||||
from precise.train_data import TrainData
|
||||
from precise.trainer import Trainer
|
||||
from precise.util import load_audio, save_audio, glob_all
|
||||
|
||||
usage = '''
|
||||
Train a model to inhibit activation by
|
||||
marking false activations and retraining
|
||||
|
||||
:model str
|
||||
Keras <NAME>.net file to train
|
||||
|
||||
:-e --epochs int 1
|
||||
Number of epochs to train before continuing evaluation
|
||||
|
||||
|
@ -42,18 +39,6 @@ usage = '''
|
|||
:-c --chunk-size int 2048
|
||||
Number of samples between testing the neural network
|
||||
|
||||
:-b --batch-size int 2048
|
||||
Batch size used for training
|
||||
|
||||
:-sb --save-best
|
||||
Only save the model each epoch if its stats improve
|
||||
|
||||
:-mm --metric-monitor str loss
|
||||
Metric used to determine when to save
|
||||
|
||||
:-em --extra-metrics
|
||||
Add extra metrics during training
|
||||
|
||||
:-nv --no-validation
|
||||
Disable accuracy and validation calculation
|
||||
to improve speed during training
|
||||
|
@ -85,34 +70,41 @@ def save_trained_fns(trained_fns: list, model_name: str):
|
|||
f.write('\n'.join(trained_fns).encode('utf8', 'surrogatepass'))
|
||||
|
||||
|
||||
class IncrementalTrainer:
|
||||
def __init__(self, args):
|
||||
self.args = args
|
||||
self.trained_fns = load_trained_fns(args.model)
|
||||
pr = inject_params(args.model)
|
||||
class IncrementalTrainer(Trainer):
|
||||
def __init__(self):
|
||||
super().__init__(create_parser(usage))
|
||||
|
||||
for i in (
|
||||
join(self.args.folder, 'not-wake-word', 'generated'),
|
||||
join(self.args.folder, 'test', 'not-wake-word', 'generated')
|
||||
):
|
||||
makedirs(i, exist_ok=True)
|
||||
|
||||
self.trained_fns = load_trained_fns(self.args.model)
|
||||
self.audio_buffer = np.zeros(pr.buffer_samples, dtype=float)
|
||||
|
||||
from keras.callbacks import ModelCheckpoint
|
||||
self.checkpoint = ModelCheckpoint(args.model, monitor=args.metric_monitor,
|
||||
save_best_only=args.save_best)
|
||||
data = TrainData.from_tags(args.tags_file, args.tags_folder)
|
||||
self.tags_data = data.load(True, not args.no_validation)
|
||||
if not isfile(self.args.model):
|
||||
create_model(self.args.model, self.args.no_validation, self.args.extra_metrics).save(
|
||||
self.args.model
|
||||
)
|
||||
self.listener = Listener(self.args.model, self.args.chunk_size, runner_cls=KerasRunner)
|
||||
|
||||
if not isfile(args.model):
|
||||
create_model(args.model, args.no_validation, args.extra_metrics).save(args.model)
|
||||
self.listener = Listener(args.model, args.chunk_size, runner_cls=KerasRunner)
|
||||
@staticmethod
|
||||
def load_data(args: Any):
|
||||
data = TrainData.from_tags(args.tags_file, args.tags_folder)
|
||||
return data.load(True, not args.no_validation)
|
||||
|
||||
def retrain(self):
|
||||
"""Train for a session, pulling in any new data from the filesystem"""
|
||||
folder = TrainData.from_folder(self.args.folder)
|
||||
train_data, test_data = folder.load(True, not self.args.no_validation)
|
||||
|
||||
train_data = TrainData.merge(train_data, self.tags_data[0])
|
||||
test_data = TrainData.merge(test_data, self.tags_data[1])
|
||||
train_data = TrainData.merge(train_data, self.train)
|
||||
test_data = TrainData.merge(test_data, self.test)
|
||||
print()
|
||||
try:
|
||||
self.listener.runner.model.fit(*train_data, self.args.batch_size, self.args.epochs,
|
||||
validation_data=test_data, callbacks=[self.checkpoint])
|
||||
validation_data=test_data, callbacks=self.callbacks)
|
||||
finally:
|
||||
self.listener.runner.model.save(self.args.model)
|
||||
|
||||
|
@ -142,7 +134,7 @@ class IncrementalTrainer:
|
|||
samples_since_train = 0
|
||||
self.retrain()
|
||||
|
||||
def train_incremental(self):
|
||||
def run(self):
|
||||
"""
|
||||
Begin reading through audio files, saving false
|
||||
activations and retraining when necessary
|
||||
|
@ -161,17 +153,8 @@ class IncrementalTrainer:
|
|||
|
||||
|
||||
def main():
|
||||
args = TrainData.parse_args(create_parser(usage))
|
||||
|
||||
for i in (
|
||||
join(args.folder, 'not-wake-word', 'generated'),
|
||||
join(args.folder, 'test', 'not-wake-word', 'generated')
|
||||
):
|
||||
makedirs(i, exist_ok=True)
|
||||
|
||||
trainer = IncrementalTrainer(args)
|
||||
try:
|
||||
trainer.train_incremental()
|
||||
IncrementalTrainer().run()
|
||||
except KeyboardInterrupt:
|
||||
print()
|
||||
|
||||
|
|
|
@ -0,0 +1,103 @@
|
|||
from argparse import ArgumentParser
|
||||
from fitipy import Fitipy
|
||||
from keras.callbacks import LambdaCallback
|
||||
from os.path import splitext
|
||||
from prettyparse import add_to_parser
|
||||
from typing import Any, Tuple
|
||||
|
||||
from precise.functions import set_loss_bias
|
||||
from precise.model import create_model
|
||||
from precise.params import inject_params, save_params
|
||||
from precise.train_data import TrainData
|
||||
|
||||
|
||||
class Trainer:
|
||||
usage = '''
|
||||
Train a new model on a dataset
|
||||
|
||||
:model str
|
||||
Keras model file (.net) to load from and save to
|
||||
|
||||
:-e --epochs int 10
|
||||
Number of epochs to train model for
|
||||
|
||||
:-s --sensitivity float 0.2
|
||||
Weighted loss bias. Higher values decrease increase positives
|
||||
|
||||
:-b --batch-size int 5000
|
||||
Batch size for training
|
||||
|
||||
:-sb --save-best
|
||||
Only save the model each epoch if its stats improve
|
||||
|
||||
:-nv --no-validation
|
||||
Disable accuracy and validation calculation
|
||||
to improve speed during training
|
||||
|
||||
:-mm --metric-monitor str loss
|
||||
Metric used to determine when to save
|
||||
|
||||
:-em --extra-metrics
|
||||
Add extra metrics during training
|
||||
|
||||
...
|
||||
'''
|
||||
|
||||
def __init__(self, parser=None):
|
||||
parser = parser or ArgumentParser()
|
||||
add_to_parser(parser, self.usage, True)
|
||||
self.args = args = TrainData.parse_args(parser)
|
||||
if not 0.0 <= args.sensitivity <= 1.0:
|
||||
parser.error('sensitivity must be between 0.0 and 1.0')
|
||||
|
||||
inject_params(args.model)
|
||||
save_params(args.model)
|
||||
self.train, self.test = self.load_data(self.args)
|
||||
|
||||
set_loss_bias(1.0 - args.sensitivity)
|
||||
self.model = create_model(args.model, args.no_validation, args.extra_metrics)
|
||||
self.model.summary()
|
||||
|
||||
from keras.callbacks import ModelCheckpoint, TensorBoard
|
||||
checkpoint = ModelCheckpoint(args.model, monitor=args.metric_monitor,
|
||||
save_best_only=args.save_best)
|
||||
epoch_fiti = Fitipy(splitext(args.model)[0] + '.epoch')
|
||||
self.epoch = epoch_fiti.read().read(0, int)
|
||||
|
||||
def on_epoch_end(a, b):
|
||||
self.epoch += 1
|
||||
epoch_fiti.write().write(self.epoch, str)
|
||||
|
||||
self.callbacks = [
|
||||
checkpoint, TensorBoard(),
|
||||
LambdaCallback(on_epoch_end=on_epoch_end)
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def load_data(args: Any) -> Tuple[tuple, tuple]:
|
||||
data = TrainData.from_both(args.tags_file, args.tags_folder, args.folder)
|
||||
print('Data:', data)
|
||||
train, test = data.load(True, not args.no_validation)
|
||||
|
||||
print('Inputs shape:', train[0].shape)
|
||||
print('Outputs shape:', train[1].shape)
|
||||
|
||||
if test:
|
||||
print('Test inputs shape:', test[0].shape)
|
||||
print('Test outputs shape:', test[1].shape)
|
||||
|
||||
if 0 in train[0].shape or 0 in train[1].shape:
|
||||
print('Not enough data to train')
|
||||
exit(1)
|
||||
|
||||
return train, test
|
||||
|
||||
def run(self):
|
||||
try:
|
||||
self.model.fit(
|
||||
self.train[0], self.train[1], self.args.batch_size, self.epoch + self.args.epochs,
|
||||
validation_data=self.test, initial_epoch=self.epoch,
|
||||
callbacks=self.callbacks
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
print()
|
|
@ -11,7 +11,7 @@ Markdown==2.6.11
|
|||
numpy==1.14.2
|
||||
pocketsphinx==0.1.3
|
||||
-e git+https://github.com/mycroftai/mycroft-precise#egg=precise_runner&subdirectory=runner
|
||||
prettyparse==0.1.2
|
||||
prettyparse==0.1.4
|
||||
protobuf==3.5.2.post1
|
||||
PyAudio==0.2.11
|
||||
PyYAML==3.12
|
||||
|
@ -24,3 +24,4 @@ termcolor==1.1.0
|
|||
typing==3.6.4
|
||||
wavio==0.0.3
|
||||
Werkzeug==0.14.1
|
||||
fitipy==0.1.1
|
Loading…
Reference in New Issue