Add precise-eval script
parent
484a90f9bc
commit
13c3d81993
|
@ -11,6 +11,7 @@ __pycache__/
|
|||
*.txt
|
||||
other/
|
||||
.venv/
|
||||
stats.json
|
||||
|
||||
!requirements.txt
|
||||
|
||||
|
|
|
@ -14,6 +14,10 @@ from precise.params import inject_params
|
|||
|
||||
|
||||
class Runner(metaclass=ABCMeta):
|
||||
@abstractmethod
|
||||
def predict(self, inputs: np.ndarray) -> np.ndarray:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def run(self, inp: np.ndarray) -> float:
|
||||
pass
|
||||
|
@ -42,8 +46,12 @@ class TensorflowRunner(Runner):
|
|||
|
||||
return graph
|
||||
|
||||
def predict(self, inputs: np.ndarray) -> np.ndarray:
|
||||
"""Run on multiple inputs"""
|
||||
return self.sess.run(self.out_var, {self.inp_var: inputs})
|
||||
|
||||
def run(self, inp: np.ndarray) -> float:
|
||||
return self.sess.run(self.out_var, {self.inp_var: inp[np.newaxis]})[0][0]
|
||||
return self.predict(inp[np.newaxis])[0][0]
|
||||
|
||||
|
||||
class KerasRunner(Runner):
|
||||
|
@ -52,9 +60,12 @@ class KerasRunner(Runner):
|
|||
self.model = load_precise_model(model_name)
|
||||
self.graph = tf.get_default_graph()
|
||||
|
||||
def run(self, inp: np.ndarray) -> float:
|
||||
def predict(self, inputs: np.ndarray):
|
||||
with self.graph.as_default():
|
||||
return self.model.predict(np.array([inp]))[0][0]
|
||||
return self.model.predict(inputs)
|
||||
|
||||
def run(self, inp: np.ndarray) -> float:
|
||||
return self.predict(inp[np.newaxis])[0][0]
|
||||
|
||||
|
||||
class Listener:
|
||||
|
|
|
@ -0,0 +1,72 @@
|
|||
#!/usr/bin/env python3
|
||||
# Copyright (c) 2017 Mycroft AI Inc.
|
||||
import sys
|
||||
|
||||
sys.path += ['.'] # noqa
|
||||
|
||||
import json
|
||||
|
||||
from precise.network_runner import Listener
|
||||
from precise.scripts.test import show_stats
|
||||
from prettyparse import create_parser
|
||||
from precise.params import inject_params
|
||||
from precise.train_data import TrainData
|
||||
|
||||
usage = '''
|
||||
Evaluate a list of models on a dataset
|
||||
|
||||
:-t --use-train
|
||||
Evaluate training data instead of test data
|
||||
|
||||
:-o --output str stats.json
|
||||
Output json file
|
||||
|
||||
...
|
||||
'''
|
||||
|
||||
|
||||
def main():
|
||||
parser = create_parser(usage)
|
||||
parser.add_argument('models', nargs='*', help='List of model filenames')
|
||||
args = TrainData.parse_args(parser)
|
||||
|
||||
data = TrainData.from_both(args.db_file, args.db_folder, args.data_dir)
|
||||
filenames = sum(data.train_files if args.use_train else data.test_files, [])
|
||||
print('Data:', data)
|
||||
|
||||
stats = {}
|
||||
|
||||
for model_name in args.models:
|
||||
inject_params(model_name)
|
||||
|
||||
train, test = data.load()
|
||||
inputs, targets = train if args.use_train else test
|
||||
predictions = Listener.find_runner(model_name)(model_name).predict(inputs)
|
||||
|
||||
true_pos, true_neg = [], []
|
||||
false_pos, false_neg = [], []
|
||||
|
||||
for name, target, prediction in zip(filenames, targets, predictions):
|
||||
{
|
||||
(True, False): false_pos,
|
||||
(True, True): true_pos,
|
||||
(False, True): false_neg,
|
||||
(False, False): true_neg
|
||||
}[prediction[0] > 0.5, target[0] > 0.5].append(name)
|
||||
|
||||
print('----', model_name, '----')
|
||||
show_stats(false_pos, false_neg, true_pos, true_neg, False)
|
||||
stats[model_name] = {
|
||||
'true_pos': len(true_pos),
|
||||
'true_neg': len(true_neg),
|
||||
'false_pos': len(false_pos),
|
||||
'false_neg': len(false_neg),
|
||||
}
|
||||
|
||||
print('Writing to:', args.output)
|
||||
with open(args.output, 'w') as f:
|
||||
json.dump(stats, f)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
12
setup.py
12
setup.py
|
@ -1,6 +1,7 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from setuptools import setup, find_packages
|
||||
|
||||
from precise import __version__
|
||||
|
||||
setup(
|
||||
|
@ -9,11 +10,14 @@ setup(
|
|||
packages=find_packages(),
|
||||
entry_points={
|
||||
'console_scripts': [
|
||||
'precise-train=precise.scripts.train:main',
|
||||
'precise-train-feedback=precise.scripts.train_feedback:main',
|
||||
'precise-convert=precise.scripts.convert:main',
|
||||
'precise-eval=precise.scripts.eval:main',
|
||||
'precise-record=precise.scripts.record:main'
|
||||
'precise-stream=precise.scripts.stream:main',
|
||||
'precise-test=precise.scripts.test:main',
|
||||
'precise-convert=precise.scripts.convert:main'
|
||||
'precise-test-pocketsphinx=precise.scripts.test_pocketsphinx:main'
|
||||
'precise-train=precise.scripts.train:main',
|
||||
'precise-train-incremental=precise.scripts.train_incremental:main',
|
||||
]
|
||||
},
|
||||
install_requires=[
|
||||
|
@ -35,7 +39,7 @@ setup(
|
|||
description='Mycroft Precise Wake Word Listener',
|
||||
keywords='wakeword keyword wake word listener sound',
|
||||
url='http://github.com/MycroftAI/mycroft-precise',
|
||||
|
||||
|
||||
zip_safe=True,
|
||||
classifiers=[
|
||||
'Development Status :: 3 - Alpha',
|
||||
|
|
Loading…
Reference in New Issue