Add precise-eval script

pull/1/head
Matthew D. Scholefield 2018-02-21 14:27:21 -06:00
parent 484a90f9bc
commit 13c3d81993
4 changed files with 95 additions and 7 deletions

1
.gitignore vendored
View File

@ -11,6 +11,7 @@ __pycache__/
*.txt
other/
.venv/
stats.json
!requirements.txt

View File

@ -14,6 +14,10 @@ from precise.params import inject_params
class Runner(metaclass=ABCMeta):
@abstractmethod
def predict(self, inputs: np.ndarray) -> np.ndarray:
pass
@abstractmethod
def run(self, inp: np.ndarray) -> float:
pass
@ -42,8 +46,12 @@ class TensorflowRunner(Runner):
return graph
def predict(self, inputs: np.ndarray) -> np.ndarray:
"""Run on multiple inputs"""
return self.sess.run(self.out_var, {self.inp_var: inputs})
def run(self, inp: np.ndarray) -> float:
return self.sess.run(self.out_var, {self.inp_var: inp[np.newaxis]})[0][0]
return self.predict(inp[np.newaxis])[0][0]
class KerasRunner(Runner):
@ -52,9 +60,12 @@ class KerasRunner(Runner):
self.model = load_precise_model(model_name)
self.graph = tf.get_default_graph()
def run(self, inp: np.ndarray) -> float:
def predict(self, inputs: np.ndarray):
with self.graph.as_default():
return self.model.predict(np.array([inp]))[0][0]
return self.model.predict(inputs)
def run(self, inp: np.ndarray) -> float:
return self.predict(inp[np.newaxis])[0][0]
class Listener:

72
precise/scripts/eval.py Executable file
View File

@ -0,0 +1,72 @@
#!/usr/bin/env python3
# Copyright (c) 2017 Mycroft AI Inc.
import sys
sys.path += ['.'] # noqa
import json
from precise.network_runner import Listener
from precise.scripts.test import show_stats
from prettyparse import create_parser
from precise.params import inject_params
from precise.train_data import TrainData
usage = '''
Evaluate a list of models on a dataset
:-t --use-train
Evaluate training data instead of test data
:-o --output str stats.json
Output json file
...
'''
def main():
parser = create_parser(usage)
parser.add_argument('models', nargs='*', help='List of model filenames')
args = TrainData.parse_args(parser)
data = TrainData.from_both(args.db_file, args.db_folder, args.data_dir)
filenames = sum(data.train_files if args.use_train else data.test_files, [])
print('Data:', data)
stats = {}
for model_name in args.models:
inject_params(model_name)
train, test = data.load()
inputs, targets = train if args.use_train else test
predictions = Listener.find_runner(model_name)(model_name).predict(inputs)
true_pos, true_neg = [], []
false_pos, false_neg = [], []
for name, target, prediction in zip(filenames, targets, predictions):
{
(True, False): false_pos,
(True, True): true_pos,
(False, True): false_neg,
(False, False): true_neg
}[prediction[0] > 0.5, target[0] > 0.5].append(name)
print('----', model_name, '----')
show_stats(false_pos, false_neg, true_pos, true_neg, False)
stats[model_name] = {
'true_pos': len(true_pos),
'true_neg': len(true_neg),
'false_pos': len(false_pos),
'false_neg': len(false_neg),
}
print('Writing to:', args.output)
with open(args.output, 'w') as f:
json.dump(stats, f)
if __name__ == '__main__':
main()

View File

@ -1,6 +1,7 @@
#!/usr/bin/env python3
from setuptools import setup, find_packages
from precise import __version__
setup(
@ -9,11 +10,14 @@ setup(
packages=find_packages(),
entry_points={
'console_scripts': [
'precise-train=precise.scripts.train:main',
'precise-train-feedback=precise.scripts.train_feedback:main',
'precise-convert=precise.scripts.convert:main',
'precise-eval=precise.scripts.eval:main',
'precise-record=precise.scripts.record:main'
'precise-stream=precise.scripts.stream:main',
'precise-test=precise.scripts.test:main',
'precise-convert=precise.scripts.convert:main'
'precise-test-pocketsphinx=precise.scripts.test_pocketsphinx:main'
'precise-train=precise.scripts.train:main',
'precise-train-incremental=precise.scripts.train_incremental:main',
]
},
install_requires=[
@ -35,7 +39,7 @@ setup(
description='Mycroft Precise Wake Word Listener',
keywords='wakeword keyword wake word listener sound',
url='http://github.com/MycroftAI/mycroft-precise',
zip_safe=True,
classifiers=[
'Development Status :: 3 - Alpha',