New precise-collect script and minor fixes

- Fix precise-convert script
 - Fix pyinstaller spec
 - Change keyword folder to wake-word. This is more consistent with other labels
pull/1/head
Matthew D. Scholefield 2018-02-23 13:45:17 -06:00
parent cc32959e29
commit ce5f93369c
13 changed files with 202 additions and 39 deletions

35
.gitignore vendored
View File

@ -1,23 +1,24 @@
dist/
build/
.cache/
.idea/
/dist/
/build/
/.cache/
/.idea/
__pycache__/
*.egg-info/
*.pb
*.params
*.net
*.pbtxt
*.txt
other/
.venv/
stats.json
/*.egg-info/
/*.pb
/*.params
/*.net
/*.pbtxt
/*.txt
/*.wav
/other/
/.venv/
/stats.json
!requirements.txt
# jenkins branch
wake-words.txt
all/
mycroft-precise/
models/
/wake-words.txt
/all/
/mycroft-precise/
/models/

View File

@ -2,7 +2,7 @@
block_cipher = None
a = Analysis(['precise/stream.py'],
a = Analysis(['precise/scripts/stream.py'],
pathex=['.'],
binaries=[],
datas=[],

View File

@ -17,7 +17,7 @@ def load_precise_model(model_name: str) -> Any:
return load_keras().models.load_model(model_name)
def create_model(model_name: str, skip_acc: bool = False) -> Any:
def create_model(model_name: str, skip_acc=False, extra_metrics=False) -> Any:
"""
Load or create a precise model
@ -42,6 +42,6 @@ def create_model(model_name: str, skip_acc: bool = False) -> Any:
model.add(Dense(1, activation='sigmoid'))
load_keras()
metrics = ['accuracy', false_pos, false_neg]
metrics = ['accuracy'] + extra_metrics * [false_pos, false_neg]
model.compile('rmsprop', weighted_log_loss, metrics=(not skip_acc) * metrics)
return model

View File

View File

146
precise/scripts/collect.py Normal file
View File

@ -0,0 +1,146 @@
#!/usr/bin/env python3
# Copyright (c) 2017 Mycroft AI Inc.
import tty
import wave
from os.path import isfile
from select import select
from sys import stdin
from termios import tcsetattr, tcgetattr, TCSADRAIN
from prettyparse import create_parser
import pyaudio
usage = '''
Record audio samples for use with precise
:-w --width int 2
Sample width of audio
:-r --rate int 16000
Sample rate of audio
:-c --channels int 2
Number of audio channels
'''
def key_pressed():
return select([stdin], [], [], 0) == ([stdin], [], [])
def termios_wrapper(main):
global orig_settings
orig_settings = tcgetattr(stdin)
try:
hide_input()
main()
finally:
tcsetattr(stdin, TCSADRAIN, orig_settings)
def show_input():
tcsetattr(stdin, TCSADRAIN, orig_settings)
def hide_input():
tty.setcbreak(stdin.fileno())
orig_settings = None
RECORD_KEY = ' '
EXIT_KEY_CODE = 27
def record_until(p, should_return, args):
chunk_size = 1024
stream = p.open(format=p.get_format_from_width(args.width), channels=args.channels, rate=args.rate,
input=True, frames_per_buffer=chunk_size)
frames = []
while not should_return():
frames.append(stream.read(chunk_size))
stream.stop_stream()
stream.close()
return b''.join(frames)
def save_audio(name, data, args):
wf = wave.open(name, 'wb')
wf.setnchannels(args.channels)
wf.setsampwidth(args.width)
wf.setframerate(args.rate)
wf.writeframes(data)
wf.close()
def next_name(name):
name += '.wav'
pos, num_digits = None, None
try:
pos = name.index('#')
num_digits = name.count('#')
except ValueError:
print("Name must contain at least one # to indicate where to put the number.")
raise
def get_name(i):
nonlocal name, pos
return name[:pos] + str(i).zfill(num_digits) + name[pos + num_digits:]
i = 0
while True:
if not isfile(get_name(i)):
break
i += 1
return get_name(i)
def wait_to_continue():
while True:
c = stdin.read(1)
if c == RECORD_KEY:
return True
elif ord(c) == EXIT_KEY_CODE:
return False
def record_until_key(p, args):
def should_return():
return key_pressed() and stdin.read(1) == RECORD_KEY
return record_until(p, should_return, args)
def _main():
args = create_parser(usage).parse_args()
show_input()
audio_name = input("Audio name (Ex. recording-##): ")
hide_input()
p = pyaudio.PyAudio()
while True:
print('Press space to record (esc to exit)...')
if not wait_to_continue():
break
print('Recording...')
d = record_until_key(p, args)
name = next_name(audio_name)
save_audio(name, d, args)
print('Saved as ' + name)
p.terminate()
def main():
termios_wrapper(_main)
if __name__ == '__main__':
main()

View File

@ -8,7 +8,7 @@ from shutil import copyfile
from prettyparse import create_parser
usage = '''
Convert keyword model from Keras to TensorFlow
Convert wake word model from Keras to TensorFlow
:model str
Input Keras model (.net)
@ -65,8 +65,12 @@ def convert(model_path: str, out_file: str):
del sess
if __name__ == '__main__':
def main():
args = create_parser(usage).parse_args()
model_name = args.model.replace('.net', '')
convert(args.model, args.out.format(model=model_name))
if __name__ == '__main__':
main()

View File

@ -25,6 +25,9 @@ usage = '''
:-mm --metric-monitor str loss
Metric used to determine when to save
:-em --extra-metrics
Add extra metrics during training
...
'''
@ -50,7 +53,7 @@ def main():
print('Not enough data to train')
exit(1)
model = create_model(args.model, args.no_validation)
model = create_model(args.model, args.no_validation, args.extra_metrics)
model.summary()
from keras.callbacks import ModelCheckpoint

View File

@ -40,6 +40,9 @@ usage = '''
:-mm --metric-monitor str loss
Metric used to determine when to save
:-em --extra-metrics
Add extra metrics during training
:-nv --no-validation
Disable accuracy and validation calculation
to improve speed during training
@ -85,7 +88,7 @@ class IncrementalTrainer:
self.db_data = data.load(True, not args.no_validation)
if not isfile(args.model):
create_model(args.model, args.no_validation).save(args.model)
create_model(args.model, args.no_validation, args.extra_metrics).save(args.model)
self.listener = Listener(args.model, args.chunk_size, runner_cls=KerasRunner)
def retrain(self):
@ -118,7 +121,7 @@ class IncrementalTrainer:
if conf > 0.5:
samples_since_train += 1
name = splitext(basename(fn))[0] + '-' + str(i) + '.wav'
name = join(self.args.data_dir, 'test' if save_test else '', 'not-keyword',
name = join(self.args.data_dir, 'test' if save_test else '', 'not-wake-word',
'generated', name)
save_audio(name, audio_buffer)
print()
@ -152,8 +155,8 @@ def main():
args = TrainData.parse_args(create_parser(usage))
for i in (
join(args.data_dir, 'not-keyword', 'generated'),
join(args.data_dir, 'test', 'not-keyword', 'generated')
join(args.data_dir, 'not-wake-word', 'generated'),
join(args.data_dir, 'test', 'not-wake-word', 'generated')
):
makedirs(i, exist_ok=True)

View File

@ -23,14 +23,14 @@ class TrainData:
"""
Load a set of data from a structured folder in the following format:
{prefix}/
keyword/
wake-word/
*.wav
not-keyword/
not-wake-word/
*.wav
test/
keyword/
wake-word/
*.wav
not-keyword/
not-wake-word/
*.wav
"""
return cls(find_wavs(folder), find_wavs(join(folder, 'test')))
@ -99,7 +99,7 @@ class TrainData:
return self.__load(self.__load_files, train, test)
def load_inhibit(self, train=True, test=True) -> tuple:
"""Generate data with inhibitory inputs created from keyword samples"""
"""Generate data with inhibitory inputs created from wake word samples"""
def loader(kws: list, nkws: list):
from precise.params import pr
@ -162,10 +162,10 @@ class TrainData:
inputs.extend(load_vector(f, vectorizer) for f in filenames)
outputs.extend(np.array([output]) for _ in filenames)
print('Loading keyword...')
print('Loading wake-word...')
add(kw_files, 1.0)
print('Loading not-keyword...')
print('Loading not-wake-word...')
add(nkw_files, 0.0)
from precise.params import pr

View File

@ -52,5 +52,5 @@ def glob_all(folder: str, filt: str) -> List[str]:
def find_wavs(folder: str) -> Tuple[List[str], List[str]]:
"""Finds keyword and not-keyword wavs in folder"""
return glob_all(folder + '/keyword', '*.wav'), glob_all(folder + '/not-keyword', '*.wav')
"""Finds wake-word and not-wake-word wavs in folder"""
return glob_all(folder + '/wake-word', '*.wav'), glob_all(folder + '/not-wake-word', '*.wav')

View File

@ -40,7 +40,7 @@ def vectorize(audio: np.ndarray) -> np.ndarray:
def vectorize_inhibit(audio: np.ndarray) -> np.ndarray:
"""
Returns an array of inputs generated from the
keyword audio that shouldn't cause an activation
wake word audio that shouldn't cause an activation
"""
def samp(x):

View File

@ -1,15 +1,21 @@
#!/usr/bin/env python3
from setuptools import setup, find_packages
from setuptools import setup
from precise import __version__
setup(
name='mycroft-precise',
version=__version__,
packages=find_packages(),
packages=[
'precise',
'precise.scripts',
'precise.pocketsphinx',
'precise.pocketsphinx.scripts'
],
entry_points={
'console_scripts': [
'precise-collect=precise.scripts.collect:main',
'precise-convert=precise.scripts.convert:main',
'precise-eval=precise.scripts.eval:main',
'precise-record=precise.scripts.record:main',