Add RNN wakeword listener
commit
431843a016
|
@ -0,0 +1,11 @@
|
|||
dist/
|
||||
build/
|
||||
cache/
|
||||
.idea/
|
||||
__pycache__/
|
||||
*.pb
|
||||
*.params
|
||||
*.net
|
||||
*.pbtxt
|
||||
keyword.trained.txt
|
||||
other/
|
|
@ -0,0 +1,2 @@
|
|||
*
|
||||
!.gitignore
|
|
@ -0,0 +1,2 @@
|
|||
*
|
||||
!.gitignore
|
|
@ -0,0 +1,2 @@
|
|||
*
|
||||
!.gitignore
|
|
@ -0,0 +1,2 @@
|
|||
*
|
||||
!.gitignore
|
|
@ -0,0 +1,2 @@
|
|||
*
|
||||
!.gitignore
|
File diff suppressed because one or more lines are too long
|
@ -0,0 +1,30 @@
|
|||
# -*- mode: python -*-
|
||||
block_cipher = None
|
||||
|
||||
|
||||
a = Analysis(['precise/stream.py'],
|
||||
pathex=['.'],
|
||||
binaries=[],
|
||||
datas=[],
|
||||
hiddenimports=[],
|
||||
hookspath=[],
|
||||
runtime_hooks=[],
|
||||
excludes=['PySide', 'PyQt4', 'PyQt5', 'matplotlib'],
|
||||
win_no_prefer_redirects=False,
|
||||
win_private_assemblies=False,
|
||||
cipher=block_cipher)
|
||||
|
||||
pyz = PYZ(a.pure, a.zipped_data,
|
||||
cipher=block_cipher)
|
||||
|
||||
exe = EXE(pyz,
|
||||
a.scripts,
|
||||
a.binaries,
|
||||
a.zipfiles,
|
||||
a.datas,
|
||||
name='precise_stream',
|
||||
debug=False,
|
||||
strip=True,
|
||||
upx=True,
|
||||
runtime_tmpdir=None,
|
||||
console=True )
|
|
@ -0,0 +1,123 @@
|
|||
# Python 3
|
||||
import numpy as np
|
||||
from os.path import isfile
|
||||
from precise.params import ListenerParams
|
||||
|
||||
|
||||
pr = ListenerParams(window_t=0.2, hop_t=0.1, buffer_t=1.5,
|
||||
sample_rate=16000, sample_depth=2,
|
||||
n_mfcc=13, n_filt=20, n_fft=512)
|
||||
|
||||
lstm_units = 20
|
||||
|
||||
|
||||
def vectorize_raw(audio):
|
||||
"""Turns audio into feature vectors, without clipping for length"""
|
||||
from speechpy.main import mfcc
|
||||
return mfcc(audio, pr.sample_rate, pr.window_t, pr.hop_t, pr.n_mfcc, pr.n_filt, pr.n_fft)
|
||||
|
||||
|
||||
def vectorize(audio):
|
||||
"""
|
||||
Args:
|
||||
audio (array<float>): Audio verified to be of `sample_rate`
|
||||
|
||||
Returns:
|
||||
array<float>: Vector representation of audio
|
||||
"""
|
||||
if len(audio) > pr.max_samples:
|
||||
audio = audio[-pr.max_samples:]
|
||||
features = vectorize_raw(audio)
|
||||
if len(features) < pr.n_features:
|
||||
features = np.concatenate([np.zeros((pr.n_features - len(features), len(features[0]))), features])
|
||||
if len(features) > pr.n_features:
|
||||
features = features[-pr.n_features:]
|
||||
|
||||
return features
|
||||
|
||||
|
||||
def load_vector(name, vectorizer=vectorize):
|
||||
"""Loads and caches a vector input from a wav or npy file"""
|
||||
import os
|
||||
|
||||
save_name = name if name.endswith('.npy') else os.path.join('cache', vectorizer.__name__ + name + '.npy')
|
||||
|
||||
if os.path.isfile(save_name):
|
||||
return np.load(save_name)
|
||||
|
||||
print('Loading ' + name + '...')
|
||||
os.makedirs(os.path.dirname(save_name), exist_ok=True)
|
||||
|
||||
vec = vectorizer(load_audio(name))
|
||||
np.save(save_name, vec)
|
||||
return vec
|
||||
|
||||
|
||||
def load_audio(file):
|
||||
"""
|
||||
Args:
|
||||
file (any): Audio filename or file object
|
||||
Returns:
|
||||
rate, array<float>: Sample rate and audio samples from 0..1
|
||||
"""
|
||||
import wavio
|
||||
wav = wavio.read(file)
|
||||
if wav.data.dtype != np.int16:
|
||||
raise ValueError('Unsupported data type: ' + str(wav.data.dtype))
|
||||
if wav.rate != pr.sample_rate:
|
||||
raise ValueError('Unsupported sample rate: ' + str(wav.rate))
|
||||
|
||||
data = np.squeeze(wav.data)
|
||||
return data.astype(np.float32) / float(np.iinfo(data.dtype).max)
|
||||
|
||||
|
||||
def to_np(x):
|
||||
"""list<np.array> to np.array"""
|
||||
arr = np.empty((len(x),) + x[0].shape)
|
||||
for i in range(len(x)):
|
||||
arr[i] = x[i]
|
||||
return arr
|
||||
|
||||
|
||||
def find_wavs(folder):
|
||||
"""Finds keyword and not-keyword wavs in folder"""
|
||||
from glob import glob
|
||||
return glob(folder + '/keyword/*.wav'), glob(folder + '/not-keyword/**/*.wav', recursive=True)
|
||||
|
||||
|
||||
def load_data(prefix):
|
||||
inputs = []
|
||||
outputs = []
|
||||
|
||||
def add(filenames, output):
|
||||
nonlocal inputs, outputs
|
||||
inputs += [load_vector(f) for f in filenames]
|
||||
outputs += [np.array(output)] * len(filenames)
|
||||
|
||||
kww, nkw = find_wavs(prefix)
|
||||
|
||||
print('Loading keyword...')
|
||||
add(kww, 1.0)
|
||||
|
||||
print('Loading not-keyword...')
|
||||
add(nkw, 0.0)
|
||||
|
||||
return to_np(inputs), to_np(outputs)
|
||||
|
||||
|
||||
def create_model(model_name, should_load):
|
||||
if isfile(model_name) and should_load:
|
||||
print('Loading from ' + model_name + '...')
|
||||
from keras.models import load_model
|
||||
model = load_model(model_name)
|
||||
else:
|
||||
from keras.layers.core import Dense
|
||||
from keras.layers.recurrent import GRU
|
||||
from keras.models import Sequential
|
||||
|
||||
model = Sequential()
|
||||
model.add(GRU(lstm_units, activation='linear', input_shape=(pr.n_features, pr.feature_size), dropout=0.2, name='net'))
|
||||
model.add(Dense(1, activation='sigmoid'))
|
||||
|
||||
model.compile('rmsprop', 'mse', metrics=['accuracy'])
|
||||
return model
|
|
@ -0,0 +1,69 @@
|
|||
#!/usr/bin/env python3
|
||||
#
|
||||
# Attribution: This script was adapted from https://github.com/amir-abdi/keras_to_tensorflow
|
||||
#
|
||||
# Copyright (c) 2017 Mycroft AI Inc.
|
||||
|
||||
import sys
|
||||
sys.path += ['.']
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from os.path import split, isfile
|
||||
|
||||
from shutil import copyfile
|
||||
|
||||
|
||||
def convert(model_path, out_file):
|
||||
"""
|
||||
Converts an HD5F file from Keras to a .pb for use with TensorFlow
|
||||
|
||||
Args:
|
||||
model_path (str): location of Keras model
|
||||
out_file (str): location to write protobuf
|
||||
"""
|
||||
print('Converting', model_path, 'to', out_file, '...')
|
||||
|
||||
import tensorflow as tf
|
||||
from keras.models import load_model
|
||||
from keras import backend as K
|
||||
|
||||
out_dir, filename = split(out_file)
|
||||
out_dir = out_dir or '.'
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
|
||||
K.set_learning_phase(0)
|
||||
model = load_model(model_path)
|
||||
|
||||
out_name = 'net_output'
|
||||
tf.identity(model.output, name=out_name)
|
||||
print('Output node name:', out_name)
|
||||
print('Output folder:', out_dir)
|
||||
|
||||
sess = K.get_session()
|
||||
|
||||
# Write the graph in human readable
|
||||
tf.train.write_graph(sess.graph.as_graph_def(), out_dir, filename + 'txt', as_text=True)
|
||||
print('Saved readable graph to:', filename + 'txt')
|
||||
|
||||
# Write the graph in binary .pb file
|
||||
from tensorflow.python.framework import graph_util
|
||||
from tensorflow.python.framework import graph_io
|
||||
|
||||
cgraph = graph_util.convert_variables_to_constants(sess, sess.graph.as_graph_def(), [out_name])
|
||||
graph_io.write_graph(cgraph, out_dir, filename, as_text=False)
|
||||
|
||||
if isfile(model_path + '.params'):
|
||||
copyfile(model_path + '.params', out_file + '.params')
|
||||
|
||||
print('Saved graph to:', filename)
|
||||
|
||||
del sess
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Convert keyword model from Keras to TensorFlow')
|
||||
parser.add_argument('--model','-m', default='keyword.net', help='Input Keras model', type=argparse.FileType())
|
||||
parser.add_argument('--out', '-o', default='keyword.pb', help='Output TensorFlow protobuf')
|
||||
args = parser.parse_args()
|
||||
|
||||
convert(args.model.name, args.out)
|
|
@ -0,0 +1,24 @@
|
|||
# Python 3
|
||||
|
||||
from collections import namedtuple
|
||||
from math import floor
|
||||
|
||||
|
||||
def _make_cls():
|
||||
cls = namedtuple('ListenerParams', 'window_t hop_t buffer_t sample_rate sample_depth n_mfcc n_filt n_fft')
|
||||
|
||||
def add_prop(name, fn):
|
||||
setattr(cls, name, property(fn))
|
||||
import numpy as np
|
||||
|
||||
add_prop('buffer_samples', lambda s: s.hop_samples * (int(np.round(s.sample_rate * s.buffer_t)) // s.hop_samples))
|
||||
add_prop('window_samples', lambda s: int(np.round(s.sample_rate * s.window_t)))
|
||||
add_prop('hop_samples', lambda s: int(np.round(s.sample_rate * s.hop_t)))
|
||||
|
||||
add_prop('n_features', lambda s: 1 + int(floor((s.buffer_samples - s.window_samples) / s.hop_samples)))
|
||||
add_prop('feature_size', lambda s: s.n_mfcc)
|
||||
add_prop('max_samples', lambda s: int(s.buffer_t * s.sample_rate))
|
||||
return cls
|
||||
|
||||
|
||||
ListenerParams = _make_cls()
|
|
@ -0,0 +1,33 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import sys
|
||||
sys.path += ['.'] # noqa
|
||||
|
||||
from subprocess import Popen, PIPE
|
||||
from pyaudio import PyAudio, paInt16
|
||||
from precise.common import pr
|
||||
|
||||
|
||||
def main():
|
||||
pa = PyAudio()
|
||||
stream = pa.open(pr.sample_rate, 1, paInt16, True, frames_per_buffer=1024)
|
||||
|
||||
proc = Popen(['python3', 'precise/stream.py', 'keyword.pb', '1024'], stdin=PIPE, stdout=PIPE)
|
||||
|
||||
print('Listening...')
|
||||
try:
|
||||
while True:
|
||||
proc.stdin.write(stream.read(1024))
|
||||
proc.stdin.flush()
|
||||
|
||||
prob = float(proc.stdout.readline())
|
||||
print('!' if prob > 0.5 else '.', end='', flush=True)
|
||||
except KeyboardInterrupt:
|
||||
print()
|
||||
finally:
|
||||
stream.stop_stream()
|
||||
pa.terminate()
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
|
@ -0,0 +1,129 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import sys
|
||||
sys.path += ['.']
|
||||
|
||||
from io import StringIO
|
||||
|
||||
from precise.common import vectorize, pr
|
||||
|
||||
silent = False
|
||||
should_save = False
|
||||
enable_sound = False
|
||||
save_prefix = ''
|
||||
num_to_save = -1
|
||||
save_act = False
|
||||
|
||||
act_lev = 0
|
||||
was_active = False
|
||||
|
||||
|
||||
while len(sys.argv) > 1:
|
||||
a = sys.argv[1]
|
||||
del sys.argv[1]
|
||||
if a == 'save':
|
||||
should_save = True
|
||||
while len(sys.argv) > 1:
|
||||
a = sys.argv[1]
|
||||
if a.isdigit():
|
||||
num_to_save = int(a)
|
||||
elif a == 'sound':
|
||||
enable_sound = True
|
||||
elif a == 'silent':
|
||||
silent = True
|
||||
else:
|
||||
break
|
||||
|
||||
if silent:
|
||||
_stdout = sys.stdout
|
||||
sys.stdout = StringIO() # capture any output
|
||||
|
||||
import wave
|
||||
from random import randint
|
||||
|
||||
session_id = randint(0, 1000)
|
||||
chunk_id = 0
|
||||
|
||||
import pyaudio
|
||||
from os.path import join
|
||||
from pyaudio import PyAudio
|
||||
import numpy as np
|
||||
|
||||
|
||||
class PreciseRecognizer:
|
||||
def __init__(self):
|
||||
from keras.models import load_model
|
||||
self.model = load_model('keyword.net')
|
||||
|
||||
@staticmethod
|
||||
def buffer_to_audio(buffer):
|
||||
"""Convert a raw mono audio byte string to numpy array of floats"""
|
||||
return np.fromstring(buffer, dtype='<i2').astype(np.float32, order='C') / 32768.0
|
||||
|
||||
def found_wake_word(self, raw_data):
|
||||
inp = vectorize(self.buffer_to_audio(raw_data))
|
||||
return self.model.predict_on_batch(inp[np.newaxis]) >= 0.5
|
||||
|
||||
CHANNELS = 1
|
||||
CHUNK_SIZE = 1024
|
||||
RATE = pr.sample_rate
|
||||
WIDTH = 2 # Int16
|
||||
FORMAT = pyaudio.get_format_from_width(WIDTH)
|
||||
BUFFER_LEN = WIDTH * CHANNELS * pr.max_samples
|
||||
|
||||
p = PyAudio()
|
||||
recognizer = PreciseRecognizer()
|
||||
stream = p.open(RATE, CHANNELS, FORMAT, True, frames_per_buffer=CHUNK_SIZE)
|
||||
buffer = b'\0' * BUFFER_LEN
|
||||
|
||||
|
||||
def save(buffer, debug=False):
|
||||
if not should_save:
|
||||
return
|
||||
|
||||
global chunk_id, num_to_save
|
||||
nm = join('data', 'not-keyword',
|
||||
save_prefix + str(session_id) + '.' + str(chunk_id) + '.wav')
|
||||
chunk_id += 1
|
||||
num_to_save -= 1
|
||||
|
||||
with wave.open(nm, 'w') as wf:
|
||||
wf.setnchannels(CHANNELS)
|
||||
wf.setframerate(RATE)
|
||||
wf.setsampwidth(WIDTH)
|
||||
wf.writeframes(buffer)
|
||||
if debug:
|
||||
print('Saved to ' + nm + '.')
|
||||
|
||||
start_delay = 40
|
||||
print('Filling buffer...')
|
||||
try:
|
||||
while True:
|
||||
chunk = stream.read(CHUNK_SIZE)
|
||||
buffer = buffer[max(0, len(buffer) - BUFFER_LEN + len(chunk)):] + chunk
|
||||
|
||||
if start_delay > 0:
|
||||
start_delay -= 1
|
||||
continue
|
||||
found = recognizer.found_wake_word(buffer)
|
||||
|
||||
if found:
|
||||
if not was_active:
|
||||
was_active = True
|
||||
if silent:
|
||||
sys.stdout = _stdout
|
||||
print(':activate:')
|
||||
sys.stdout = StringIO()
|
||||
else:
|
||||
print('Activate!')
|
||||
save(buffer, debug=True)
|
||||
else:
|
||||
if was_active:
|
||||
was_active = False
|
||||
print('.', end='', flush=True)
|
||||
if num_to_save == 0:
|
||||
break
|
||||
|
||||
finally:
|
||||
stream.stop_stream()
|
||||
p.terminate()
|
|
@ -0,0 +1,110 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import sys
|
||||
sys.path += ['.'] # noqa
|
||||
|
||||
import os
|
||||
import json
|
||||
import numpy as np
|
||||
from json.decoder import JSONDecodeError
|
||||
from speechpy.main import mfcc
|
||||
from precise.params import ListenerParams
|
||||
|
||||
|
||||
def load_graph(model_file):
|
||||
graph = tf.Graph()
|
||||
graph_def = tf.GraphDef()
|
||||
|
||||
with open(model_file, "rb") as f:
|
||||
graph_def.ParseFromString(f.read())
|
||||
with graph.as_default():
|
||||
tf.import_graph_def(graph_def)
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
def buffer_to_audio(buffer):
|
||||
"""Convert a raw mono audio byte string to numpy array of floats"""
|
||||
return np.fromstring(buffer, dtype='<i2').astype(np.float32, order='C') / 32768.0
|
||||
|
||||
|
||||
class NetworkRunner:
|
||||
def __init__(self, model_name):
|
||||
self.graph = load_graph(model_name)
|
||||
|
||||
self.inp_var = self.graph.get_operation_by_name('import/net_input').outputs[0]
|
||||
self.out_var = self.graph.get_operation_by_name('import/net_output').outputs[0]
|
||||
|
||||
self.sess = tf.Session(graph=self.graph)
|
||||
|
||||
def run(self, inp):
|
||||
return self.sess.run(self.out_var, {self.inp_var: inp[np.newaxis]})[0][0]
|
||||
|
||||
|
||||
class Listener:
|
||||
def __init__(self, model_name, chunk_size):
|
||||
self.buffer = np.array([])
|
||||
self.pr = self._load_params(model_name)
|
||||
self.features = np.zeros((self.pr.n_features, self.pr.feature_size))
|
||||
self.read_size = -1 if chunk_size == -1 else self.pr.sample_depth * chunk_size
|
||||
self.runner = NetworkRunner(model_name)
|
||||
|
||||
def _load_params(self, model_name):
|
||||
try:
|
||||
with open(model_name + '.params') as f:
|
||||
return ListenerParams(**json.loads(f))
|
||||
except (OSError, JSONDecodeError, TypeError):
|
||||
from precise.common import pr
|
||||
return pr
|
||||
|
||||
def update(self, stream):
|
||||
chunk = stream.read(self.read_size)
|
||||
if len(chunk) == 0:
|
||||
raise EOFError
|
||||
|
||||
chunk_audio = buffer_to_audio(chunk)
|
||||
self.buffer = np.concatenate((self.buffer, chunk_audio))
|
||||
|
||||
if len(self.buffer) >= self.pr.window_samples:
|
||||
remaining = self.pr.window_samples - (
|
||||
self.pr.hop_samples - (len(self.buffer) - self.pr.window_samples) % self.pr.hop_samples)
|
||||
new_features = mfcc(self.buffer, self.pr.sample_rate, self.pr.window_t, self.pr.hop_t, self.pr.n_mfcc, self.pr.n_filt, self.pr.n_fft)
|
||||
|
||||
self.features = np.concatenate([self.features[len(new_features):], new_features])
|
||||
self.buffer = self.buffer[-remaining:]
|
||||
|
||||
return self.runner.run(self.features)
|
||||
|
||||
|
||||
def main():
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
|
||||
stdout = sys.stdout
|
||||
sys.stdout = sys.stderr
|
||||
|
||||
if sys.stdin.isatty() or len(sys.argv) > 3 or len(sys.argv) == 1 or (len(sys.argv) == 3 and not sys.argv[2].isdigit()):
|
||||
print('Usage:', sys.argv[0], 'MODEL_NAME [CHUNK_SIZE] < audio.wav')
|
||||
print(' stdin should be a stream of raw int16 audio,')
|
||||
print(' written in groups of CHUNK_SIZE samples.')
|
||||
print()
|
||||
print(' If no CHUNK_SIZE is given it will read until EOF.')
|
||||
print()
|
||||
print(' For every chunk, an inference will be given')
|
||||
print(' via stdout as a float string, one per line')
|
||||
sys.exit(1)
|
||||
|
||||
global tf
|
||||
import tensorflow
|
||||
tf = tensorflow
|
||||
|
||||
listener = Listener(sys.argv[1], int(sys.argv[2]) if 2 < len(sys.argv) else -1)
|
||||
|
||||
try:
|
||||
while True:
|
||||
conf = listener.update(sys.stdin.buffer)
|
||||
stdout.buffer.write((str(conf) + '\n').encode('ascii'))
|
||||
stdout.buffer.flush()
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
pass
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,59 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import sys
|
||||
|
||||
sys.path += ['.'] # noqa
|
||||
|
||||
from argparse import ArgumentParser
|
||||
|
||||
from precise.common import *
|
||||
|
||||
|
||||
def main():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument('-m', '--model', help='Keras model to load (default: keyword.net)',
|
||||
default='keyword.net')
|
||||
parser.add_argument('-d', '--data-dir',
|
||||
help='Directory to load test data from (default: data/test)',
|
||||
default='data/test')
|
||||
parser.set_defaults(load=True, save_best=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
from keras.models import load_model
|
||||
|
||||
filenames = sum(find_wavs(args.data_dir), [])
|
||||
inputs, outputs = load_data(args.data_dir)
|
||||
predictions = load_model(args.model).predict(inputs)
|
||||
|
||||
num_correct = 0
|
||||
false_pos, false_neg = [], []
|
||||
for name, correct, prediction in zip(filenames, outputs, predictions):
|
||||
if prediction < 0.5 < correct:
|
||||
false_neg += [name]
|
||||
elif prediction > 0.5 > correct:
|
||||
false_pos += [name]
|
||||
else:
|
||||
num_correct += 1
|
||||
|
||||
def prc(a, b): # Rounded percent
|
||||
return round(100.0 * a / b, 2)
|
||||
|
||||
print('=== False Positives ===')
|
||||
for i in false_pos:
|
||||
print(i)
|
||||
print()
|
||||
print('=== False Negatives ===')
|
||||
for i in false_neg:
|
||||
print(i)
|
||||
print()
|
||||
print('=== Summary ===')
|
||||
total = num_correct + len(false_pos) + len(false_neg)
|
||||
print(num_correct, "out of", total)
|
||||
print(prc(num_correct, total), "%")
|
||||
print()
|
||||
print(prc(len(false_pos), total), "% false positives")
|
||||
print(prc(len(false_neg), total), "% false negatives")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,42 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import sys
|
||||
sys.path += ['.'] # noqa
|
||||
|
||||
import argparse
|
||||
import json
|
||||
|
||||
from precise.common import *
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-e', '--epochs', type=int, default=10)
|
||||
parser.add_argument('-l', '--load', dest='load', action='store_true')
|
||||
parser.add_argument('-nl', '--no-load', dest='load', action='store_false')
|
||||
parser.add_argument('-b', '--save-best', dest='save_best', action='store_true')
|
||||
parser.add_argument('-nb', '--no-save-best', dest='save_best', action='store_false')
|
||||
parser.set_defaults(load=True, save_best=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
inputs, outputs = load_data('data')
|
||||
validation_data = load_data('data/test')
|
||||
|
||||
print('Inputs shape:', inputs.shape)
|
||||
print('Outputs shape:', outputs.shape)
|
||||
|
||||
model = create_model('keyword.net', args.load)
|
||||
|
||||
with open('keyword.net.params', 'w') as f:
|
||||
json.dump(pr._asdict(), f)
|
||||
|
||||
from keras.callbacks import ModelCheckpoint
|
||||
checkpoint = ModelCheckpoint('keyword.net', monitor='val_acc', save_best_only=args.save_best, mode='max')
|
||||
|
||||
try:
|
||||
model.fit(inputs, outputs, 5000, args.epochs, validation_data=validation_data, callbacks=[checkpoint])
|
||||
except KeyboardInterrupt:
|
||||
print()
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,83 @@
|
|||
#!/usr/bin/env python3
|
||||
# This script trains the network, selectively choosing
|
||||
# segments from data/random that cause an activation. These
|
||||
# segments are moved into data/not-keyword and the network is retrained
|
||||
|
||||
import sys
|
||||
sys.path += ['.'] # noqa
|
||||
|
||||
import argparse
|
||||
from glob import glob
|
||||
from os import makedirs
|
||||
from os.path import join, basename, splitext
|
||||
|
||||
import wavio
|
||||
|
||||
from precise.common import *
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-e', '--epochs', type=int, default=1, help='Number of epochs to train before continueing evaluation')
|
||||
parser.add_argument('-j', '--jump', type=int, default=2, help='Number of features to skip while evaluating')
|
||||
parser.add_argument('-s', '--skip-trained', type=bool, default=True, help='Whether to skip random files that have already been trained on')
|
||||
parser.add_argument('-l', '--load', type=bool, default=True)
|
||||
parser.add_argument('-b', '--save-best', type=bool, default=False)
|
||||
parser.add_argument('-d', '--out-dir', default='data/not-keyword/generated')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
trained_fns = []
|
||||
if isfile('keyword.trained.txt'):
|
||||
with open('keyword.trained.txt', 'r') as f:
|
||||
trained_fns = f.read().split('\n')
|
||||
|
||||
random_inputs = ((f, load_audio(f)) for f in glob('data/random/*.wav') if not args.skip_trained or f not in trained_fns)
|
||||
validation_data = load_data('data/test')
|
||||
|
||||
model = create_model('keyword.net', args.load)
|
||||
|
||||
from keras.callbacks import ModelCheckpoint
|
||||
checkpoint = ModelCheckpoint('keyword.net', monitor='val_acc', save_best_only=args.save_best, mode='max')
|
||||
|
||||
makedirs(args.out_dir, exist_ok=True)
|
||||
|
||||
try:
|
||||
for fn, full_audio in random_inputs:
|
||||
features = vectorize_raw(full_audio[:pr.buffer_samples])
|
||||
counter = 0
|
||||
|
||||
print('Starting file ' + fn + '...')
|
||||
for i in range(pr.buffer_samples - pr.buffer_samples % pr.hop_samples, len(full_audio), pr.hop_samples):
|
||||
print('\r' + str(i * 100. / len(full_audio)) + '%', end='', flush=True)
|
||||
window = full_audio[i-pr.window_samples:i]
|
||||
|
||||
vec = vectorize_raw(window)
|
||||
assert len(vec) == 1
|
||||
features = np.concatenate([features[1:], vec])
|
||||
|
||||
counter += 1
|
||||
if counter % args.jump == 0:
|
||||
out = model.predict(np.array([features]))[0]
|
||||
else:
|
||||
out = 0.0
|
||||
if out > 0.5:
|
||||
name = join(args.out_dir, splitext(basename(fn))[0] + '-' + str(i)) + '.wav'
|
||||
audio = full_audio[i - pr.buffer_samples:i]
|
||||
audio = (audio * np.iinfo(np.int16).max).astype(np.int16)
|
||||
wavio.write(name, audio, pr.sample_rate, sampwidth=pr.sample_depth, scale='none')
|
||||
|
||||
inputs, outputs = load_data('data')
|
||||
model.fit(inputs, outputs, 5000, 1, validation_data=validation_data)
|
||||
model.save('keyword.net')
|
||||
model.fit(inputs, outputs, 5000, args.epochs - 1, validation_data=validation_data, callbacks=[checkpoint])
|
||||
print()
|
||||
|
||||
with open('keyword.trained.txt', 'a') as f:
|
||||
f.write('\n' + fn)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print()
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,49 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from setuptools import setup, find_packages
|
||||
|
||||
setup(
|
||||
name='Mycroft Precise',
|
||||
version='0.1.0',
|
||||
packages=find_packages(),
|
||||
entry_points={
|
||||
'console_scripts': [
|
||||
'precise-train=precise.train:main',
|
||||
'precise-train-feedback=precise.train_feedback:main',
|
||||
'precise-stream=precise.stream:main',
|
||||
'precise-test=precise.test:main',
|
||||
'precise-convert=precise.convert:main'
|
||||
]
|
||||
},
|
||||
install_requires=[
|
||||
'numpy',
|
||||
'tensorflow',
|
||||
'speechpy',
|
||||
'pyaudio',
|
||||
'keras',
|
||||
'wavio'
|
||||
],
|
||||
|
||||
author='Matthew Scholefield',
|
||||
author_email='matthew.scholefield@mycroft.ai',
|
||||
description='Mycroft Precise Wake Word Listener',
|
||||
keywords='wakeword keyword wake word listener sound',
|
||||
url='http://github.com/MycroftAI/mycroft-precise',
|
||||
|
||||
zip_safe=True,
|
||||
classifiers=[
|
||||
'Development Status :: 3 - Alpha',
|
||||
'Intended Audience :: Developers',
|
||||
'Topic :: Text Processing :: Linguistic',
|
||||
'License :: OSI Approved :: Apache Software License',
|
||||
|
||||
'Programming Language :: Python :: 3',
|
||||
'Programming Language :: Python :: 3.0',
|
||||
'Programming Language :: Python :: 3.1',
|
||||
'Programming Language :: Python :: 3.2',
|
||||
'Programming Language :: Python :: 3.3',
|
||||
'Programming Language :: Python :: 3.4',
|
||||
'Programming Language :: Python :: 3.5',
|
||||
'Programming Language :: Python :: 3.6',
|
||||
],
|
||||
)
|
Loading…
Reference in New Issue