Make train_optimize inherit from Trainer
parent
8145dedefa
commit
d6db6c8ec2
|
@ -3,20 +3,21 @@
|
|||
.cache/
|
||||
/.idea/
|
||||
__pycache__/
|
||||
*.pyc
|
||||
*.egg-info/
|
||||
/*.pb
|
||||
/*.params
|
||||
/*.net
|
||||
/*.pbtxt
|
||||
/*.txt
|
||||
/*.wav
|
||||
/other/
|
||||
/.venv/
|
||||
/stats.json
|
||||
/data/*/
|
||||
/wakewords/
|
||||
/logs
|
||||
*.egg-info/
|
||||
|
||||
*.pyc
|
||||
*.pb
|
||||
*.params
|
||||
*.net
|
||||
*.json
|
||||
*.pbtxt
|
||||
*.txt
|
||||
*.wav
|
||||
|
||||
# Data folders
|
||||
/*/wake-word/
|
||||
|
|
|
@ -11,16 +11,31 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import attr
|
||||
from os.path import isfile
|
||||
from typing import *
|
||||
|
||||
from precise.functions import load_keras, false_pos, false_neg, weighted_log_loss
|
||||
from precise.functions import load_keras, false_pos, false_neg, weighted_log_loss, set_loss_bias
|
||||
from precise.params import inject_params, pr
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from keras.models import Sequential
|
||||
|
||||
lstm_units = 20
|
||||
|
||||
@attr.s()
|
||||
class ModelParams:
|
||||
"""
|
||||
Attributes:
|
||||
recurrent_units:
|
||||
dropout:
|
||||
extra_metrics: Whether to include false positive and false negative metrics
|
||||
skip_acc: Whether to skip accuracy calculation while training
|
||||
"""
|
||||
recurrent_units = attr.ib(20) # type: int
|
||||
dropout = attr.ib(0.2) # type: float
|
||||
extra_metrics = attr.ib(False) # type: bool
|
||||
skip_acc = attr.ib(False) # type: bool
|
||||
loss_bias = attr.ib(0.7) # type: float
|
||||
|
||||
|
||||
def load_precise_model(model_name: str) -> Any:
|
||||
|
@ -32,19 +47,18 @@ def load_precise_model(model_name: str) -> Any:
|
|||
return load_keras().models.load_model(model_name)
|
||||
|
||||
|
||||
def create_model(model_name: str, skip_acc=False, extra_metrics=False) -> 'Sequential':
|
||||
def create_model(model_name: Optional[str], params: ModelParams) -> 'Sequential':
|
||||
"""
|
||||
Load or create a precise model
|
||||
|
||||
Args:
|
||||
model_name: Name of model
|
||||
skip_acc: Whether to skip accuracy calculation while training
|
||||
extra_metrics: Whether to include false positive and false negative metrics
|
||||
params: Parameters used to create the model
|
||||
|
||||
Returns:
|
||||
model: Loaded Keras model
|
||||
"""
|
||||
if isfile(model_name):
|
||||
if model_name and isfile(model_name):
|
||||
print('Loading from ' + model_name + '...')
|
||||
model = load_precise_model(model_name)
|
||||
else:
|
||||
|
@ -53,11 +67,14 @@ def create_model(model_name: str, skip_acc=False, extra_metrics=False) -> 'Seque
|
|||
from keras.models import Sequential
|
||||
|
||||
model = Sequential()
|
||||
model.add(GRU(lstm_units, activation='linear', input_shape=(pr.n_features, pr.feature_size),
|
||||
dropout=0.3, name='net'))
|
||||
model.add(GRU(
|
||||
params.recurrent_units, activation='linear',
|
||||
input_shape=(pr.n_features, pr.feature_size), dropout=params.dropout, name='net'
|
||||
))
|
||||
model.add(Dense(1, activation='sigmoid'))
|
||||
|
||||
load_keras()
|
||||
metrics = ['accuracy'] + extra_metrics * [false_pos, false_neg]
|
||||
model.compile('rmsprop', weighted_log_loss, metrics=(not skip_acc) * metrics)
|
||||
metrics = ['accuracy'] + params.extra_metrics * [false_pos, false_neg]
|
||||
set_loss_bias(params.loss_bias)
|
||||
model.compile('rmsprop', weighted_log_loss, metrics=(not params.skip_acc) * metrics)
|
||||
return model
|
||||
|
|
|
@ -19,7 +19,7 @@ from prettyparse import create_parser
|
|||
from random import random
|
||||
from typing import *
|
||||
|
||||
from precise.model import create_model
|
||||
from precise.model import create_model, ModelParams
|
||||
from precise.network_runner import Listener, KerasRunner
|
||||
from precise.params import pr
|
||||
from precise.train_data import TrainData
|
||||
|
@ -80,9 +80,10 @@ class IncrementalTrainer(Trainer):
|
|||
self.audio_buffer = np.zeros(pr.buffer_samples, dtype=float)
|
||||
|
||||
if not isfile(self.args.model):
|
||||
create_model(self.args.model, self.args.no_validation, self.args.extra_metrics).save(
|
||||
self.args.model
|
||||
params = ModelParams(
|
||||
skip_acc=self.args.no_validation, extra_metrics=self.args.extra_metrics
|
||||
)
|
||||
create_model(self.args.model, params).save(self.args.model)
|
||||
self.listener = Listener(self.args.model, self.args.chunk_size, runner_cls=KerasRunner)
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -12,97 +12,125 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import re
|
||||
from glob import glob
|
||||
from os import remove
|
||||
|
||||
from os.path import isfile, splitext, join
|
||||
|
||||
import h5py
|
||||
import numpy
|
||||
# Optimizer blackhat
|
||||
from bbopt import BlackBoxOptimizer
|
||||
from keras.layers.core import Dense
|
||||
from keras.layers.recurrent import GRU
|
||||
from keras.models import Sequential
|
||||
from pprint import pprint
|
||||
from typing import *
|
||||
from prettyparse import create_parser
|
||||
from shutil import rmtree
|
||||
from typing import Any
|
||||
|
||||
from precise.functions import weighted_log_loss
|
||||
from precise.params import pr
|
||||
from precise.model import ModelParams, create_model
|
||||
from precise.train_data import TrainData
|
||||
from precise.trainer import Trainer
|
||||
|
||||
usage = '''
|
||||
Use black box optimization to tune model hyperparameters
|
||||
|
||||
:-t --trials-name str -
|
||||
Filename to save hyperparameter optimization trials in
|
||||
'.bbopt.json' will automatically be appended
|
||||
|
||||
:-c --cycles int 20
|
||||
Number of cycles of optimization to run
|
||||
|
||||
:-m --model str .cache/optimized.net
|
||||
Model to load from
|
||||
...
|
||||
'''
|
||||
|
||||
|
||||
def false_pos(yt, yp) -> Any:
|
||||
from keras import backend as K
|
||||
return K.sum(K.cast(yp * (1 - yt) > 0.5, 'float')) / K.sum(1 - yt)
|
||||
class OptimizeTrainer(Trainer):
|
||||
usage = re.sub(r'.*:model str.*\n.*\n', '', Trainer.usage)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(create_parser(usage))
|
||||
self.bb = BlackBoxOptimizer(file=self.args.trials_name)
|
||||
if not self.test:
|
||||
data = TrainData.from_both(self.args.tags_file, self.args.tags_folder, self.args.folder)
|
||||
_, self.test = data.load(False, True)
|
||||
|
||||
def false_neg(yt, yp) -> Any:
|
||||
from keras import backend as K
|
||||
return K.sum(K.cast((1 - yp) * (0 + yt) > 0.5, 'float')) / K.sum(0 + yt)
|
||||
from keras.callbacks import ModelCheckpoint
|
||||
for i in list(self.callbacks):
|
||||
if isinstance(i, ModelCheckpoint):
|
||||
self.callbacks.remove(i)
|
||||
|
||||
def process_args(self, args: Any):
|
||||
model_parts = glob(splitext(args.model)[0] + '.*')
|
||||
if len(model_parts) < 5:
|
||||
for name in model_parts:
|
||||
if isfile(name):
|
||||
remove(name)
|
||||
else:
|
||||
rmtree(name)
|
||||
args.trials_name = args.trials_name.replace('.bbopt.json', '').replace('.json', '')
|
||||
if not args.trials_name:
|
||||
if isfile(join('.cache', 'trials.bbopt.json')):
|
||||
remove(join('.cache', 'trials.bbopt.json'))
|
||||
args.trials_name = join('.cache', 'trials')
|
||||
|
||||
def run(self):
|
||||
print('Writing to:', self.args.trials_name + '.bbopt.json')
|
||||
for i in range(self.args.cycles):
|
||||
self.bb.run(backend="random")
|
||||
print("\n= %d = (example #%d)" % (i + 1, len(self.bb.get_data()["examples"]) + 1))
|
||||
|
||||
params = ModelParams(
|
||||
recurrent_units=self.bb.randint("units", 1, 100, guess=50),
|
||||
dropout=self.bb.uniform("dropout", 0.1, 0.9, guess=0.6),
|
||||
extra_metrics=self.args.extra_metrics,
|
||||
skip_acc=self.args.no_validation,
|
||||
loss_bias=1.0 - self.args.sensitivity
|
||||
)
|
||||
print('Testing with:', params)
|
||||
model = create_model(self.args.model, params)
|
||||
model.fit(
|
||||
*self.sampled_data, batch_size=self.args.batch_size,
|
||||
epochs=self.epoch + self.args.epochs,
|
||||
validation_data=self.test * (not self.args.no_validation),
|
||||
callbacks=self.callbacks, initial_epoch=self.epoch,
|
||||
)
|
||||
resp = model.evaluate(*self.test, batch_size=self.args.batch_size)
|
||||
if not isinstance(resp, (list, tuple)):
|
||||
resp = [resp, None]
|
||||
test_loss, test_acc = resp
|
||||
predictions = model.predict(self.test[0], batch_size=self.args.batch_size)
|
||||
|
||||
num_false_positive = numpy.sum(predictions * (1 - self.test[1]) > 0.5)
|
||||
num_false_negative = numpy.sum((1 - predictions) * self.test[1] > 0.5)
|
||||
false_positives = num_false_positive / numpy.sum(self.test[1] < 0.5)
|
||||
false_negatives = num_false_negative / numpy.sum(self.test[1] > 0.5)
|
||||
|
||||
from math import exp
|
||||
param_score = 1.0 / (1.0 + exp((model.count_params() - 11000) / 2000))
|
||||
fitness = param_score * (1.0 - 0.2 * false_negatives - 0.8 * false_positives)
|
||||
|
||||
self.bb.remember({
|
||||
"test loss": test_loss,
|
||||
"test accuracy": test_acc,
|
||||
"false positive%": false_positives,
|
||||
"false negative%": false_negatives,
|
||||
"fitness": fitness
|
||||
})
|
||||
|
||||
print(false_positives)
|
||||
print("False positive: ", false_positives * 100, "%")
|
||||
|
||||
self.bb.maximize(fitness)
|
||||
pprint(self.bb.get_current_run())
|
||||
best_example = self.bb.get_optimal_run()
|
||||
print("\n= BEST = (example #%d)" % self.bb.get_data()["examples"].index(best_example))
|
||||
pprint(best_example)
|
||||
|
||||
|
||||
def main():
|
||||
bb = BlackBoxOptimizer(file=__file__)
|
||||
|
||||
# Loading in data to train
|
||||
data = TrainData.from_both('/home/mikhail/wakewords/wakewords/files/tags.txt',
|
||||
'/home/mikhail/wakewords/wakewords/files',
|
||||
'/home/mikhail/wakewords/wakewords/not-wake-word/generated')
|
||||
(train_inputs, train_outputs), (test_inputs, test_outputs) = data.load()
|
||||
|
||||
test_data = (test_inputs, test_outputs)
|
||||
|
||||
for i in range(5):
|
||||
bb.run(backend="random")
|
||||
|
||||
print("\n= %d = (example #%d)" % (i + 1, len(bb.get_data()["examples"]) + 1))
|
||||
|
||||
shuffle_ids = numpy.arange(len(test_inputs))
|
||||
numpy.random.shuffle(shuffle_ids)
|
||||
(test_inputs, test_outputs) = (test_inputs[shuffle_ids], test_outputs[shuffle_ids])
|
||||
|
||||
model_array = numpy.empty(len(test_data), dtype=int)
|
||||
with h5py.File('tested_models.hdf5', 'w') as f:
|
||||
f.create_dataset('dataset_1', data=model_array)
|
||||
f.close()
|
||||
|
||||
batch_size = bb.randint("batch_size", 1000, 5000, guess=3000)
|
||||
|
||||
model = Sequential()
|
||||
model.add(GRU(units=bb.randint("units", 1, 100, guess=50), activation='linear',
|
||||
input_shape=(pr.n_features, pr.feature_size),
|
||||
dropout=bb.uniform("dropout", 0.1, 0.9, guess=0.6), name='net'))
|
||||
model.add(Dense(1, activation='sigmoid'))
|
||||
|
||||
model.compile('rmsprop', weighted_log_loss, metrics=['accuracy'])
|
||||
|
||||
from keras.callbacks import ModelCheckpoint
|
||||
|
||||
checkpoint = ModelCheckpoint('tested_models.hdf5', monitor='val_loss',
|
||||
save_best_only=True)
|
||||
|
||||
model.fit(train_inputs, train_outputs, batch_size=batch_size, epochs=100,
|
||||
validation_data=(test_inputs, test_outputs),
|
||||
callbacks=[checkpoint])
|
||||
test_loss, test_acc = model.evaluate(test_inputs, test_outputs)
|
||||
|
||||
predictions = model.predict(test_inputs)
|
||||
num_false_positive = numpy.sum(predictions * (1 - test_outputs) > 0.5)
|
||||
num_false_negative = numpy.sum((1 - predictions) * test_outputs > 0.5)
|
||||
false_positives = num_false_positive / numpy.sum(test_outputs < 0.5)
|
||||
false_negatives = num_false_negative / numpy.sum(test_outputs > 0.5)
|
||||
|
||||
bb.remember({
|
||||
"test loss": test_loss,
|
||||
"test accuracy": test_acc,
|
||||
"false positive%": false_positives,
|
||||
"false negative%": false_negatives
|
||||
})
|
||||
print(false_positives)
|
||||
print("False positive: ", false_positives * 100, "%")
|
||||
bb.minimize(false_positives)
|
||||
pprint(bb.get_current_run())
|
||||
|
||||
best_example = bb.get_optimal_run()
|
||||
print("\n= BEST = (example #%d)" % bb.get_data()["examples"].index(best_example))
|
||||
pprint(best_example)
|
||||
OptimizeTrainer().run()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -20,7 +20,7 @@ from os.path import join, isfile
|
|||
from prettyparse import add_to_parser
|
||||
from typing import *
|
||||
|
||||
from precise.util import find_wavs
|
||||
from precise.util import find_wavs, InvalidAudio
|
||||
from precise.vectorization import load_vector, vectorize_inhibit
|
||||
|
||||
|
||||
|
@ -145,7 +145,7 @@ class TrainData:
|
|||
"""Return parsed args from parser, adding options for train data inputs"""
|
||||
extra_usage = '''
|
||||
:folder str
|
||||
Folder to wav files from
|
||||
Folder to load wav files from
|
||||
|
||||
:-tf --tags-folder str {folder}
|
||||
Specify a different folder to load file ids
|
||||
|
@ -195,8 +195,8 @@ class TrainData:
|
|||
try:
|
||||
inputs.append(load_vector(f, vectorizer))
|
||||
outputs.append(np.array([output]))
|
||||
except ValueError:
|
||||
print('Skipping invalid file:', f)
|
||||
except InvalidAudio as e:
|
||||
print('Skipping invalid file:', f, e)
|
||||
|
||||
print('Loading wake-word...')
|
||||
add(kw_files, 1.0)
|
||||
|
|
|
@ -6,7 +6,7 @@ from prettyparse import add_to_parser
|
|||
from typing import Any, Tuple
|
||||
|
||||
from precise.functions import set_loss_bias
|
||||
from precise.model import create_model
|
||||
from precise.model import create_model, ModelParams
|
||||
from precise.params import inject_params, save_params
|
||||
from precise.train_data import TrainData
|
||||
from precise.util import calc_sample_hash
|
||||
|
@ -54,7 +54,8 @@ class Trainer:
|
|||
def __init__(self, parser=None):
|
||||
parser = parser or ArgumentParser()
|
||||
add_to_parser(parser, self.usage, True)
|
||||
self.args = args = TrainData.parse_args(parser)
|
||||
args = TrainData.parse_args(parser)
|
||||
self.args = args = self.process_args(args) or args
|
||||
|
||||
if args.invert_samples and not args.samples_file:
|
||||
parser.error('You must specify --samples-file when using --invert-samples')
|
||||
|
@ -68,7 +69,8 @@ class Trainer:
|
|||
self.train, self.test = self.load_data(self.args)
|
||||
|
||||
set_loss_bias(1.0 - args.sensitivity)
|
||||
self.model = create_model(args.model, args.no_validation, args.extra_metrics)
|
||||
params = ModelParams(skip_acc=args.no_validation, extra_metrics=args.extra_metrics)
|
||||
self.model = create_model(args.model, params)
|
||||
self.model.summary()
|
||||
|
||||
from keras.callbacks import ModelCheckpoint, TensorBoard
|
||||
|
@ -91,10 +93,14 @@ class Trainer:
|
|||
|
||||
self.callbacks = [
|
||||
checkpoint, TensorBoard(
|
||||
log_dir=self.model_base + '.logs', histogram_freq=10 if self.test else 0
|
||||
log_dir=self.model_base + '.logs',
|
||||
), LambdaCallback(on_epoch_end=on_epoch_end)
|
||||
]
|
||||
|
||||
def process_args(self, args: Any) -> Any:
|
||||
"""Override to modify args"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def load_sample_data(filename, train_data) -> Tuple[set, dict]:
|
||||
samples = Fitipy(filename).read().set()
|
||||
|
|
|
@ -19,6 +19,11 @@ from typing import *
|
|||
from precise.params import pr
|
||||
|
||||
|
||||
|
||||
class InvalidAudio(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
def buffer_to_audio(buffer: bytes) -> np.ndarray:
|
||||
"""Convert a raw mono audio byte string to numpy array of floats"""
|
||||
return np.fromstring(buffer, dtype='<i2').astype(np.float32, order='C') / 32768.0
|
||||
|
@ -42,9 +47,9 @@ def load_audio(file: Any) -> np.ndarray:
|
|||
except EOFError:
|
||||
wav = wavio.Wav(np.array([[]], dtype=np.int16), 16000, 2)
|
||||
if wav.data.dtype != np.int16:
|
||||
raise ValueError('Unsupported data type: ' + str(wav.data.dtype))
|
||||
raise InvalidAudio('Unsupported data type: ' + str(wav.data.dtype))
|
||||
if wav.rate != pr.sample_rate:
|
||||
raise ValueError('Unsupported sample rate: ' + str(wav.rate))
|
||||
raise InvalidAudio('Unsupported sample rate: ' + str(wav.rate))
|
||||
|
||||
data = np.squeeze(wav.data)
|
||||
return data.astype(np.float32) / float(np.iinfo(data.dtype).max)
|
||||
|
|
Loading…
Reference in New Issue