New precise-collect script and minor fixes
- Fix precise-convert script - Fix pyinstaller spec - Change keyword folder to wake-word. This is more consistent with other labelspull/1/head
parent
cc32959e29
commit
ce5f93369c
|
|
@ -1,23 +1,24 @@
|
|||
dist/
|
||||
build/
|
||||
.cache/
|
||||
.idea/
|
||||
/dist/
|
||||
/build/
|
||||
/.cache/
|
||||
/.idea/
|
||||
__pycache__/
|
||||
*.egg-info/
|
||||
*.pb
|
||||
*.params
|
||||
*.net
|
||||
*.pbtxt
|
||||
*.txt
|
||||
other/
|
||||
.venv/
|
||||
stats.json
|
||||
/*.egg-info/
|
||||
/*.pb
|
||||
/*.params
|
||||
/*.net
|
||||
/*.pbtxt
|
||||
/*.txt
|
||||
/*.wav
|
||||
/other/
|
||||
/.venv/
|
||||
/stats.json
|
||||
|
||||
!requirements.txt
|
||||
|
||||
# jenkins branch
|
||||
wake-words.txt
|
||||
all/
|
||||
mycroft-precise/
|
||||
models/
|
||||
/wake-words.txt
|
||||
/all/
|
||||
/mycroft-precise/
|
||||
/models/
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
block_cipher = None
|
||||
|
||||
|
||||
a = Analysis(['precise/stream.py'],
|
||||
a = Analysis(['precise/scripts/stream.py'],
|
||||
pathex=['.'],
|
||||
binaries=[],
|
||||
datas=[],
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ def load_precise_model(model_name: str) -> Any:
|
|||
return load_keras().models.load_model(model_name)
|
||||
|
||||
|
||||
def create_model(model_name: str, skip_acc: bool = False) -> Any:
|
||||
def create_model(model_name: str, skip_acc=False, extra_metrics=False) -> Any:
|
||||
"""
|
||||
Load or create a precise model
|
||||
|
||||
|
|
@ -42,6 +42,6 @@ def create_model(model_name: str, skip_acc: bool = False) -> Any:
|
|||
model.add(Dense(1, activation='sigmoid'))
|
||||
|
||||
load_keras()
|
||||
metrics = ['accuracy', false_pos, false_neg]
|
||||
metrics = ['accuracy'] + extra_metrics * [false_pos, false_neg]
|
||||
model.compile('rmsprop', weighted_log_loss, metrics=(not skip_acc) * metrics)
|
||||
return model
|
||||
|
|
|
|||
|
|
@ -0,0 +1,146 @@
|
|||
#!/usr/bin/env python3
|
||||
# Copyright (c) 2017 Mycroft AI Inc.
|
||||
import tty
|
||||
import wave
|
||||
from os.path import isfile
|
||||
from select import select
|
||||
from sys import stdin
|
||||
from termios import tcsetattr, tcgetattr, TCSADRAIN
|
||||
from prettyparse import create_parser
|
||||
|
||||
import pyaudio
|
||||
|
||||
usage = '''
|
||||
Record audio samples for use with precise
|
||||
|
||||
:-w --width int 2
|
||||
Sample width of audio
|
||||
|
||||
:-r --rate int 16000
|
||||
Sample rate of audio
|
||||
|
||||
:-c --channels int 2
|
||||
Number of audio channels
|
||||
'''
|
||||
|
||||
|
||||
def key_pressed():
|
||||
return select([stdin], [], [], 0) == ([stdin], [], [])
|
||||
|
||||
|
||||
def termios_wrapper(main):
|
||||
global orig_settings
|
||||
orig_settings = tcgetattr(stdin)
|
||||
try:
|
||||
hide_input()
|
||||
main()
|
||||
finally:
|
||||
tcsetattr(stdin, TCSADRAIN, orig_settings)
|
||||
|
||||
|
||||
def show_input():
|
||||
tcsetattr(stdin, TCSADRAIN, orig_settings)
|
||||
|
||||
|
||||
def hide_input():
|
||||
tty.setcbreak(stdin.fileno())
|
||||
|
||||
|
||||
orig_settings = None
|
||||
|
||||
RECORD_KEY = ' '
|
||||
EXIT_KEY_CODE = 27
|
||||
|
||||
|
||||
def record_until(p, should_return, args):
|
||||
chunk_size = 1024
|
||||
stream = p.open(format=p.get_format_from_width(args.width), channels=args.channels, rate=args.rate,
|
||||
input=True, frames_per_buffer=chunk_size)
|
||||
|
||||
frames = []
|
||||
while not should_return():
|
||||
frames.append(stream.read(chunk_size))
|
||||
|
||||
stream.stop_stream()
|
||||
stream.close()
|
||||
|
||||
return b''.join(frames)
|
||||
|
||||
|
||||
def save_audio(name, data, args):
|
||||
wf = wave.open(name, 'wb')
|
||||
wf.setnchannels(args.channels)
|
||||
wf.setsampwidth(args.width)
|
||||
wf.setframerate(args.rate)
|
||||
wf.writeframes(data)
|
||||
wf.close()
|
||||
|
||||
|
||||
def next_name(name):
|
||||
name += '.wav'
|
||||
pos, num_digits = None, None
|
||||
try:
|
||||
pos = name.index('#')
|
||||
num_digits = name.count('#')
|
||||
except ValueError:
|
||||
print("Name must contain at least one # to indicate where to put the number.")
|
||||
raise
|
||||
|
||||
def get_name(i):
|
||||
nonlocal name, pos
|
||||
return name[:pos] + str(i).zfill(num_digits) + name[pos + num_digits:]
|
||||
|
||||
i = 0
|
||||
while True:
|
||||
if not isfile(get_name(i)):
|
||||
break
|
||||
i += 1
|
||||
|
||||
return get_name(i)
|
||||
|
||||
|
||||
def wait_to_continue():
|
||||
while True:
|
||||
c = stdin.read(1)
|
||||
if c == RECORD_KEY:
|
||||
return True
|
||||
elif ord(c) == EXIT_KEY_CODE:
|
||||
return False
|
||||
|
||||
|
||||
def record_until_key(p, args):
|
||||
def should_return():
|
||||
return key_pressed() and stdin.read(1) == RECORD_KEY
|
||||
|
||||
return record_until(p, should_return, args)
|
||||
|
||||
|
||||
def _main():
|
||||
args = create_parser(usage).parse_args()
|
||||
show_input()
|
||||
audio_name = input("Audio name (Ex. recording-##): ")
|
||||
hide_input()
|
||||
|
||||
p = pyaudio.PyAudio()
|
||||
|
||||
while True:
|
||||
print('Press space to record (esc to exit)...')
|
||||
|
||||
if not wait_to_continue():
|
||||
break
|
||||
|
||||
print('Recording...')
|
||||
d = record_until_key(p, args)
|
||||
name = next_name(audio_name)
|
||||
save_audio(name, d, args)
|
||||
print('Saved as ' + name)
|
||||
|
||||
p.terminate()
|
||||
|
||||
|
||||
def main():
|
||||
termios_wrapper(_main)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
|
@ -8,7 +8,7 @@ from shutil import copyfile
|
|||
from prettyparse import create_parser
|
||||
|
||||
usage = '''
|
||||
Convert keyword model from Keras to TensorFlow
|
||||
Convert wake word model from Keras to TensorFlow
|
||||
|
||||
:model str
|
||||
Input Keras model (.net)
|
||||
|
|
@ -65,8 +65,12 @@ def convert(model_path: str, out_file: str):
|
|||
del sess
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
def main():
|
||||
args = create_parser(usage).parse_args()
|
||||
|
||||
model_name = args.model.replace('.net', '')
|
||||
convert(args.model, args.out.format(model=model_name))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
|
|
|||
|
|
@ -25,6 +25,9 @@ usage = '''
|
|||
:-mm --metric-monitor str loss
|
||||
Metric used to determine when to save
|
||||
|
||||
:-em --extra-metrics
|
||||
Add extra metrics during training
|
||||
|
||||
...
|
||||
'''
|
||||
|
||||
|
|
@ -50,7 +53,7 @@ def main():
|
|||
print('Not enough data to train')
|
||||
exit(1)
|
||||
|
||||
model = create_model(args.model, args.no_validation)
|
||||
model = create_model(args.model, args.no_validation, args.extra_metrics)
|
||||
model.summary()
|
||||
|
||||
from keras.callbacks import ModelCheckpoint
|
||||
|
|
|
|||
|
|
@ -40,6 +40,9 @@ usage = '''
|
|||
:-mm --metric-monitor str loss
|
||||
Metric used to determine when to save
|
||||
|
||||
:-em --extra-metrics
|
||||
Add extra metrics during training
|
||||
|
||||
:-nv --no-validation
|
||||
Disable accuracy and validation calculation
|
||||
to improve speed during training
|
||||
|
|
@ -85,7 +88,7 @@ class IncrementalTrainer:
|
|||
self.db_data = data.load(True, not args.no_validation)
|
||||
|
||||
if not isfile(args.model):
|
||||
create_model(args.model, args.no_validation).save(args.model)
|
||||
create_model(args.model, args.no_validation, args.extra_metrics).save(args.model)
|
||||
self.listener = Listener(args.model, args.chunk_size, runner_cls=KerasRunner)
|
||||
|
||||
def retrain(self):
|
||||
|
|
@ -118,7 +121,7 @@ class IncrementalTrainer:
|
|||
if conf > 0.5:
|
||||
samples_since_train += 1
|
||||
name = splitext(basename(fn))[0] + '-' + str(i) + '.wav'
|
||||
name = join(self.args.data_dir, 'test' if save_test else '', 'not-keyword',
|
||||
name = join(self.args.data_dir, 'test' if save_test else '', 'not-wake-word',
|
||||
'generated', name)
|
||||
save_audio(name, audio_buffer)
|
||||
print()
|
||||
|
|
@ -152,8 +155,8 @@ def main():
|
|||
args = TrainData.parse_args(create_parser(usage))
|
||||
|
||||
for i in (
|
||||
join(args.data_dir, 'not-keyword', 'generated'),
|
||||
join(args.data_dir, 'test', 'not-keyword', 'generated')
|
||||
join(args.data_dir, 'not-wake-word', 'generated'),
|
||||
join(args.data_dir, 'test', 'not-wake-word', 'generated')
|
||||
):
|
||||
makedirs(i, exist_ok=True)
|
||||
|
||||
|
|
|
|||
|
|
@ -23,14 +23,14 @@ class TrainData:
|
|||
"""
|
||||
Load a set of data from a structured folder in the following format:
|
||||
{prefix}/
|
||||
keyword/
|
||||
wake-word/
|
||||
*.wav
|
||||
not-keyword/
|
||||
not-wake-word/
|
||||
*.wav
|
||||
test/
|
||||
keyword/
|
||||
wake-word/
|
||||
*.wav
|
||||
not-keyword/
|
||||
not-wake-word/
|
||||
*.wav
|
||||
"""
|
||||
return cls(find_wavs(folder), find_wavs(join(folder, 'test')))
|
||||
|
|
@ -99,7 +99,7 @@ class TrainData:
|
|||
return self.__load(self.__load_files, train, test)
|
||||
|
||||
def load_inhibit(self, train=True, test=True) -> tuple:
|
||||
"""Generate data with inhibitory inputs created from keyword samples"""
|
||||
"""Generate data with inhibitory inputs created from wake word samples"""
|
||||
|
||||
def loader(kws: list, nkws: list):
|
||||
from precise.params import pr
|
||||
|
|
@ -162,10 +162,10 @@ class TrainData:
|
|||
inputs.extend(load_vector(f, vectorizer) for f in filenames)
|
||||
outputs.extend(np.array([output]) for _ in filenames)
|
||||
|
||||
print('Loading keyword...')
|
||||
print('Loading wake-word...')
|
||||
add(kw_files, 1.0)
|
||||
|
||||
print('Loading not-keyword...')
|
||||
print('Loading not-wake-word...')
|
||||
add(nkw_files, 0.0)
|
||||
|
||||
from precise.params import pr
|
||||
|
|
|
|||
|
|
@ -52,5 +52,5 @@ def glob_all(folder: str, filt: str) -> List[str]:
|
|||
|
||||
|
||||
def find_wavs(folder: str) -> Tuple[List[str], List[str]]:
|
||||
"""Finds keyword and not-keyword wavs in folder"""
|
||||
return glob_all(folder + '/keyword', '*.wav'), glob_all(folder + '/not-keyword', '*.wav')
|
||||
"""Finds wake-word and not-wake-word wavs in folder"""
|
||||
return glob_all(folder + '/wake-word', '*.wav'), glob_all(folder + '/not-wake-word', '*.wav')
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ def vectorize(audio: np.ndarray) -> np.ndarray:
|
|||
def vectorize_inhibit(audio: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Returns an array of inputs generated from the
|
||||
keyword audio that shouldn't cause an activation
|
||||
wake word audio that shouldn't cause an activation
|
||||
"""
|
||||
|
||||
def samp(x):
|
||||
|
|
|
|||
10
setup.py
10
setup.py
|
|
@ -1,15 +1,21 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from setuptools import setup, find_packages
|
||||
from setuptools import setup
|
||||
|
||||
from precise import __version__
|
||||
|
||||
setup(
|
||||
name='mycroft-precise',
|
||||
version=__version__,
|
||||
packages=find_packages(),
|
||||
packages=[
|
||||
'precise',
|
||||
'precise.scripts',
|
||||
'precise.pocketsphinx',
|
||||
'precise.pocketsphinx.scripts'
|
||||
],
|
||||
entry_points={
|
||||
'console_scripts': [
|
||||
'precise-collect=precise.scripts.collect:main',
|
||||
'precise-convert=precise.scripts.convert:main',
|
||||
'precise-eval=precise.scripts.eval:main',
|
||||
'precise-record=precise.scripts.record:main',
|
||||
|
|
|
|||
Loading…
Reference in New Issue