Add precise-eval script
							parent
							
								
									484a90f9bc
								
							
						
					
					
						commit
						13c3d81993
					
				| 
						 | 
				
			
			@ -11,6 +11,7 @@ __pycache__/
 | 
			
		|||
*.txt
 | 
			
		||||
other/
 | 
			
		||||
.venv/
 | 
			
		||||
stats.json
 | 
			
		||||
 | 
			
		||||
!requirements.txt
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -14,6 +14,10 @@ from precise.params import inject_params
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
class Runner(metaclass=ABCMeta):
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    def predict(self, inputs: np.ndarray) -> np.ndarray:
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    def run(self, inp: np.ndarray) -> float:
 | 
			
		||||
        pass
 | 
			
		||||
| 
						 | 
				
			
			@ -42,8 +46,12 @@ class TensorflowRunner(Runner):
 | 
			
		|||
 | 
			
		||||
        return graph
 | 
			
		||||
 | 
			
		||||
    def predict(self, inputs: np.ndarray) -> np.ndarray:
 | 
			
		||||
        """Run on multiple inputs"""
 | 
			
		||||
        return self.sess.run(self.out_var, {self.inp_var: inputs})
 | 
			
		||||
 | 
			
		||||
    def run(self, inp: np.ndarray) -> float:
 | 
			
		||||
        return self.sess.run(self.out_var, {self.inp_var: inp[np.newaxis]})[0][0]
 | 
			
		||||
        return self.predict(inp[np.newaxis])[0][0]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class KerasRunner(Runner):
 | 
			
		||||
| 
						 | 
				
			
			@ -52,9 +60,12 @@ class KerasRunner(Runner):
 | 
			
		|||
        self.model = load_precise_model(model_name)
 | 
			
		||||
        self.graph = tf.get_default_graph()
 | 
			
		||||
 | 
			
		||||
    def run(self, inp: np.ndarray) -> float:
 | 
			
		||||
    def predict(self, inputs: np.ndarray):
 | 
			
		||||
        with self.graph.as_default():
 | 
			
		||||
            return self.model.predict(np.array([inp]))[0][0]
 | 
			
		||||
            return self.model.predict(inputs)
 | 
			
		||||
 | 
			
		||||
    def run(self, inp: np.ndarray) -> float:
 | 
			
		||||
        return self.predict(inp[np.newaxis])[0][0]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Listener:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -0,0 +1,72 @@
 | 
			
		|||
#!/usr/bin/env python3
 | 
			
		||||
# Copyright (c) 2017 Mycroft AI Inc.
 | 
			
		||||
import sys
 | 
			
		||||
 | 
			
		||||
sys.path += ['.']  # noqa
 | 
			
		||||
 | 
			
		||||
import json
 | 
			
		||||
 | 
			
		||||
from precise.network_runner import Listener
 | 
			
		||||
from precise.scripts.test import show_stats
 | 
			
		||||
from prettyparse import create_parser
 | 
			
		||||
from precise.params import inject_params
 | 
			
		||||
from precise.train_data import TrainData
 | 
			
		||||
 | 
			
		||||
usage = '''
 | 
			
		||||
    Evaluate a list of models on a dataset
 | 
			
		||||
    
 | 
			
		||||
    :-t --use-train
 | 
			
		||||
        Evaluate training data instead of test data
 | 
			
		||||
    
 | 
			
		||||
    :-o --output str stats.json
 | 
			
		||||
        Output json file
 | 
			
		||||
    
 | 
			
		||||
    ...
 | 
			
		||||
'''
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def main():
 | 
			
		||||
    parser = create_parser(usage)
 | 
			
		||||
    parser.add_argument('models', nargs='*', help='List of model filenames')
 | 
			
		||||
    args = TrainData.parse_args(parser)
 | 
			
		||||
 | 
			
		||||
    data = TrainData.from_both(args.db_file, args.db_folder, args.data_dir)
 | 
			
		||||
    filenames = sum(data.train_files if args.use_train else data.test_files, [])
 | 
			
		||||
    print('Data:', data)
 | 
			
		||||
 | 
			
		||||
    stats = {}
 | 
			
		||||
 | 
			
		||||
    for model_name in args.models:
 | 
			
		||||
        inject_params(model_name)
 | 
			
		||||
 | 
			
		||||
        train, test = data.load()
 | 
			
		||||
        inputs, targets = train if args.use_train else test
 | 
			
		||||
        predictions = Listener.find_runner(model_name)(model_name).predict(inputs)
 | 
			
		||||
 | 
			
		||||
        true_pos, true_neg = [], []
 | 
			
		||||
        false_pos, false_neg = [], []
 | 
			
		||||
 | 
			
		||||
        for name, target, prediction in zip(filenames, targets, predictions):
 | 
			
		||||
            {
 | 
			
		||||
                (True, False): false_pos,
 | 
			
		||||
                (True, True): true_pos,
 | 
			
		||||
                (False, True): false_neg,
 | 
			
		||||
                (False, False): true_neg
 | 
			
		||||
            }[prediction[0] > 0.5, target[0] > 0.5].append(name)
 | 
			
		||||
 | 
			
		||||
        print('----', model_name, '----')
 | 
			
		||||
        show_stats(false_pos, false_neg, true_pos, true_neg, False)
 | 
			
		||||
        stats[model_name] = {
 | 
			
		||||
            'true_pos': len(true_pos),
 | 
			
		||||
            'true_neg': len(true_neg),
 | 
			
		||||
            'false_pos': len(false_pos),
 | 
			
		||||
            'false_neg': len(false_neg),
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
    print('Writing to:', args.output)
 | 
			
		||||
    with open(args.output, 'w') as f:
 | 
			
		||||
        json.dump(stats, f)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    main()
 | 
			
		||||
							
								
								
									
										12
									
								
								setup.py
								
								
								
								
							
							
						
						
									
										12
									
								
								setup.py
								
								
								
								
							| 
						 | 
				
			
			@ -1,6 +1,7 @@
 | 
			
		|||
#!/usr/bin/env python3
 | 
			
		||||
 | 
			
		||||
from setuptools import setup, find_packages
 | 
			
		||||
 | 
			
		||||
from precise import __version__
 | 
			
		||||
 | 
			
		||||
setup(
 | 
			
		||||
| 
						 | 
				
			
			@ -9,11 +10,14 @@ setup(
 | 
			
		|||
    packages=find_packages(),
 | 
			
		||||
    entry_points={
 | 
			
		||||
        'console_scripts': [
 | 
			
		||||
            'precise-train=precise.scripts.train:main',
 | 
			
		||||
            'precise-train-feedback=precise.scripts.train_feedback:main',
 | 
			
		||||
            'precise-convert=precise.scripts.convert:main',
 | 
			
		||||
            'precise-eval=precise.scripts.eval:main',
 | 
			
		||||
            'precise-record=precise.scripts.record:main'
 | 
			
		||||
            'precise-stream=precise.scripts.stream:main',
 | 
			
		||||
            'precise-test=precise.scripts.test:main',
 | 
			
		||||
            'precise-convert=precise.scripts.convert:main'
 | 
			
		||||
            'precise-test-pocketsphinx=precise.scripts.test_pocketsphinx:main'
 | 
			
		||||
            'precise-train=precise.scripts.train:main',
 | 
			
		||||
            'precise-train-incremental=precise.scripts.train_incremental:main',
 | 
			
		||||
        ]
 | 
			
		||||
    },
 | 
			
		||||
    install_requires=[
 | 
			
		||||
| 
						 | 
				
			
			@ -35,7 +39,7 @@ setup(
 | 
			
		|||
    description='Mycroft Precise Wake Word Listener',
 | 
			
		||||
    keywords='wakeword keyword wake word listener sound',
 | 
			
		||||
    url='http://github.com/MycroftAI/mycroft-precise',
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
    zip_safe=True,
 | 
			
		||||
    classifiers=[
 | 
			
		||||
        'Development Status :: 3 - Alpha',
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue