Overhaul of how scripts are written to allow programmatic access
This introduces a way for scripts to be easily called from within Python with command line arguments as function parameters To support this, prettyparse has been upgraded to the latest versionpull/102/head
parent
fb452ca1eb
commit
4cec9b0767
|
@ -12,56 +12,56 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from prettyparse import create_parser
|
||||
from random import randint
|
||||
from precise_runner import PreciseRunner
|
||||
from precise_runner.runner import ListenerEngine
|
||||
from prettyparse import Usage
|
||||
from threading import Event
|
||||
|
||||
from precise.pocketsphinx.listener import PocketsphinxListener
|
||||
from precise.scripts.base_script import BaseScript
|
||||
from precise.util import activate_notify
|
||||
from precise_runner import PreciseRunner
|
||||
from precise_runner.runner import ListenerEngine
|
||||
|
||||
usage = '''
|
||||
Run Pocketsphinx on microphone audio input
|
||||
|
||||
:key_phrase str
|
||||
Key phrase composed of words from dictionary
|
||||
|
||||
:dict_file str
|
||||
Filename of dictionary with word pronunciations
|
||||
|
||||
:hmm_folder str
|
||||
Folder containing hidden markov model
|
||||
|
||||
:-th --threshold str 1e-90
|
||||
Threshold for activations
|
||||
|
||||
:-c --chunk-size int 2048
|
||||
Samples between inferences
|
||||
'''
|
||||
|
||||
session_id, chunk_num = '%09d' % randint(0, 999999999), 0
|
||||
|
||||
|
||||
def main():
|
||||
args = create_parser(usage).parse_args()
|
||||
class PocketsphinxListenScript(BaseScript):
|
||||
usage = Usage('''
|
||||
Run Pocketsphinx on microphone audio input
|
||||
|
||||
def on_activation():
|
||||
activate_notify()
|
||||
:key_phrase str
|
||||
Key phrase composed of words from dictionary
|
||||
|
||||
def on_prediction(conf):
|
||||
print('!' if conf > 0.5 else '.', end='', flush=True)
|
||||
:dict_file str
|
||||
Filename of dictionary with word pronunciations
|
||||
|
||||
runner = PreciseRunner(
|
||||
ListenerEngine(
|
||||
PocketsphinxListener(
|
||||
args.key_phrase, args.dict_file, args.hmm_folder, args.threshold, args.chunk_size
|
||||
)
|
||||
), 3, on_activation=on_activation, on_prediction=on_prediction
|
||||
)
|
||||
runner.start()
|
||||
Event().wait() # Wait forever
|
||||
:hmm_folder str
|
||||
Folder containing hidden markov model
|
||||
|
||||
:-th --threshold str 1e-90
|
||||
Threshold for activations
|
||||
|
||||
:-c --chunk-size int 2048
|
||||
Samples between inferences
|
||||
''')
|
||||
|
||||
def run(self):
|
||||
def on_activation():
|
||||
activate_notify()
|
||||
|
||||
def on_prediction(conf):
|
||||
print('!' if conf > 0.5 else '.', end='', flush=True)
|
||||
|
||||
args = self.args
|
||||
runner = PreciseRunner(
|
||||
ListenerEngine(
|
||||
PocketsphinxListener(
|
||||
args.key_phrase, args.dict_file, args.hmm_folder, args.threshold, args.chunk_size
|
||||
)
|
||||
), 3, on_activation=on_activation, on_prediction=on_prediction
|
||||
)
|
||||
runner.start()
|
||||
Event().wait() # Wait forever
|
||||
|
||||
|
||||
main = PocketsphinxListenScript.run_main
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
|
|
@ -13,80 +13,101 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import wave
|
||||
from prettyparse import create_parser
|
||||
from prettyparse import Usage
|
||||
from subprocess import check_output, PIPE
|
||||
|
||||
from precise.pocketsphinx.listener import PocketsphinxListener
|
||||
from precise.scripts.test import show_stats, Stats
|
||||
from precise.scripts.base_script import BaseScript
|
||||
from precise.scripts.test import Stats
|
||||
from precise.train_data import TrainData
|
||||
|
||||
usage = '''
|
||||
Test a dataset using Pocketsphinx
|
||||
|
||||
:key_phrase str
|
||||
Key phrase composed of words from dictionary
|
||||
|
||||
:dict_file str
|
||||
Filename of dictionary with word pronunciations
|
||||
|
||||
:hmm_folder str
|
||||
Folder containing hidden markov model
|
||||
|
||||
:-th --threshold str 1e-90
|
||||
Threshold for activations
|
||||
|
||||
:-t --use-train
|
||||
Evaluate training data instead of test data
|
||||
|
||||
:-nf --no-filenames
|
||||
Don't show the names of files that failed
|
||||
|
||||
...
|
||||
'''
|
||||
|
||||
class PocketsphinxTestScript(BaseScript):
|
||||
usage = Usage('''
|
||||
Test a dataset using Pocketsphinx
|
||||
|
||||
def eval_file(filename) -> float:
|
||||
transcription = check_output(
|
||||
['pocketsphinx_continuous', '-kws_threshold', '1e-20', '-keyphrase', 'hey my craft',
|
||||
'-infile', filename], stderr=PIPE)
|
||||
return float(bool(transcription) and not transcription.isspace())
|
||||
:key_phrase str
|
||||
Key phrase composed of words from dictionary
|
||||
|
||||
:dict_file str
|
||||
Filename of dictionary with word pronunciations
|
||||
|
||||
def test_pocketsphinx(listener: PocketsphinxListener, data_files) -> Stats:
|
||||
def run_test(filenames, name):
|
||||
:hmm_folder str
|
||||
Folder containing hidden markov model
|
||||
|
||||
:-th --threshold str 1e-90
|
||||
Threshold for activations
|
||||
|
||||
:-t --use-train
|
||||
Evaluate training data instead of test data
|
||||
|
||||
:-nf --no-filenames
|
||||
Don't show the names of files that failed
|
||||
|
||||
...
|
||||
''') | TrainData.usage
|
||||
|
||||
def __init__(self, args):
|
||||
super().__init__(args)
|
||||
self.listener = PocketsphinxListener(
|
||||
args.key_phrase, args.dict_file, args.hmm_folder, args.threshold
|
||||
)
|
||||
|
||||
self.outputs = []
|
||||
self.targets = []
|
||||
self.filenames = []
|
||||
|
||||
def get_stats(self):
|
||||
return Stats(self.outputs, self.targets, self.filenames)
|
||||
|
||||
def run(self):
|
||||
args = self.args
|
||||
data = TrainData.from_both(args.tags_file, args.tags_folder, args.folder)
|
||||
print('Data:', data)
|
||||
|
||||
ww_files, nww_files = data.train_files if args.use_train else data.test_files
|
||||
self.run_test(ww_files, 'Wake Word', 1.0)
|
||||
self.run_test(nww_files, 'Not Wake Word', 0.0)
|
||||
stats = self.get_stats()
|
||||
if not self.args.no_filenames:
|
||||
fp_files = stats.calc_filenames(False, True, 0.5)
|
||||
fn_files = stats.calc_filenames(False, False, 0.5)
|
||||
print('=== False Positives ===')
|
||||
print('\n'.join(fp_files))
|
||||
print()
|
||||
print('=== False Negatives ===')
|
||||
print('\n'.join(fn_files))
|
||||
print()
|
||||
print(stats.counts_str(0.5))
|
||||
print()
|
||||
print('===', name, '===')
|
||||
negatives, positives = [], []
|
||||
for filename in filenames:
|
||||
print(stats.summary_str(0.5))
|
||||
|
||||
def eval_file(self, filename) -> float:
|
||||
transcription = check_output(
|
||||
['pocketsphinx_continuous', '-kws_threshold', '1e-20', '-keyphrase', 'hey my craft',
|
||||
'-infile', filename], stderr=PIPE)
|
||||
return float(bool(transcription) and not transcription.isspace())
|
||||
|
||||
def run_test(self, test_files, label_name, label):
|
||||
print()
|
||||
print('===', label_name, '===')
|
||||
for test_file in test_files:
|
||||
try:
|
||||
with wave.open(filename) as wf:
|
||||
with wave.open(test_file) as wf:
|
||||
frames = wf.readframes(wf.getnframes())
|
||||
except (OSError, EOFError):
|
||||
print('?', end='', flush=True)
|
||||
continue
|
||||
out = listener.found_wake_word(frames)
|
||||
{False: negatives, True: positives}[out].append(filename)
|
||||
|
||||
out = int(self.listener.found_wake_word(frames))
|
||||
self.outputs.append(out)
|
||||
self.targets.append(label)
|
||||
self.filenames.append(test_file)
|
||||
print('!' if out else '.', end='', flush=True)
|
||||
print()
|
||||
return negatives, positives
|
||||
|
||||
false_neg, true_pos = run_test(data_files[0], 'Wake Word')
|
||||
true_neg, false_pos = run_test(data_files[1], 'Not Wake Word')
|
||||
return Stats(false_pos, false_neg, true_pos, true_neg)
|
||||
|
||||
|
||||
def main():
|
||||
args = TrainData.parse_args(create_parser(usage))
|
||||
data = TrainData.from_both(args.tags_file, args.tags_folder, args.folder)
|
||||
data_files = data.train_files if args.use_train else data.test_files
|
||||
listener = PocketsphinxListener(
|
||||
args.key_phrase, args.dict_file, args.hmm_folder, args.threshold
|
||||
)
|
||||
|
||||
print('Data:', data)
|
||||
stats = test_pocketsphinx(listener, data_files)
|
||||
show_stats(stats, not args.no_filenames)
|
||||
|
||||
main = PocketsphinxTestScript.run_main
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
|
|
@ -12,7 +12,6 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from itertools import chain
|
||||
from math import sqrt
|
||||
|
||||
import numpy as np
|
||||
|
@ -20,40 +19,15 @@ import os
|
|||
from glob import glob
|
||||
from os import makedirs
|
||||
from os.path import join, dirname, abspath, splitext
|
||||
from pip._vendor.distlib._backport import shutil
|
||||
from prettyparse import create_parser
|
||||
import shutil
|
||||
from prettyparse import Usage
|
||||
from random import random
|
||||
|
||||
from precise.scripts.base_script import BaseScript
|
||||
from precise.train_data import TrainData
|
||||
from precise.util import load_audio
|
||||
from precise.util import save_audio
|
||||
|
||||
usage = '''
|
||||
Create a duplicate dataset with added noise
|
||||
|
||||
:folder str
|
||||
Folder containing source dataset
|
||||
|
||||
:-tg --tags-file str -
|
||||
Tags file to optionally load from
|
||||
|
||||
:noise_folder str
|
||||
Folder with wav files containing noise to be added
|
||||
|
||||
:output_folder str
|
||||
Folder to write the duplicate generated dataset
|
||||
|
||||
:-if --inflation-factor int 1
|
||||
The number of noisy samples generated per single source sample
|
||||
|
||||
:-nl --noise-ratio-low float 0.0
|
||||
Minimum random ratio of noise to sample. 1.0 is all noise, no sample sound
|
||||
|
||||
:-nh --noise-ratio-high float 0.4
|
||||
Maximum random ratio of noise to sample. 1.0 is all noise, no sample sound
|
||||
...
|
||||
'''
|
||||
|
||||
|
||||
class NoiseData:
|
||||
def __init__(self, noise_folder: str):
|
||||
|
@ -92,42 +66,72 @@ class NoiseData:
|
|||
return noise_ratio * adjusted_noise + (1.0 - noise_ratio) * audio
|
||||
|
||||
|
||||
def main():
|
||||
args = create_parser(usage).parse_args()
|
||||
args.tags_file = abspath(args.tags_file) if args.tags_file else None
|
||||
args.folder = abspath(args.folder)
|
||||
args.output_folder = abspath(args.output_folder)
|
||||
noise_min, noise_max = args.noise_ratio_low, args.noise_ratio_high
|
||||
class AddNoiseScript(BaseScript):
|
||||
usage = Usage(
|
||||
"""
|
||||
Create a duplicate dataset with added noise
|
||||
|
||||
data = TrainData.from_both(args.tags_file, args.folder, args.folder)
|
||||
noise_data = NoiseData(args.noise_folder)
|
||||
print('Data:', data)
|
||||
:folder str
|
||||
Folder containing source dataset
|
||||
|
||||
def translate_filename(source: str, n=0) -> str:
|
||||
assert source.startswith(args.folder)
|
||||
relative_file = source[len(args.folder):].strip(os.path.sep)
|
||||
if n > 0:
|
||||
base, ext = splitext(relative_file)
|
||||
relative_file = base + '.' + str(n) + ext
|
||||
return join(args.output_folder, relative_file)
|
||||
:-tg --tags-file str -
|
||||
Tags file to optionally load from
|
||||
|
||||
all_filenames = sum(data.train_files + data.test_files, [])
|
||||
for i, filename in enumerate(all_filenames):
|
||||
print('{0:.2%} \r'.format(i / (len(all_filenames) - 1)), end='', flush=True)
|
||||
:noise_folder str
|
||||
Folder with wav files containing noise to be added
|
||||
|
||||
audio = load_audio(filename)
|
||||
for n in range(args.inflation_factor):
|
||||
altered = noise_data.noised_audio(audio, noise_min + (noise_max - noise_min) * random())
|
||||
output_filename = translate_filename(filename, n)
|
||||
:output_folder str
|
||||
Folder to write the duplicate generated dataset
|
||||
|
||||
makedirs(dirname(output_filename), exist_ok=True)
|
||||
save_audio(output_filename, altered)
|
||||
:-if --inflation-factor int 1
|
||||
The number of noisy samples generated per single source sample
|
||||
|
||||
print('Done!')
|
||||
:-nl --noise-ratio-low float 0.0
|
||||
Minimum random ratio of noise to sample. 1.0 is all noise, no sample sound
|
||||
|
||||
if args.tags_file and args.tags_file.startswith(args.folder):
|
||||
shutil.copy2(args.tags_file, translate_filename(args.tags_file))
|
||||
:-nh --noise-ratio-high float 0.4
|
||||
Maximum random ratio of noise to sample. 1.0 is all noise, no sample sound
|
||||
""",
|
||||
tags_file=lambda args: abspath(args.tags_file) if args.tags_file else None,
|
||||
folder=lambda args: abspath(args.folder),
|
||||
output_folder=lambda args: abspath(args.output_folder)
|
||||
)
|
||||
|
||||
def run(self):
|
||||
args = self.args
|
||||
noise_min, noise_max = args.noise_ratio_low, args.noise_ratio_high
|
||||
|
||||
data = TrainData.from_both(args.tags_file, args.folder, args.folder)
|
||||
noise_data = NoiseData(args.noise_folder)
|
||||
print('Data:', data)
|
||||
|
||||
def translate_filename(source: str, n=0) -> str:
|
||||
assert source.startswith(args.folder)
|
||||
relative_file = source[len(args.folder):].strip(os.path.sep)
|
||||
if n > 0:
|
||||
base, ext = splitext(relative_file)
|
||||
relative_file = base + '.' + str(n) + ext
|
||||
return join(args.output_folder, relative_file)
|
||||
|
||||
all_filenames = sum(data.train_files + data.test_files, [])
|
||||
for i, filename in enumerate(all_filenames):
|
||||
print('{0:.2%} \r'.format(i / (len(all_filenames) - 1)), end='', flush=True)
|
||||
|
||||
audio = load_audio(filename)
|
||||
for n in range(args.inflation_factor):
|
||||
altered = noise_data.noised_audio(audio, noise_min + (noise_max - noise_min) * random())
|
||||
output_filename = translate_filename(filename, n)
|
||||
|
||||
makedirs(dirname(output_filename), exist_ok=True)
|
||||
save_audio(output_filename, altered)
|
||||
|
||||
print('Done!')
|
||||
|
||||
if args.tags_file and args.tags_file.startswith(args.folder):
|
||||
shutil.copy2(args.tags_file, translate_filename(args.tags_file))
|
||||
|
||||
|
||||
main = AddNoiseScript.run_main
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
|
|
@ -0,0 +1,51 @@
|
|||
from abc import abstractmethod
|
||||
from argparse import ArgumentParser, Namespace
|
||||
from prettyparse import Usage
|
||||
|
||||
|
||||
class BaseScript:
|
||||
"""A class to standardize the way scripts are defined"""
|
||||
usage = Usage()
|
||||
|
||||
def __init__(self, args):
|
||||
self.args = args
|
||||
|
||||
@classmethod
|
||||
def create(cls, **args):
|
||||
values = {}
|
||||
for arg_name, arg_data in cls.usage.arguments.items():
|
||||
if arg_name in args:
|
||||
values[arg_name] = args.pop(arg_name)
|
||||
else:
|
||||
if 'default' not in arg_data and arg_name and not arg_data['_0'].startswith('-'):
|
||||
raise TypeError('Calling script without required "{}" argument.'.format(arg_name))
|
||||
typ = arg_data.get('type')
|
||||
if arg_data.get('action', '').startswith('store_') and not typ:
|
||||
typ = bool
|
||||
if not typ:
|
||||
typ = lambda x: x
|
||||
values[arg_name] = typ(arg_data.get('default'))
|
||||
args = Namespace(**values)
|
||||
cls.usage.render_args(args)
|
||||
return cls(args)
|
||||
|
||||
@abstractmethod
|
||||
def run(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def run_main(cls):
|
||||
parser = ArgumentParser()
|
||||
cls.usage.apply(parser)
|
||||
args = cls.usage.render_args(parser.parse_args())
|
||||
|
||||
try:
|
||||
script = cls(args)
|
||||
except ValueError as e:
|
||||
parser.error('Error parsing args: ' + str(e))
|
||||
raise SystemExit(1)
|
||||
|
||||
try:
|
||||
script.run()
|
||||
except KeyboardInterrupt:
|
||||
print()
|
|
@ -12,76 +12,79 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License
|
||||
from math import atan, tan, log, exp, sqrt, pi
|
||||
from math import sqrt
|
||||
|
||||
import math
|
||||
|
||||
from functools import partial
|
||||
from os.path import basename, splitext
|
||||
|
||||
from prettyparse import create_parser
|
||||
from prettyparse import Usage
|
||||
|
||||
from precise.params import inject_params, save_params
|
||||
from precise.scripts.base_script import BaseScript
|
||||
from precise.stats import Stats
|
||||
|
||||
usage = '''
|
||||
Update the threshold values of a model for a dataset.
|
||||
This makes the sensitivity more accurate and linear
|
||||
|
||||
:model str
|
||||
Either Keras (.net) or TensorFlow (.pb) model to adjust
|
||||
class CalcThresholdScript(BaseScript):
|
||||
usage = Usage('''
|
||||
Update the threshold values of a model for a dataset.
|
||||
This makes the sensitivity more accurate and linear
|
||||
|
||||
:input_file str
|
||||
Input stats file that was outputted from precise-graph
|
||||
|
||||
:-k --model-key str -
|
||||
Custom model name to use from the stats.json
|
||||
|
||||
:-s --smoothing float 1.2
|
||||
Amount of extra smoothing to apply
|
||||
|
||||
:-c --center float 0.2
|
||||
Decoded threshold that is mapped to 0.5. Proportion of
|
||||
false negatives at sensitivity=0.5
|
||||
'''
|
||||
:model str
|
||||
Either Keras (.net) or TensorFlow (.pb) model to adjust
|
||||
|
||||
:input_file str
|
||||
Input stats file that was outputted from precise-graph
|
||||
|
||||
:-k --model-key str -
|
||||
Custom model name to use from the stats.json
|
||||
|
||||
:-s --smoothing float 1.2
|
||||
Amount of extra smoothing to apply
|
||||
|
||||
:-c --center float 0.2
|
||||
Decoded threshold that is mapped to 0.5. Proportion of
|
||||
false negatives at sensitivity=0.5
|
||||
''')
|
||||
|
||||
def __init__(self, args):
|
||||
super().__init__(args)
|
||||
|
||||
def run(self):
|
||||
args = self.args
|
||||
import numpy as np
|
||||
|
||||
model_data = {
|
||||
name: Stats.from_np_dict(data) for name, data in np.load(args.input_file)['data'].item().items()
|
||||
}
|
||||
model_name = args.model_key or basename(splitext(args.model)[0])
|
||||
|
||||
if model_name not in model_data:
|
||||
print("Could not find model '{}' in saved models in stats file: {}".format(model_name, list(model_data)))
|
||||
raise SystemExit(1)
|
||||
|
||||
stats = model_data[model_name]
|
||||
|
||||
save_spots = (stats.outputs != 0) & (stats.outputs != 1)
|
||||
if save_spots.sum() == 0:
|
||||
print('No data (or all NaN)')
|
||||
return
|
||||
|
||||
stats.outputs = stats.outputs[save_spots]
|
||||
stats.targets = stats.targets[save_spots]
|
||||
inv = -np.log(1 / stats.outputs - 1)
|
||||
|
||||
pos = np.extract(stats.targets > 0.5, inv)
|
||||
pos_mu = pos.mean().item()
|
||||
pos_std = sqrt(np.mean((pos - pos_mu) ** 2)) * args.smoothing
|
||||
|
||||
print('Peak: {:.2f} mu, {:.2f} std'.format(pos_mu, pos_std))
|
||||
pr = inject_params(args.model)
|
||||
pr.__dict__.update(threshold_config=(
|
||||
(pos_mu, pos_std),
|
||||
))
|
||||
save_params(args.model)
|
||||
print('Saved params to {}.params'.format(args.model))
|
||||
|
||||
|
||||
def main():
|
||||
args = create_parser(usage).parse_args()
|
||||
import numpy as np
|
||||
|
||||
model_data = {
|
||||
name: Stats.from_np_dict(data) for name, data in np.load(args.input_file)['data'].item().items()
|
||||
}
|
||||
model_name = args.model_key or basename(splitext(args.model)[0])
|
||||
|
||||
if model_name not in model_data:
|
||||
print("Could not find model '{}' in saved models in stats file: {}".format(model_name, list(model_data)))
|
||||
raise SystemExit(1)
|
||||
|
||||
stats = model_data[model_name]
|
||||
|
||||
save_spots = (stats.outputs != 0) & (stats.outputs != 1)
|
||||
if save_spots.sum() == 0:
|
||||
print('No data (or all NaN)')
|
||||
return
|
||||
|
||||
stats.outputs = stats.outputs[save_spots]
|
||||
stats.targets = stats.targets[save_spots]
|
||||
inv = -np.log(1 / stats.outputs - 1)
|
||||
|
||||
pos = np.extract(stats.targets > 0.5, inv)
|
||||
pos_mu = pos.mean().item()
|
||||
pos_std = sqrt(np.mean((pos - pos_mu) ** 2)) * args.smoothing
|
||||
|
||||
print('Peak: {:.2f} mu, {:.2f} std'.format(pos_mu, pos_std))
|
||||
pr = inject_params(args.model)
|
||||
pr.__dict__.update(threshold_config=(
|
||||
(pos_mu, pos_std),
|
||||
))
|
||||
save_params(args.model)
|
||||
print('Saved params to {}.params'.format(args.model))
|
||||
|
||||
main = CalcThresholdScript.run_main
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
|
|
@ -16,52 +16,13 @@ from select import select
|
|||
from sys import stdin
|
||||
from termios import tcsetattr, tcgetattr, TCSADRAIN
|
||||
|
||||
import pyaudio
|
||||
import tty
|
||||
import wave
|
||||
from os.path import isfile
|
||||
from prettyparse import create_parser
|
||||
from prettyparse import Usage
|
||||
from pyaudio import PyAudio
|
||||
|
||||
usage = '''
|
||||
Record audio samples for use with precise
|
||||
|
||||
:-w --width int 2
|
||||
Sample width of audio
|
||||
|
||||
:-r --rate int 16000
|
||||
Sample rate of audio
|
||||
|
||||
:-c --channels int 1
|
||||
Number of audio channels
|
||||
'''
|
||||
|
||||
|
||||
def key_pressed():
|
||||
return select([stdin], [], [], 0) == ([stdin], [], [])
|
||||
|
||||
|
||||
def termios_wrapper(main):
|
||||
global orig_settings
|
||||
orig_settings = tcgetattr(stdin)
|
||||
try:
|
||||
hide_input()
|
||||
main()
|
||||
finally:
|
||||
tcsetattr(stdin, TCSADRAIN, orig_settings)
|
||||
|
||||
|
||||
def show_input():
|
||||
tcsetattr(stdin, TCSADRAIN, orig_settings)
|
||||
|
||||
|
||||
def hide_input():
|
||||
tty.setcbreak(stdin.fileno())
|
||||
|
||||
|
||||
orig_settings = None
|
||||
|
||||
RECORD_KEY = ' '
|
||||
EXIT_KEY_CODE = 27
|
||||
from precise.scripts.base_script import BaseScript
|
||||
|
||||
|
||||
def record_until(p, should_return, args):
|
||||
|
@ -88,74 +49,103 @@ def save_audio(name, data, args):
|
|||
wf.close()
|
||||
|
||||
|
||||
def next_name(name):
|
||||
name += '.wav'
|
||||
pos, num_digits = None, None
|
||||
try:
|
||||
pos = name.index('#')
|
||||
num_digits = name.count('#')
|
||||
except ValueError:
|
||||
print("Name must contain at least one # to indicate where to put the number.")
|
||||
raise
|
||||
class CollectScript(BaseScript):
|
||||
RECORD_KEY = ' '
|
||||
EXIT_KEY_CODE = 27
|
||||
|
||||
def get_name(i):
|
||||
nonlocal name, pos
|
||||
return name[:pos] + str(i).zfill(num_digits) + name[pos + num_digits:]
|
||||
usage = Usage('''
|
||||
Record audio samples for use with precise
|
||||
|
||||
i = 0
|
||||
while True:
|
||||
if not isfile(get_name(i)):
|
||||
break
|
||||
i += 1
|
||||
:-w --width int 2
|
||||
Sample width of audio
|
||||
|
||||
return get_name(i)
|
||||
:-r --rate int 16000
|
||||
Sample rate of audio
|
||||
|
||||
:-c --channels int 1
|
||||
Number of audio channels
|
||||
''')
|
||||
usage.add_argument('file_label', nargs='?', help='File label (Ex. recording-##)')
|
||||
|
||||
def __init__(self, args):
|
||||
super().__init__(args)
|
||||
self.orig_settings = tcgetattr(stdin)
|
||||
self.p = PyAudio()
|
||||
|
||||
def key_pressed(self):
|
||||
return select([stdin], [], [], 0) == ([stdin], [], [])
|
||||
|
||||
def show_input(self):
|
||||
tcsetattr(stdin, TCSADRAIN, self.orig_settings)
|
||||
|
||||
def hide_input(self):
|
||||
tty.setcbreak(stdin.fileno())
|
||||
|
||||
def next_name(self, name):
|
||||
name += '.wav'
|
||||
pos, num_digits = None, None
|
||||
try:
|
||||
pos = name.index('#')
|
||||
num_digits = name.count('#')
|
||||
except ValueError:
|
||||
print("Name must contain at least one # to indicate where to put the number.")
|
||||
raise
|
||||
|
||||
def get_name(i):
|
||||
nonlocal name, pos
|
||||
return name[:pos] + str(i).zfill(num_digits) + name[pos + num_digits:]
|
||||
|
||||
i = 0
|
||||
while True:
|
||||
if not isfile(get_name(i)):
|
||||
break
|
||||
i += 1
|
||||
|
||||
return get_name(i)
|
||||
|
||||
def wait_to_continue(self):
|
||||
while True:
|
||||
c = stdin.read(1)
|
||||
if c == self.RECORD_KEY:
|
||||
return True
|
||||
elif ord(c) == self.EXIT_KEY_CODE:
|
||||
return False
|
||||
|
||||
def record_until_key(self):
|
||||
def should_return():
|
||||
return self.key_pressed() and stdin.read(1) == self.RECORD_KEY
|
||||
|
||||
return record_until(self.p, should_return, self.args)
|
||||
|
||||
def _run(self):
|
||||
args = self.args
|
||||
self.show_input()
|
||||
args.file_label = args.file_label or input("File label (Ex. recording-##): ")
|
||||
args.file_label = args.file_label + ('' if '#' in args.file_label else '-##')
|
||||
self.hide_input()
|
||||
|
||||
while True:
|
||||
print('Press space to record (esc to exit)...')
|
||||
|
||||
if not self.wait_to_continue():
|
||||
break
|
||||
|
||||
print('Recording...')
|
||||
d = self.record_until_key()
|
||||
name = self.next_name(args.file_label)
|
||||
save_audio(name, d, args)
|
||||
print('Saved as ' + name)
|
||||
|
||||
def run(self):
|
||||
try:
|
||||
self.hide_input()
|
||||
self._run()
|
||||
finally:
|
||||
tcsetattr(stdin, TCSADRAIN, self.orig_settings)
|
||||
self.p.terminate()
|
||||
|
||||
|
||||
def wait_to_continue():
|
||||
while True:
|
||||
c = stdin.read(1)
|
||||
if c == RECORD_KEY:
|
||||
return True
|
||||
elif ord(c) == EXIT_KEY_CODE:
|
||||
return False
|
||||
|
||||
|
||||
def record_until_key(p, args):
|
||||
def should_return():
|
||||
return key_pressed() and stdin.read(1) == RECORD_KEY
|
||||
|
||||
return record_until(p, should_return, args)
|
||||
|
||||
|
||||
def _main():
|
||||
parser = create_parser(usage)
|
||||
parser.add_argument('file_label', nargs='?', help='File label (Ex. recording-##)')
|
||||
args = parser.parse_args()
|
||||
show_input()
|
||||
args.file_label = args.file_label or input("File label (Ex. recording-##): ")
|
||||
args.file_label = args.file_label + ('' if '#' in args.file_label else '-##')
|
||||
hide_input()
|
||||
|
||||
p = pyaudio.PyAudio()
|
||||
|
||||
while True:
|
||||
print('Press space to record (esc to exit)...')
|
||||
|
||||
if not wait_to_continue():
|
||||
break
|
||||
|
||||
print('Recording...')
|
||||
d = record_until_key(p, args)
|
||||
name = next_name(args.file_label)
|
||||
save_audio(name, d, args)
|
||||
print('Saved as ' + name)
|
||||
|
||||
p.terminate()
|
||||
|
||||
|
||||
def main():
|
||||
termios_wrapper(_main)
|
||||
|
||||
main = CollectScript.run_main
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
|
|
@ -15,73 +15,76 @@
|
|||
# limitations under the License.
|
||||
import os
|
||||
from os.path import split, isfile
|
||||
from prettyparse import create_parser
|
||||
from prettyparse import Usage
|
||||
from shutil import copyfile
|
||||
|
||||
usage = '''
|
||||
Convert wake word model from Keras to TensorFlow
|
||||
|
||||
:model str
|
||||
Input Keras model (.net)
|
||||
|
||||
:-o --out str {model}.pb
|
||||
Custom output TensorFlow protobuf filename
|
||||
'''
|
||||
from precise.scripts.base_script import BaseScript
|
||||
|
||||
|
||||
def convert(model_path: str, out_file: str):
|
||||
"""
|
||||
Converts an HD5F file from Keras to a .pb for use with TensorFlow
|
||||
class ConvertScript(BaseScript):
|
||||
usage = Usage('''
|
||||
Convert wake word model from Keras to TensorFlow
|
||||
|
||||
Args:
|
||||
model_path: location of Keras model
|
||||
out_file: location to write protobuf
|
||||
"""
|
||||
print('Converting', model_path, 'to', out_file, '...')
|
||||
:model str
|
||||
Input Keras model (.net)
|
||||
|
||||
import tensorflow as tf
|
||||
from precise.model import load_precise_model
|
||||
from keras import backend as K
|
||||
:-o --out str {model}.pb
|
||||
Custom output TensorFlow protobuf filename
|
||||
''')
|
||||
|
||||
out_dir, filename = split(out_file)
|
||||
out_dir = out_dir or '.'
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
def run(self):
|
||||
args = self.args
|
||||
model_name = args.model.replace('.net', '')
|
||||
self.convert(args.model, args.out.format(model=model_name))
|
||||
|
||||
K.set_learning_phase(0)
|
||||
model = load_precise_model(model_path)
|
||||
def convert(self, model_path: str, out_file: str):
|
||||
"""
|
||||
Converts an HD5F file from Keras to a .pb for use with TensorFlow
|
||||
|
||||
out_name = 'net_output'
|
||||
tf.identity(model.output, name=out_name)
|
||||
print('Output node name:', out_name)
|
||||
print('Output folder:', out_dir)
|
||||
Args:
|
||||
model_path: location of Keras model
|
||||
out_file: location to write protobuf
|
||||
"""
|
||||
print('Converting', model_path, 'to', out_file, '...')
|
||||
|
||||
sess = K.get_session()
|
||||
import tensorflow as tf
|
||||
from precise.model import load_precise_model
|
||||
from keras import backend as K
|
||||
|
||||
# Write the graph in human readable
|
||||
tf.train.write_graph(sess.graph.as_graph_def(), out_dir, filename + 'txt', as_text=True)
|
||||
print('Saved readable graph to:', filename + 'txt')
|
||||
out_dir, filename = split(out_file)
|
||||
out_dir = out_dir or '.'
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
|
||||
# Write the graph in binary .pb file
|
||||
from tensorflow.python.framework import graph_util
|
||||
from tensorflow.python.framework import graph_io
|
||||
K.set_learning_phase(0)
|
||||
model = load_precise_model(model_path)
|
||||
|
||||
cgraph = graph_util.convert_variables_to_constants(sess, sess.graph.as_graph_def(), [out_name])
|
||||
graph_io.write_graph(cgraph, out_dir, filename, as_text=False)
|
||||
out_name = 'net_output'
|
||||
tf.identity(model.output, name=out_name)
|
||||
print('Output node name:', out_name)
|
||||
print('Output folder:', out_dir)
|
||||
|
||||
if isfile(model_path + '.params'):
|
||||
copyfile(model_path + '.params', out_file + '.params')
|
||||
sess = K.get_session()
|
||||
|
||||
print('Saved graph to:', filename)
|
||||
# Write the graph in human readable
|
||||
tf.train.write_graph(sess.graph.as_graph_def(), out_dir, filename + 'txt', as_text=True)
|
||||
print('Saved readable graph to:', filename + 'txt')
|
||||
|
||||
del sess
|
||||
# Write the graph in binary .pb file
|
||||
from tensorflow.python.framework import graph_util
|
||||
from tensorflow.python.framework import graph_io
|
||||
|
||||
cgraph = graph_util.convert_variables_to_constants(sess, sess.graph.as_graph_def(), [out_name])
|
||||
graph_io.write_graph(cgraph, out_dir, filename, as_text=False)
|
||||
|
||||
if isfile(model_path + '.params'):
|
||||
copyfile(model_path + '.params', out_file + '.params')
|
||||
|
||||
print('Saved graph to:', filename)
|
||||
|
||||
del sess
|
||||
|
||||
|
||||
def main():
|
||||
args = create_parser(usage).parse_args()
|
||||
|
||||
model_name = args.model.replace('.net', '')
|
||||
convert(args.model, args.out.format(model=model_name))
|
||||
|
||||
main = ConvertScript.run_main
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
|
|
@ -15,50 +15,56 @@
|
|||
import sys
|
||||
|
||||
import os
|
||||
from prettyparse import create_parser
|
||||
from prettyparse import Usage
|
||||
|
||||
from precise import __version__
|
||||
from precise.network_runner import Listener
|
||||
|
||||
usage = '''
|
||||
stdin should be a stream of raw int16 audio, written in
|
||||
groups of CHUNK_SIZE samples. If no CHUNK_SIZE is given
|
||||
it will read until EOF. For every chunk, an inference
|
||||
will be given via stdout as a float string, one per line
|
||||
|
||||
:model_name str
|
||||
Keras or TensorFlow model to read from
|
||||
|
||||
...
|
||||
'''
|
||||
from precise.scripts.base_script import BaseScript
|
||||
|
||||
|
||||
def main():
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
|
||||
stdout = sys.stdout
|
||||
sys.stdout = sys.stderr
|
||||
|
||||
parser = create_parser(usage)
|
||||
parser.add_argument('-v', '--version', action='version', version=__version__)
|
||||
parser.add_argument('chunk_size', type=int, nargs='?', default=-1,
|
||||
help='Number of bytes to read before making a prediction.'
|
||||
'Higher values are less computationally expensive')
|
||||
def add_audio_pipe_to_parser(parser):
|
||||
parser.usage = parser.format_usage().strip().replace('usage: ', '') + ' < audio.wav'
|
||||
args = parser.parse_args()
|
||||
|
||||
if sys.stdin.isatty():
|
||||
parser.error('Please pipe audio via stdin using < audio.wav')
|
||||
|
||||
listener = Listener(args.model_name, args.chunk_size)
|
||||
class EngineScript(BaseScript):
|
||||
usage = Usage('''
|
||||
stdin should be a stream of raw int16 audio, written in
|
||||
groups of CHUNK_SIZE samples. If no CHUNK_SIZE is given
|
||||
it will read until EOF. For every chunk, an inference
|
||||
will be given via stdout as a float string, one per line
|
||||
|
||||
try:
|
||||
while True:
|
||||
conf = listener.update(sys.stdin.buffer)
|
||||
stdout.buffer.write((str(conf) + '\n').encode('ascii'))
|
||||
stdout.buffer.flush()
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
pass
|
||||
:model_name str
|
||||
Keras or TensorFlow model to read from
|
||||
|
||||
...
|
||||
''')
|
||||
usage.add_argument('-v', '--version', action='version', version=__version__)
|
||||
usage.add_argument('chunk_size', type=int, nargs='?', default=-1,
|
||||
help='Number of bytes to read before making a prediction. '
|
||||
'Higher values are less computationally expensive')
|
||||
usage.add_customizer(add_audio_pipe_to_parser)
|
||||
|
||||
def __init__(self, args):
|
||||
super().__init__(args)
|
||||
if sys.stdin.isatty():
|
||||
raise ValueError('Please pipe audio via stdin using < audio.wav')
|
||||
|
||||
def run(self):
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
|
||||
stdout = sys.stdout
|
||||
sys.stdout = sys.stderr
|
||||
listener = Listener(self.args.model_name, self.args.chunk_size)
|
||||
|
||||
try:
|
||||
while True:
|
||||
conf = listener.update(sys.stdin.buffer)
|
||||
stdout.buffer.write((str(conf) + '\n').encode('ascii'))
|
||||
stdout.buffer.flush()
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
pass
|
||||
|
||||
|
||||
main = EngineScript.run_main
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
|
|
@ -14,99 +14,110 @@
|
|||
# limitations under the License.
|
||||
import json
|
||||
from os.path import isfile, isdir
|
||||
from prettyparse import create_parser
|
||||
from prettyparse import Usage
|
||||
|
||||
from precise.network_runner import Listener
|
||||
from precise.params import inject_params
|
||||
from precise.pocketsphinx.listener import PocketsphinxListener
|
||||
from precise.pocketsphinx.scripts.test import test_pocketsphinx
|
||||
from precise.pocketsphinx.scripts.test import PocketsphinxTestScript
|
||||
from precise.scripts.base_script import BaseScript
|
||||
from precise.stats import Stats
|
||||
from precise.train_data import TrainData
|
||||
|
||||
usage = '''
|
||||
Evaluate a list of models on a dataset
|
||||
|
||||
:-u --use-train
|
||||
Evaluate training data instead of test data
|
||||
|
||||
:-t --threshold floaat 0.5
|
||||
Network output to be considered an activation
|
||||
|
||||
:-pw --pocketsphinx-wake-word str -
|
||||
Optional wake word used to
|
||||
generate a Pocketsphinx data point
|
||||
|
||||
:-pd --pocketsphinx-dict str -
|
||||
Optional word dictionary used to
|
||||
generate a Pocketsphinx data point
|
||||
Format: wake-word.yy-mm-dd.dict
|
||||
|
||||
:-pf --pocketsphinx-folder str -
|
||||
Optional hmm folder used to
|
||||
generate a Pocketsphinx data point.
|
||||
|
||||
:-pth --pocketsphinx-threshold float 1e-90
|
||||
Optional threshold used to
|
||||
generate a Pocketsphinx data point
|
||||
|
||||
:-o --output str stats.json
|
||||
Output json file
|
||||
|
||||
...
|
||||
'''
|
||||
|
||||
class EvalScript(BaseScript):
|
||||
usage = Usage('''
|
||||
Evaluate a list of models on a dataset
|
||||
|
||||
:-u --use-train
|
||||
Evaluate training data instead of test data
|
||||
|
||||
:-t --threshold float 0.5
|
||||
Network output to be considered an activation
|
||||
|
||||
:-pw --pocketsphinx-wake-word str -
|
||||
Optional wake word used to
|
||||
generate a Pocketsphinx data point
|
||||
|
||||
:-pd --pocketsphinx-dict str -
|
||||
Optional word dictionary used to
|
||||
generate a Pocketsphinx data point
|
||||
Format = wake-word.yy-mm-dd.dict
|
||||
|
||||
:-pf --pocketsphinx-folder str -
|
||||
Optional hmm folder used to
|
||||
generate a Pocketsphinx data point.
|
||||
|
||||
:-pth --pocketsphinx-threshold float 1e-90
|
||||
Optional threshold used to
|
||||
generate a Pocketsphinx data point
|
||||
|
||||
:-o --output str stats.json
|
||||
Output json file
|
||||
|
||||
...
|
||||
''')
|
||||
usage.add_argument('models', nargs='*',
|
||||
help='List of model filenames in format: wake-word.yy-mm-dd.net')
|
||||
usage |= TrainData.usage
|
||||
|
||||
def __init__(self, args):
|
||||
super().__init__(args)
|
||||
if not (
|
||||
bool(args.pocketsphinx_dict) ==
|
||||
bool(args.pocketsphinx_folder) ==
|
||||
bool(args.pocketsphinx_wake_word)
|
||||
):
|
||||
raise ValueError('Must pass all or no Pocketsphinx arguments')
|
||||
self.is_pocketsphinx = bool(args.pocketsphinx_dict)
|
||||
|
||||
if self.is_pocketsphinx:
|
||||
if not isfile(args.pocketsphinx_dict):
|
||||
raise ValueError('No such file: ' + args.pocketsphinx_dict)
|
||||
if not isdir(args.pocketsphinx_folder):
|
||||
raise ValueError('No such folder: ' + args.pocketsphinx_folder)
|
||||
|
||||
def run(self):
|
||||
args = self.args
|
||||
data = TrainData.from_both(args.tags_file, args.tags_folder, args.folder)
|
||||
data_files = data.train_files if args.use_train else data.test_files
|
||||
print('Data:', data)
|
||||
|
||||
metrics = {}
|
||||
|
||||
if self.is_pocketsphinx:
|
||||
script = PocketsphinxTestScript.create(
|
||||
key_phrase=args.pocketsphinx_wake_word, dict_file=args.pocketsphinx_dict,
|
||||
hmm_folder=args.pocketsphinx_folder, threshold=args.pocketsphinx_threshold
|
||||
)
|
||||
ww_files, nww_files = data_files
|
||||
script.run_test(ww_files, 'Wake Word', 1.0)
|
||||
script.run_test(nww_files, 'Not Wake Word', 0.0)
|
||||
stats = script.get_stats()
|
||||
metrics[args.pocketsphinx_dict] = stats.to_dict(args.threshold)
|
||||
|
||||
for model_name in args.models:
|
||||
print('Calculating', model_name + '...')
|
||||
inject_params(model_name)
|
||||
|
||||
train, test = data.load(args.use_train, not args.use_train)
|
||||
inputs, targets = train if args.use_train else test
|
||||
predictions = Listener.find_runner(model_name)(model_name).predict(inputs)
|
||||
|
||||
stats = Stats(predictions, targets, sum(data_files, []))
|
||||
|
||||
print('----', model_name, '----')
|
||||
print(stats.counts_str())
|
||||
print()
|
||||
print(stats.summary_str())
|
||||
print()
|
||||
metrics[model_name] = stats.to_dict(args.threshold)
|
||||
|
||||
print('Writing to:', args.output)
|
||||
with open(args.output, 'w') as f:
|
||||
json.dump(metrics, f)
|
||||
|
||||
|
||||
def main():
|
||||
parser = create_parser(usage)
|
||||
parser.add_argument('models', nargs='*',
|
||||
help='List of model filenames in format: wake-word.yy-mm-dd.net')
|
||||
args = TrainData.parse_args(parser)
|
||||
if not (
|
||||
bool(args.pocketsphinx_dict) ==
|
||||
bool(args.pocketsphinx_folder) ==
|
||||
bool(args.pocketsphinx_wake_word)
|
||||
):
|
||||
parser.error('Must pass all or no Pocketsphinx arguments')
|
||||
|
||||
data = TrainData.from_both(args.tags_file, args.tags_folder, args.folder)
|
||||
data_files = data.train_files if args.use_train else data.test_files
|
||||
print('Data:', data)
|
||||
|
||||
metrics = {}
|
||||
|
||||
if args.pocketsphinx_dict and args.pocketsphinx_folder and args.pocketsphinx_wake_word:
|
||||
if not isfile(args.pocketsphinx_dict):
|
||||
parser.error('No such file: ' + args.pocketsphinx_dict)
|
||||
if not isdir(args.pocketsphinx_folder):
|
||||
parser.error('No such folder: ' + args.pocketsphinx_folder)
|
||||
listener = PocketsphinxListener(
|
||||
args.pocketsphinx_wake_word, args.pocketsphinx_dict,
|
||||
args.pocketsphinx_folder, args.pocketsphinx_threshold
|
||||
)
|
||||
stats = test_pocketsphinx(listener, data_files)
|
||||
metrics[args.pocketsphinx_dict] = stats_to_dict(stats)
|
||||
|
||||
for model_name in args.models:
|
||||
print('Calculating', model_name + '...')
|
||||
inject_params(model_name)
|
||||
|
||||
train, test = data.load(args.use_train, not args.use_train)
|
||||
inputs, targets = train if args.use_train else test
|
||||
predictions = Listener.find_runner(model_name)(model_name).predict(inputs)
|
||||
|
||||
stats = Stats(predictions, targets, sum(data_files, []))
|
||||
|
||||
print('----', model_name, '----')
|
||||
print(stats.counts_str())
|
||||
print()
|
||||
print(stats.summary_str())
|
||||
print()
|
||||
metrics[model_name] = stats.to_dict(args.threshold)
|
||||
|
||||
print('Writing to:', args.output)
|
||||
with open(args.output, 'w') as f:
|
||||
json.dump(metrics, f)
|
||||
|
||||
main = EvalScript.run_main
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
|
|
@ -12,47 +12,19 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License
|
||||
import numpy as np
|
||||
from functools import partial
|
||||
from os.path import basename, splitext
|
||||
|
||||
from prettyparse import create_parser
|
||||
from prettyparse import Usage
|
||||
from typing import Callable, Tuple
|
||||
|
||||
from precise.network_runner import Listener
|
||||
from precise.params import inject_params, pr
|
||||
from precise.scripts.base_script import BaseScript
|
||||
from precise.stats import Stats
|
||||
from precise.threshold_decoder import ThresholdDecoder
|
||||
from precise.train_data import TrainData
|
||||
|
||||
usage = '''
|
||||
Show ROC curves for a series of models
|
||||
|
||||
...
|
||||
|
||||
:-t --use-train
|
||||
Evaluate training data instead of test data
|
||||
|
||||
:-nf --no-filenames
|
||||
Don't print out the names of files that failed
|
||||
|
||||
:-r --resolution int 100
|
||||
Number of points to generate
|
||||
|
||||
:-p --power float 3.0
|
||||
Power of point distribution
|
||||
|
||||
:-l --labels
|
||||
Print labels attached to each point
|
||||
|
||||
:-o --output-file str -
|
||||
File to write data instead of displaying it
|
||||
|
||||
:-i --input-file str -
|
||||
File to read data from and visualize
|
||||
|
||||
...
|
||||
'''
|
||||
|
||||
|
||||
def get_thresholds(points=100, power=3) -> list:
|
||||
"""Run a function with a series of thresholds between 0 and 1"""
|
||||
|
@ -108,53 +80,86 @@ def calc_stats(model_files, loader, use_train, filenames):
|
|||
return model_data
|
||||
|
||||
|
||||
def main():
|
||||
parser = create_parser(usage)
|
||||
parser.add_argument('models', nargs='*', help='Either Keras (.net) or TensorFlow (.pb) models to test')
|
||||
args = TrainData.parse_args(parser)
|
||||
if not args.models and not args.input_file and args.folder:
|
||||
args.input_file = args.folder
|
||||
if bool(args.models) == bool(args.input_file):
|
||||
parser.error('Please specify either a list of models or an input file')
|
||||
class GraphScript(BaseScript):
|
||||
usage = Usage('''
|
||||
Show ROC curves for a series of models
|
||||
|
||||
if not args.output_file:
|
||||
load_plt() # Error early if matplotlib not installed
|
||||
import numpy as np
|
||||
...
|
||||
|
||||
if args.models:
|
||||
data = TrainData.from_both(args.tags_file, args.tags_folder, args.folder)
|
||||
print('Data:', data)
|
||||
filenames = sum(data.train_files if args.use_train else data.test_files, [])
|
||||
loader = CachedDataLoader(partial(
|
||||
data.load, args.use_train, not args.use_train, shuffle=False
|
||||
))
|
||||
model_data = calc_stats(args.models, loader, args.use_train, filenames)
|
||||
else:
|
||||
model_data = {
|
||||
name: Stats.from_np_dict(data) for name, data in np.load(args.input_file)['data'].item().items()
|
||||
}
|
||||
for name, stats in model_data.items():
|
||||
print('=== {} ===\n{}\n\n{}\n'.format(name, stats.counts_str(), stats.summary_str()))
|
||||
:-t --use-train
|
||||
Evaluate training data instead of test data
|
||||
|
||||
if args.output_file:
|
||||
np.savez(args.output_file, data={name: stats.to_np_dict() for name, stats in model_data.items()})
|
||||
else:
|
||||
plt = load_plt()
|
||||
decoder = ThresholdDecoder(pr.threshold_config, pr.threshold_center)
|
||||
thresholds = [decoder.encode(i) for i in np.linspace(0.0, 1.0, args.resolution)[1:-1]]
|
||||
for model_name, stats in model_data.items():
|
||||
x = [stats.false_positives(i) for i in thresholds]
|
||||
y = [stats.false_negatives(i) for i in thresholds]
|
||||
plt.plot(x, y, marker='x', linestyle='-', label=model_name)
|
||||
if args.labels:
|
||||
for x, y, threshold in zip(x, y, thresholds):
|
||||
plt.annotate('{:.4f}'.format(threshold), (x, y))
|
||||
:-nf --no-filenames
|
||||
Don't print out the names of files that failed
|
||||
|
||||
plt.legend()
|
||||
plt.xlabel('False Positives')
|
||||
plt.ylabel('False Negatives')
|
||||
plt.show()
|
||||
:-r --resolution int 100
|
||||
Number of points to generate
|
||||
|
||||
:-p --power float 3.0
|
||||
Power of point distribution
|
||||
|
||||
:-l --labels
|
||||
Print labels attached to each point
|
||||
|
||||
:-o --output-file str -
|
||||
File to write data instead of displaying it
|
||||
|
||||
:-i --input-file str -
|
||||
File to read data from and visualize
|
||||
|
||||
...
|
||||
''')
|
||||
usage.add_argument('models', nargs='*', help='Either Keras (.net) or TensorFlow (.pb) models to test')
|
||||
usage |= TrainData.usage
|
||||
|
||||
def __init__(self, args):
|
||||
super().__init__(args)
|
||||
if not args.models and not args.input_file and args.folder:
|
||||
args.input_file = args.folder
|
||||
if bool(args.models) == bool(args.input_file):
|
||||
raise ValueError('Please specify either a list of models or an input file')
|
||||
|
||||
if not args.output_file:
|
||||
load_plt() # Error early if matplotlib not installed
|
||||
|
||||
def run(self):
|
||||
args = self.args
|
||||
if args.models:
|
||||
data = TrainData.from_both(args.tags_file, args.tags_folder, args.folder)
|
||||
print('Data:', data)
|
||||
filenames = sum(data.train_files if args.use_train else data.test_files, [])
|
||||
loader = CachedDataLoader(partial(
|
||||
data.load, args.use_train, not args.use_train, shuffle=False
|
||||
))
|
||||
model_data = calc_stats(args.models, loader, args.use_train, filenames)
|
||||
else:
|
||||
model_data = {
|
||||
name: Stats.from_np_dict(data) for name, data in np.load(args.input_file)['data'].item().items()
|
||||
}
|
||||
for name, stats in model_data.items():
|
||||
print('=== {} ===\n{}\n\n{}\n'.format(name, stats.counts_str(), stats.summary_str()))
|
||||
|
||||
if args.output_file:
|
||||
np.savez(args.output_file, data={name: stats.to_np_dict() for name, stats in model_data.items()})
|
||||
else:
|
||||
plt = load_plt()
|
||||
decoder = ThresholdDecoder(pr.threshold_config, pr.threshold_center)
|
||||
thresholds = [decoder.encode(i) for i in np.linspace(0.0, 1.0, args.resolution)[1:-1]]
|
||||
for model_name, stats in model_data.items():
|
||||
x = [stats.false_positives(i) for i in thresholds]
|
||||
y = [stats.false_negatives(i) for i in thresholds]
|
||||
plt.plot(x, y, marker='x', linestyle='-', label=model_name)
|
||||
if args.labels:
|
||||
for x, y, threshold in zip(x, y, thresholds):
|
||||
plt.annotate('{:.4f}'.format(threshold), (x, y))
|
||||
|
||||
plt.legend()
|
||||
plt.xlabel('False Positives')
|
||||
plt.ylabel('False Negatives')
|
||||
plt.show()
|
||||
|
||||
|
||||
main = GraphScript.run_main
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
|
|
@ -14,85 +14,86 @@
|
|||
# limitations under the License.
|
||||
import numpy as np
|
||||
from os.path import join
|
||||
from prettyparse import create_parser
|
||||
from precise_runner import PreciseRunner
|
||||
from precise_runner.runner import ListenerEngine
|
||||
from prettyparse import Usage
|
||||
from random import randint
|
||||
from shutil import get_terminal_size
|
||||
from threading import Event
|
||||
|
||||
from precise.network_runner import Listener
|
||||
from precise.scripts.base_script import BaseScript
|
||||
from precise.util import save_audio, buffer_to_audio, activate_notify
|
||||
from precise_runner import PreciseRunner
|
||||
from precise_runner.runner import ListenerEngine
|
||||
|
||||
usage = '''
|
||||
Run a model on microphone audio input
|
||||
|
||||
:model str
|
||||
Either Keras (.net) or TensorFlow (.pb) model to run
|
||||
|
||||
:-c --chunk-size int 2048
|
||||
Samples between inferences
|
||||
|
||||
:-l --trigger-level int 3
|
||||
Number of activated chunks to cause an activation
|
||||
|
||||
:-s --sensitivity float 0.5
|
||||
Network output required to be considered activated
|
||||
|
||||
:-b --basic-mode
|
||||
Report using . or ! rather than a visual representation
|
||||
|
||||
:-d --save-dir str -
|
||||
Folder to save false positives
|
||||
|
||||
:-p --save-prefix str -
|
||||
Prefix for saved filenames
|
||||
'''
|
||||
|
||||
session_id, chunk_num = '%09d' % randint(0, 999999999), 0
|
||||
|
||||
|
||||
def main():
|
||||
args = create_parser(usage).parse_args()
|
||||
class ListenScript(BaseScript):
|
||||
usage = Usage('''
|
||||
Run a model on microphone audio input
|
||||
|
||||
def on_activation():
|
||||
:model str
|
||||
Either Keras (.net) or TensorFlow (.pb) model to run
|
||||
|
||||
:-c --chunk-size int 2048
|
||||
Samples between inferences
|
||||
|
||||
:-l --trigger-level int 3
|
||||
Number of activated chunks to cause an activation
|
||||
|
||||
:-s --sensitivity float 0.5
|
||||
Network output required to be considered activated
|
||||
|
||||
:-b --basic-mode
|
||||
Report using . or ! rather than a visual representation
|
||||
|
||||
:-d --save-dir str -
|
||||
Folder to save false positives
|
||||
|
||||
:-p --save-prefix str -
|
||||
Prefix for saved filenames
|
||||
''')
|
||||
|
||||
def __init__(self, args):
|
||||
super().__init__(args)
|
||||
self.listener = Listener(args.model, args.chunk_size)
|
||||
self.audio_buffer = np.zeros(self.listener.pr.buffer_samples, dtype=float)
|
||||
self.engine = ListenerEngine(self.listener, args.chunk_size)
|
||||
self.engine.get_prediction = self.get_prediction
|
||||
self.runner = PreciseRunner(self.engine, args.trigger_level, sensitivity=args.sensitivity,
|
||||
on_activation=self.on_activation, on_prediction=self.on_prediction)
|
||||
self.session_id, self.chunk_num = '%09d' % randint(0, 999999999), 0
|
||||
|
||||
def on_activation(self):
|
||||
activate_notify()
|
||||
|
||||
if args.save_dir:
|
||||
global chunk_num
|
||||
nm = join(args.save_dir, args.save_prefix + session_id + '.' + str(chunk_num) + '.wav')
|
||||
save_audio(nm, audio_buffer)
|
||||
if self.args.save_dir:
|
||||
nm = join(self.args.save_dir, self.args.save_prefix + self.session_id + '.' + str(self.chunk_num) + '.wav')
|
||||
save_audio(nm, self.audio_buffer)
|
||||
print()
|
||||
print('Saved to ' + nm + '.')
|
||||
chunk_num += 1
|
||||
self.chunk_num += 1
|
||||
|
||||
def on_prediction(conf):
|
||||
if args.basic_mode:
|
||||
def on_prediction(self, conf):
|
||||
if self.args.basic_mode:
|
||||
print('!' if conf > 0.7 else '.', end='', flush=True)
|
||||
else:
|
||||
max_width = 80
|
||||
width = min(get_terminal_size()[0], max_width)
|
||||
units = int(round(conf * width))
|
||||
bar = 'X' * units + '-' * (width - units)
|
||||
cutoff = round((1.0 - args.sensitivity) * width)
|
||||
cutoff = round((1.0 - self.args.sensitivity) * width)
|
||||
print(bar[:cutoff] + bar[cutoff:].replace('X', 'x'))
|
||||
|
||||
listener = Listener(args.model, args.chunk_size)
|
||||
audio_buffer = np.zeros(listener.pr.buffer_samples, dtype=float)
|
||||
|
||||
def get_prediction(chunk):
|
||||
nonlocal audio_buffer
|
||||
def get_prediction(self, chunk):
|
||||
audio = buffer_to_audio(chunk)
|
||||
audio_buffer = np.concatenate((audio_buffer[len(audio):], audio))
|
||||
return listener.update(chunk)
|
||||
self.audio_buffer = np.concatenate((self.audio_buffer[len(audio):], audio))
|
||||
return self.listener.update(chunk)
|
||||
|
||||
engine = ListenerEngine(listener, args.chunk_size)
|
||||
engine.get_prediction = get_prediction
|
||||
runner = PreciseRunner(engine, args.trigger_level, sensitivity=args.sensitivity,
|
||||
on_activation=on_activation, on_prediction=on_prediction)
|
||||
runner.start()
|
||||
Event().wait() # Wait forever
|
||||
def run(self):
|
||||
self.runner.start()
|
||||
Event().wait() # Wait forever
|
||||
|
||||
|
||||
main = ListenScript.run_main
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
|
|
@ -16,30 +16,14 @@ import attr
|
|||
import numpy as np
|
||||
from glob import glob
|
||||
from os.path import join, basename
|
||||
from prettyparse import create_parser
|
||||
from precise_runner.runner import TriggerDetector
|
||||
from prettyparse import Usage
|
||||
|
||||
from precise.network_runner import Listener
|
||||
from precise.params import pr, inject_params
|
||||
from precise.scripts.base_script import BaseScript
|
||||
from precise.util import load_audio
|
||||
from precise.vectorization import vectorize_raw
|
||||
from precise_runner.runner import TriggerDetector
|
||||
|
||||
usage = '''
|
||||
Simulate listening to long chunks of audio to find
|
||||
unbiased false positive metrics
|
||||
|
||||
:model str
|
||||
Either Keras (.net) or TensorFlow (.pb) model to test
|
||||
|
||||
:folder str
|
||||
Folder with a set of long wav files to test against
|
||||
|
||||
:-c --chunk_size int 4096
|
||||
Number of samples between tests
|
||||
|
||||
:-t --threshold float 0.5
|
||||
Network output required to be considered an activation
|
||||
'''
|
||||
|
||||
|
||||
@attr.s()
|
||||
|
@ -80,9 +64,26 @@ class Metric:
|
|||
)
|
||||
|
||||
|
||||
class Simulator:
|
||||
def __init__(self):
|
||||
self.args = create_parser(usage).parse_args()
|
||||
class SimulateScript(BaseScript):
|
||||
usage = Usage('''
|
||||
Simulate listening to long chunks of audio to find
|
||||
unbiased false positive metrics
|
||||
|
||||
:model str
|
||||
Either Keras (.net) or TensorFlow (.pb) model to test
|
||||
|
||||
:folder str
|
||||
Folder with a set of long wav files to test against
|
||||
|
||||
:-c --chunk_size int 4096
|
||||
Number of samples between tests
|
||||
|
||||
:-t --threshold float 0.5
|
||||
Network output required to be considered an activation
|
||||
''')
|
||||
|
||||
def __init__(self, args):
|
||||
super().__init__(args)
|
||||
inject_params(self.args.model)
|
||||
self.runner = Listener.find_runner(self.args.model)(self.args.model)
|
||||
self.audio_buffer = np.zeros(pr.buffer_samples, dtype=float)
|
||||
|
@ -127,9 +128,7 @@ class Simulator:
|
|||
print(total.info_string('Total'))
|
||||
|
||||
|
||||
def main():
|
||||
Simulator().run()
|
||||
|
||||
main = SimulateScript.run_main
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
|
|
@ -12,60 +12,62 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from prettyparse import create_parser
|
||||
from prettyparse import Usage
|
||||
|
||||
from precise.network_runner import Listener
|
||||
from precise.params import inject_params
|
||||
from precise.scripts.base_script import BaseScript
|
||||
from precise.stats import Stats
|
||||
from precise.train_data import TrainData
|
||||
|
||||
usage = '''
|
||||
Test a model against a dataset
|
||||
|
||||
:model str
|
||||
Either Keras (.net) or TensorFlow (.pb) model to test
|
||||
|
||||
:-u --use-train
|
||||
Evaluate training data instead of test data
|
||||
|
||||
:-nf --no-filenames
|
||||
Don't print out the names of files that failed
|
||||
|
||||
:-t --threshold float 0.5
|
||||
Network output required to be considered an activation
|
||||
|
||||
...
|
||||
'''
|
||||
|
||||
class TestScript(BaseScript):
|
||||
usage = Usage('''
|
||||
Test a model against a dataset
|
||||
|
||||
def main():
|
||||
args = TrainData.parse_args(create_parser(usage))
|
||||
:model str
|
||||
Either Keras (.net) or TensorFlow (.pb) model to test
|
||||
|
||||
inject_params(args.model)
|
||||
:-u --use-train
|
||||
Evaluate training data instead of test data
|
||||
|
||||
data = TrainData.from_both(args.tags_file, args.tags_folder, args.folder)
|
||||
train, test = data.load(args.use_train, not args.use_train, shuffle=False)
|
||||
inputs, targets = train if args.use_train else test
|
||||
:-nf --no-filenames
|
||||
Don't print out the names of files that failed
|
||||
|
||||
filenames = sum(data.train_files if args.use_train else data.test_files, [])
|
||||
predictions = Listener.find_runner(args.model)(args.model).predict(inputs)
|
||||
stats = Stats(predictions, targets, filenames)
|
||||
:-t --threshold float 0.5
|
||||
Network output required to be considered an activation
|
||||
|
||||
print('Data:', data)
|
||||
...
|
||||
''') | TrainData.usage
|
||||
|
||||
if not args.no_filenames:
|
||||
fp_files = stats.calc_filenames(False, True, args.threshold)
|
||||
fn_files = stats.calc_filenames(False, False, args.threshold)
|
||||
print('=== False Positives ===')
|
||||
print('\n'.join(fp_files))
|
||||
def run(self):
|
||||
args = self.args
|
||||
inject_params(args.model)
|
||||
data = TrainData.from_both(args.tags_file, args.tags_folder, args.folder)
|
||||
train, test = data.load(args.use_train, not args.use_train, shuffle=False)
|
||||
inputs, targets = train if args.use_train else test
|
||||
|
||||
filenames = sum(data.train_files if args.use_train else data.test_files, [])
|
||||
predictions = Listener.find_runner(args.model)(args.model).predict(inputs)
|
||||
stats = Stats(predictions, targets, filenames)
|
||||
|
||||
print('Data:', data)
|
||||
|
||||
if not args.no_filenames:
|
||||
fp_files = stats.calc_filenames(False, True, args.threshold)
|
||||
fn_files = stats.calc_filenames(False, False, args.threshold)
|
||||
print('=== False Positives ===')
|
||||
print('\n'.join(fp_files))
|
||||
print()
|
||||
print('=== False Negatives ===')
|
||||
print('\n'.join(fn_files))
|
||||
print()
|
||||
print(stats.counts_str(args.threshold))
|
||||
print()
|
||||
print('=== False Negatives ===')
|
||||
print('\n'.join(fn_files))
|
||||
print()
|
||||
print(stats.counts_str(args.threshold))
|
||||
print()
|
||||
print(stats.summary_str(args.threshold))
|
||||
print(stats.summary_str(args.threshold))
|
||||
|
||||
|
||||
main = TestScript.run_main
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
|
|
@ -12,21 +12,21 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from argparse import ArgumentParser
|
||||
from fitipy import Fitipy
|
||||
from keras.callbacks import LambdaCallback
|
||||
from os.path import splitext, isfile
|
||||
from prettyparse import add_to_parser
|
||||
from prettyparse import Usage
|
||||
from typing import Any, Tuple
|
||||
|
||||
from precise.model import create_model, ModelParams
|
||||
from precise.params import inject_params, save_params
|
||||
from precise.scripts.base_script import BaseScript
|
||||
from precise.train_data import TrainData
|
||||
from precise.util import calc_sample_hash
|
||||
|
||||
|
||||
class Trainer:
|
||||
usage = '''
|
||||
class TrainScript(BaseScript):
|
||||
usage = Usage('''
|
||||
Train a new model on a dataset
|
||||
|
||||
:model str
|
||||
|
@ -60,26 +60,23 @@ class Trainer:
|
|||
|
||||
:-em --extra-metrics
|
||||
Add extra metrics during training
|
||||
|
||||
|
||||
:-f --freeze-till int 0
|
||||
Freeze all weights up to this index (non-inclusive).
|
||||
Can be negative to wrap from end
|
||||
|
||||
...
|
||||
'''
|
||||
''') | TrainData.usage
|
||||
|
||||
def __init__(self, parser=None):
|
||||
parser = parser or ArgumentParser()
|
||||
add_to_parser(parser, self.usage, True)
|
||||
args = TrainData.parse_args(parser)
|
||||
self.args = args = self.process_args(args) or args
|
||||
def __init__(self, args):
|
||||
super().__init__(args)
|
||||
|
||||
if args.invert_samples and not args.samples_file:
|
||||
parser.error('You must specify --samples-file when using --invert-samples')
|
||||
raise ValueError('You must specify --samples-file when using --invert-samples')
|
||||
if args.samples_file and not isfile(args.samples_file):
|
||||
parser.error('No such file: ' + (args.invert_samples or args.samples_file))
|
||||
raise ValueError('No such file: ' + (args.invert_samples or args.samples_file))
|
||||
if not 0.0 <= args.sensitivity <= 1.0:
|
||||
parser.error('sensitivity must be between 0.0 and 1.0')
|
||||
raise ValueError('sensitivity must be between 0.0 and 1.0')
|
||||
|
||||
inject_params(args.model)
|
||||
save_params(args.model)
|
||||
|
@ -94,7 +91,7 @@ class Trainer:
|
|||
epoch_fiti = Fitipy(splitext(args.model)[0] + '.epoch')
|
||||
self.epoch = epoch_fiti.read().read(0, int)
|
||||
|
||||
def on_epoch_end(a, b):
|
||||
def on_epoch_end(_a, _b):
|
||||
self.epoch += 1
|
||||
epoch_fiti.write().write(self.epoch, str)
|
||||
|
||||
|
@ -112,10 +109,6 @@ class Trainer:
|
|||
), LambdaCallback(on_epoch_end=on_epoch_end)
|
||||
]
|
||||
|
||||
def process_args(self, args: Any) -> Any:
|
||||
"""Override to modify args"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def load_sample_data(filename, train_data) -> Tuple[set, dict]:
|
||||
samples = Fitipy(filename).read().set()
|
||||
|
@ -163,20 +156,15 @@ class Trainer:
|
|||
|
||||
def run(self):
|
||||
self.model.summary()
|
||||
try:
|
||||
train_inputs, train_outputs = self.sampled_data
|
||||
self.model.fit(
|
||||
train_inputs, train_outputs, self.args.batch_size,
|
||||
self.epoch + self.args.epochs, validation_data=self.test,
|
||||
initial_epoch=self.epoch, callbacks=self.callbacks
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
print()
|
||||
train_inputs, train_outputs = self.sampled_data
|
||||
self.model.fit(
|
||||
train_inputs, train_outputs, self.args.batch_size,
|
||||
self.epoch + self.args.epochs, validation_data=self.test,
|
||||
initial_epoch=self.epoch, callbacks=self.callbacks
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
Trainer().run()
|
||||
|
||||
main = TrainScript.run_main
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
|
|
@ -20,66 +20,66 @@ from contextlib import suppress
|
|||
from fitipy import Fitipy
|
||||
from keras.callbacks import LambdaCallback
|
||||
from os.path import splitext, join, basename
|
||||
from prettyparse import create_parser
|
||||
from prettyparse import Usage
|
||||
from random import random, shuffle
|
||||
from typing import *
|
||||
|
||||
from precise.model import create_model, ModelParams
|
||||
from precise.network_runner import Listener
|
||||
from precise.params import pr, save_params
|
||||
from precise.scripts.base_script import BaseScript
|
||||
from precise.train_data import TrainData
|
||||
from precise.util import load_audio, glob_all, save_audio, chunk_audio
|
||||
|
||||
usage = '''
|
||||
Train a model on infinitely generated batches
|
||||
|
||||
:model str
|
||||
Keras .net model file to load from and write to
|
||||
class TrainGeneratedScript(BaseScript):
|
||||
usage = Usage('''
|
||||
Train a model on infinitely generated batches
|
||||
|
||||
:-e --epochs int 100
|
||||
Number of epochs to train on
|
||||
:model str
|
||||
Keras .net model file to load from and write to
|
||||
|
||||
:-b --batch-size int 200
|
||||
Number of samples in each batch
|
||||
|
||||
:-t --steps-per-epoch int 100
|
||||
Number of steps that are considered an epoch
|
||||
:-e --epochs int 100
|
||||
Number of epochs to train on
|
||||
|
||||
:-c --chunk-size int 2048
|
||||
Number of audio samples between generating a training sample
|
||||
:-b --batch-size int 200
|
||||
Number of samples in each batch
|
||||
|
||||
:-r --random-data-folder str data/random
|
||||
Folder with properly encoded wav files of
|
||||
random audio that should not cause an activation
|
||||
:-t --steps-per-epoch int 100
|
||||
Number of steps that are considered an epoch
|
||||
|
||||
:-s --sensitivity float 0.2
|
||||
Weighted loss bias. Higher values decrease increase positives
|
||||
:-c --chunk-size int 2048
|
||||
Number of audio samples between generating a training sample
|
||||
|
||||
:-sb --save-best
|
||||
Only save the model each epoch if its stats improve
|
||||
:-r --random-data-folder str data/random
|
||||
Folder with properly encoded wav files of
|
||||
random audio that should not cause an activation
|
||||
|
||||
:-nv --no-validation
|
||||
Disable accuracy and validation calculation
|
||||
to improve speed during training
|
||||
:-s --sensitivity float 0.2
|
||||
Weighted loss bias. Higher values decrease increase positives
|
||||
|
||||
:-mm --metric-monitor str loss
|
||||
Metric used to determine when to save
|
||||
:-sb --save-best
|
||||
Only save the model each epoch if its stats improve
|
||||
|
||||
:-em --extra-metrics
|
||||
Add extra metrics during training
|
||||
:-nv --no-validation
|
||||
Disable accuracy and validation calculation
|
||||
to improve speed during training
|
||||
|
||||
:-p --save-prob float 0.0
|
||||
Probability of saving audio into debug/ww and debug/nww folders
|
||||
:-mm --metric-monitor str loss
|
||||
Metric used to determine when to save
|
||||
|
||||
...
|
||||
'''
|
||||
:-em --extra-metrics
|
||||
Add extra metrics during training
|
||||
|
||||
:-p --save-prob float 0.0
|
||||
Probability of saving audio into debug/ww and debug/nww folders
|
||||
|
||||
class GeneratedTrainer:
|
||||
...
|
||||
''') | TrainData.usage
|
||||
"""A trainer the runs on generated data by overlaying wakewords on background audio"""
|
||||
def __init__(self):
|
||||
parser = create_parser(usage)
|
||||
self.args = args = TrainData.parse_args(parser)
|
||||
|
||||
def __init__(self, args):
|
||||
super().__init__(args)
|
||||
self.audio_buffer = np.zeros(pr.buffer_samples, dtype=float)
|
||||
self.vals_buffer = np.zeros(pr.buffer_samples, dtype=float)
|
||||
|
||||
|
@ -96,7 +96,7 @@ class GeneratedTrainer:
|
|||
epoch_fiti = Fitipy(splitext(args.model)[0] + '.epoch')
|
||||
self.epoch = epoch_fiti.read().read(0, int)
|
||||
|
||||
def on_epoch_end(a, b):
|
||||
def on_epoch_end(_a, _b):
|
||||
self.epoch += 1
|
||||
epoch_fiti.write().write(self.epoch, str)
|
||||
|
||||
|
@ -228,7 +228,8 @@ class GeneratedTrainer:
|
|||
_, test_data = self.data.load(train=False, test=True)
|
||||
try:
|
||||
self.model.fit_generator(
|
||||
self.samples_to_batches(self.generate_samples(), self.args.batch_size), steps_per_epoch=self.args.steps_per_epoch,
|
||||
self.samples_to_batches(self.generate_samples(), self.args.batch_size),
|
||||
steps_per_epoch=self.args.steps_per_epoch,
|
||||
epochs=self.epoch + self.args.epochs, validation_data=test_data,
|
||||
callbacks=self.callbacks, initial_epoch=self.epoch
|
||||
)
|
||||
|
@ -237,12 +238,7 @@ class GeneratedTrainer:
|
|||
save_params(self.args.model)
|
||||
|
||||
|
||||
def main():
|
||||
try:
|
||||
GeneratedTrainer().run()
|
||||
except KeyboardInterrupt:
|
||||
print()
|
||||
|
||||
main = TrainGeneratedScript.run_main
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
|
|
@ -15,40 +15,17 @@
|
|||
import numpy as np
|
||||
from os import makedirs
|
||||
from os.path import basename, splitext, isfile, join
|
||||
from prettyparse import create_parser
|
||||
from prettyparse import Usage
|
||||
from random import random
|
||||
from typing import *
|
||||
|
||||
from precise.model import create_model, ModelParams
|
||||
from precise.network_runner import Listener, KerasRunner
|
||||
from precise.params import pr
|
||||
from precise.scripts.train import TrainScript
|
||||
from precise.train_data import TrainData
|
||||
from precise.scripts.train import Trainer
|
||||
from precise.util import load_audio, save_audio, glob_all, chunk_audio
|
||||
|
||||
usage = '''
|
||||
Train a model to inhibit activation by
|
||||
marking false activations and retraining
|
||||
|
||||
:-e --epochs int 1
|
||||
Number of epochs to train before continuing evaluation
|
||||
|
||||
:-ds --delay-samples int 10
|
||||
Number of false activations to save before re-training
|
||||
|
||||
:-c --chunk-size int 2048
|
||||
Number of samples between testing the neural network
|
||||
|
||||
:-r --random-data-folder str data/random
|
||||
Folder with properly encoded wav files of
|
||||
random audio that should not cause an activation
|
||||
|
||||
:-th --threshold float 0.5
|
||||
Network output to be considered activated
|
||||
|
||||
...
|
||||
'''
|
||||
|
||||
|
||||
def load_trained_fns(model_name: str) -> list:
|
||||
progress_file = model_name.replace('.net', '') + '.trained.txt'
|
||||
|
@ -64,9 +41,32 @@ def save_trained_fns(trained_fns: list, model_name: str):
|
|||
f.write('\n'.join(trained_fns).encode('utf8', 'surrogatepass'))
|
||||
|
||||
|
||||
class IncrementalTrainer(Trainer):
|
||||
def __init__(self):
|
||||
super().__init__(create_parser(usage))
|
||||
class TrainIncrementalScript(TrainScript):
|
||||
usage = Usage('''
|
||||
Train a model to inhibit activation by
|
||||
marking false activations and retraining
|
||||
|
||||
:-e --epochs int 1
|
||||
Number of epochs to train before continuing evaluation
|
||||
|
||||
:-ds --delay-samples int 10
|
||||
Number of false activations to save before re-training
|
||||
|
||||
:-c --chunk-size int 2048
|
||||
Number of samples between testing the neural network
|
||||
|
||||
:-r --random-data-folder str data/random
|
||||
Folder with properly encoded wav files of
|
||||
random audio that should not cause an activation
|
||||
|
||||
:-th --threshold float 0.5
|
||||
Network output to be considered activated
|
||||
|
||||
...
|
||||
''') | TrainScript.usage
|
||||
|
||||
def __init__(self, args):
|
||||
super().__init__(args)
|
||||
|
||||
for i in (
|
||||
join(self.args.folder, 'not-wake-word', 'generated'),
|
||||
|
@ -153,12 +153,7 @@ class IncrementalTrainer(Trainer):
|
|||
save_trained_fns(self.trained_fns, self.args.model)
|
||||
|
||||
|
||||
def main():
|
||||
try:
|
||||
IncrementalTrainer().run()
|
||||
except KeyboardInterrupt:
|
||||
print()
|
||||
|
||||
main = TrainIncrementalScript.run_main
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
|
|
@ -12,45 +12,41 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import re
|
||||
from glob import glob
|
||||
from os import remove
|
||||
|
||||
from os.path import isfile, splitext, join
|
||||
|
||||
import numpy
|
||||
# Optimizer blackhat
|
||||
from bbopt import BlackBoxOptimizer
|
||||
from glob import glob
|
||||
from os import remove
|
||||
from os.path import isfile, splitext, join
|
||||
from pprint import pprint
|
||||
from prettyparse import create_parser
|
||||
from prettyparse import Usage
|
||||
from shutil import rmtree
|
||||
from typing import Any
|
||||
|
||||
from precise.model import ModelParams, create_model
|
||||
from precise.scripts.train import TrainScript
|
||||
from precise.train_data import TrainData
|
||||
from precise.scripts.train import Trainer
|
||||
|
||||
usage = '''
|
||||
Use black box optimization to tune model hyperparameters
|
||||
|
||||
:-t --trials-name str -
|
||||
Filename to save hyperparameter optimization trials in
|
||||
'.bbopt.json' will automatically be appended
|
||||
|
||||
:-c --cycles int 20
|
||||
Number of cycles of optimization to run
|
||||
|
||||
:-m --model str .cache/optimized.net
|
||||
Model to load from
|
||||
...
|
||||
'''
|
||||
|
||||
|
||||
class OptimizeTrainer(Trainer):
|
||||
usage = re.sub(r'.*:model str.*\n.*\n', '', Trainer.usage)
|
||||
class TrainOptimizeScript(TrainScript):
|
||||
Usage('''
|
||||
Use black box optimization to tune model hyperparameters
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(create_parser(usage))
|
||||
:-t --trials-name str -
|
||||
Filename to save hyperparameter optimization trials in
|
||||
'.bbopt.json' will automatically be appended
|
||||
|
||||
:-c --cycles int 20
|
||||
Number of cycles of optimization to run
|
||||
|
||||
:-m --model str .cache/optimized.net
|
||||
Model to load from
|
||||
|
||||
...
|
||||
''') | TrainScript.usage
|
||||
|
||||
def __init__(self, args):
|
||||
super().__init__(args)
|
||||
self.bb = BlackBoxOptimizer(file=self.args.trials_name)
|
||||
if not self.test:
|
||||
data = TrainData.from_both(self.args.tags_file, self.args.tags_folder, self.args.folder)
|
||||
|
@ -128,9 +124,7 @@ class OptimizeTrainer(Trainer):
|
|||
pprint(best_example)
|
||||
|
||||
|
||||
def main():
|
||||
OptimizeTrainer().run()
|
||||
|
||||
main = TrainOptimizeScript.run_main
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
|
|
@ -15,36 +15,35 @@
|
|||
from itertools import islice
|
||||
|
||||
from fitipy import Fitipy
|
||||
from prettyparse import create_parser
|
||||
from prettyparse import Usage
|
||||
|
||||
from precise.scripts.train import Trainer
|
||||
from precise.scripts.train import TrainScript
|
||||
from precise.util import calc_sample_hash
|
||||
|
||||
usage = '''
|
||||
Train a model, sampling data points with the highest loss from a larger dataset
|
||||
|
||||
:-c --cycles int 200
|
||||
Number of sampling cycles of size {epoch} to run
|
||||
|
||||
:-n --num-sample-chunk int 50
|
||||
Number of new samples to introduce at a time between training cycles
|
||||
|
||||
:-sf --samples-file str -
|
||||
Json file to write selected samples to.
|
||||
Default: {model_base}.samples.json
|
||||
|
||||
:-is --invert-samples
|
||||
Unused parameter
|
||||
...
|
||||
'''
|
||||
class TrainSampledScript(TrainScript):
|
||||
usage = Usage('''
|
||||
Train a model, sampling data points with the highest loss from a larger dataset
|
||||
|
||||
:-c --cycles int 200
|
||||
Number of sampling cycles of size {epoch} to run
|
||||
|
||||
class SampledTrainer(Trainer):
|
||||
def __init__(self):
|
||||
parser = create_parser(usage)
|
||||
super().__init__(parser)
|
||||
:-n --num-sample-chunk int 50
|
||||
Number of new samples to introduce at a time between training cycles
|
||||
|
||||
:-sf --samples-file str -
|
||||
Json file to write selected samples to.
|
||||
Default = {model_base}.samples.json
|
||||
|
||||
:-is --invert-samples
|
||||
Unused parameter
|
||||
...
|
||||
''') | TrainScript.usage
|
||||
|
||||
def __init__(self, args):
|
||||
super().__init__(args)
|
||||
if self.args.invert_samples:
|
||||
parser.error('--invert-samples should be left blank')
|
||||
raise ValueError('--invert-samples should be left blank')
|
||||
self.args.samples_file = (self.args.samples_file or '{model_base}.samples.json').format(
|
||||
model_base=self.model_base
|
||||
)
|
||||
|
@ -90,9 +89,7 @@ class SampledTrainer(Trainer):
|
|||
)
|
||||
|
||||
|
||||
def main():
|
||||
SampledTrainer().run()
|
||||
|
||||
main = TrainSampledScript.run_main
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
|
|
@ -13,11 +13,10 @@
|
|||
# limitations under the License.
|
||||
import json
|
||||
import numpy as np
|
||||
from argparse import ArgumentParser
|
||||
from glob import glob
|
||||
from hashlib import md5
|
||||
from os.path import join, isfile
|
||||
from prettyparse import add_to_parser
|
||||
from prettyparse import Usage
|
||||
from pyache import Pyache
|
||||
from typing import *
|
||||
|
||||
|
@ -27,6 +26,20 @@ from precise.vectorization import vectorize_delta, vectorize
|
|||
|
||||
class TrainData:
|
||||
"""Class to handle loading of wave data from categorized folders and tagged text files"""
|
||||
usage = Usage('''
|
||||
:folder str
|
||||
Folder to load wav files from
|
||||
|
||||
:-tf --tags-folder str {folder}
|
||||
Specify a different folder to load file ids
|
||||
in tags file from
|
||||
|
||||
:-tg --tags-file str -
|
||||
Text file to load tags from where each line is
|
||||
<file_id> TAB (wake-word|not-wake-word) and
|
||||
{folder}/<file_id>.wav exists
|
||||
|
||||
''', tags_folder=lambda args: args.tags_folder.format(folder=args.folder))
|
||||
|
||||
def __init__(self, train_files: Tuple[List[str], List[str]],
|
||||
test_files: Tuple[List[str], List[str]]):
|
||||
|
@ -144,28 +157,6 @@ class TrainData:
|
|||
def merge(data_a: tuple, data_b: tuple) -> tuple:
|
||||
return np.concatenate((data_a[0], data_b[0])), np.concatenate((data_a[1], data_b[1]))
|
||||
|
||||
@staticmethod
|
||||
def parse_args(parser: ArgumentParser) -> Any:
|
||||
"""Return parsed args from parser, adding options for train data inputs"""
|
||||
extra_usage = '''
|
||||
:folder str
|
||||
Folder to load wav files from
|
||||
|
||||
:-tf --tags-folder str {folder}
|
||||
Specify a different folder to load file ids
|
||||
in tags file from
|
||||
|
||||
:-tg --tags-file str -
|
||||
Text file to load tags from where each line is
|
||||
<file_id> TAB (wake-word|not-wake-word) and
|
||||
{folder}/<file_id>.wav exists
|
||||
|
||||
'''
|
||||
add_to_parser(parser, extra_usage)
|
||||
args = parser.parse_args()
|
||||
args.tags_folder = args.tags_folder.format(folder=args.folder)
|
||||
return args
|
||||
|
||||
def __repr__(self) -> str:
|
||||
string = '<TrainData wake_words={kws} not_wake_words={nkws}' \
|
||||
' test_wake_words={test_kws} test_not_wake_words={test_nkws}>'
|
||||
|
@ -203,6 +194,7 @@ class TrainData:
|
|||
def on_loop():
|
||||
on_loop.i += 1
|
||||
print('\r{0:.2%} '.format(on_loop.i / len(filenames)), end='', flush=True)
|
||||
|
||||
on_loop.i = 0
|
||||
|
||||
new_inputs = cache.load(filenames, on_loop=on_loop)
|
||||
|
|
Loading…
Reference in New Issue