Add precise-graph script to show ROC curves
							parent
							
								
									c1b6677f48
								
							
						
					
					
						commit
						02da99e21d
					
				| 
						 | 
				
			
			@ -0,0 +1,108 @@
 | 
			
		|||
#!/usr/bin/env python3
 | 
			
		||||
# Copyright 2019 Mycroft AI Inc.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License
 | 
			
		||||
from functools import partial
 | 
			
		||||
from os.path import basename, splitext
 | 
			
		||||
 | 
			
		||||
from prettyparse import create_parser
 | 
			
		||||
from typing import Callable, Tuple
 | 
			
		||||
 | 
			
		||||
from precise.network_runner import Listener
 | 
			
		||||
from precise.params import inject_params, pr
 | 
			
		||||
from precise.stats import Stats
 | 
			
		||||
from precise.train_data import TrainData
 | 
			
		||||
from precise.vectorization import get_cache_folder
 | 
			
		||||
 | 
			
		||||
usage = '''
 | 
			
		||||
    Show ROC curves for a series of models
 | 
			
		||||
    
 | 
			
		||||
    ...
 | 
			
		||||
 | 
			
		||||
    :-t --use-train
 | 
			
		||||
        Evaluate training data instead of test data
 | 
			
		||||
    
 | 
			
		||||
    :-nf --no-filenames
 | 
			
		||||
        Don't print out the names of files that failed
 | 
			
		||||
    
 | 
			
		||||
    ...
 | 
			
		||||
'''
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_thresholds(func, delta=0.01, power=3) -> list:
 | 
			
		||||
    """Run a function with a series of thresholds between 0 and 1"""
 | 
			
		||||
    return [func((th * delta) ** power) for th in range(1, int(1.0 / delta))]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CachedDataLoader:
 | 
			
		||||
    """
 | 
			
		||||
    Class for reloading train data every time the params change
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        loader: Function that loads the train data (something that calls TrainData.load)
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, loader: Callable):
 | 
			
		||||
        self.prev_cache = None
 | 
			
		||||
        self.data = None
 | 
			
		||||
        self.loader = loader
 | 
			
		||||
 | 
			
		||||
    def load_for(self, model: str) -> Tuple[list, list]:
 | 
			
		||||
        """Injects the model parameters, reloading if they changed, and returning the data"""
 | 
			
		||||
        inject_params(model)
 | 
			
		||||
        if get_cache_folder() != self.prev_cache:
 | 
			
		||||
            self.prev_cache = get_cache_folder()
 | 
			
		||||
            self.data = self.loader()
 | 
			
		||||
        return self.data
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def main():
 | 
			
		||||
    try:
 | 
			
		||||
        import matplotlib.pyplot as plt
 | 
			
		||||
    except ImportError:
 | 
			
		||||
        print('Please install matplotlib first')
 | 
			
		||||
        raise SystemExit(2)
 | 
			
		||||
 | 
			
		||||
    parser = create_parser(usage)
 | 
			
		||||
    parser.add_argument('models', nargs='+', help='Either Keras (.net) or TensorFlow (.pb) models to test')
 | 
			
		||||
    args = TrainData.parse_args(parser)
 | 
			
		||||
 | 
			
		||||
    data = TrainData.from_both(args.tags_file, args.tags_folder, args.folder)
 | 
			
		||||
    filenames = sum(data.train_files if args.use_train else data.test_files, [])
 | 
			
		||||
    loader = CachedDataLoader(partial(
 | 
			
		||||
        data.load, args.use_train, not args.use_train, shuffle=False
 | 
			
		||||
    ))
 | 
			
		||||
 | 
			
		||||
    for model in args.models:
 | 
			
		||||
        train, test = loader.load_for(model)
 | 
			
		||||
        inputs, targets = train if args.use_train else test
 | 
			
		||||
        predictions = Listener.find_runner(model)(model).predict(inputs)
 | 
			
		||||
 | 
			
		||||
        print('Generating statistics...')
 | 
			
		||||
        stats = Stats(predictions, targets, filenames)
 | 
			
		||||
        print('\n' + stats.counts_str() + '\n\n' + stats.summary_str() + '\n')
 | 
			
		||||
        print('Generating x values...')
 | 
			
		||||
        x = test_thresholds(stats.false_positives)
 | 
			
		||||
        print('Generating y values...')
 | 
			
		||||
        y = test_thresholds(stats.false_negatives)
 | 
			
		||||
        plt.plot(x, y, marker='x', linestyle='-', label=basename(splitext(model)[0]))
 | 
			
		||||
 | 
			
		||||
    print('Data:', data)
 | 
			
		||||
    plt.legend()
 | 
			
		||||
    plt.xlabel('False Positives')
 | 
			
		||||
    plt.ylabel('False Negatives')
 | 
			
		||||
    plt.show()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    main()
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,94 @@
 | 
			
		|||
#!/usr/bin/env python3
 | 
			
		||||
# Copyright 2018 Mycroft AI Inc.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License
 | 
			
		||||
 | 
			
		||||
counts_str = '''
 | 
			
		||||
=== Counts ===
 | 
			
		||||
False Positives: {false_pos}
 | 
			
		||||
True Negatives: {true_neg}
 | 
			
		||||
False Negatives: {false_neg}
 | 
			
		||||
True Positives: {true_pos}
 | 
			
		||||
'''.strip()
 | 
			
		||||
 | 
			
		||||
summary_str = '''
 | 
			
		||||
=== Summary ===
 | 
			
		||||
{num_correct} out of {total}
 | 
			
		||||
{accuracy_ratio:.2f}
 | 
			
		||||
 | 
			
		||||
{false_pos_ratio:.2%} false positives
 | 
			
		||||
{false_neg_ratio:.2%} false negatives
 | 
			
		||||
'''
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Stats:
 | 
			
		||||
    """Represents a set of statistics from a model run on a dataset"""
 | 
			
		||||
    def __init__(self, outputs, targets, filenames):
 | 
			
		||||
        self.outputs = outputs
 | 
			
		||||
        self.targets = targets
 | 
			
		||||
        self.filenames = filenames
 | 
			
		||||
        self.num_positives = sum(int(i > 0.5) for i in self.targets)
 | 
			
		||||
        self.num_negatives = sum(int(i < 0.5) for i in self.targets)
 | 
			
		||||
 | 
			
		||||
        # Methods
 | 
			
		||||
        self.__len__ = lambda: len(self.outputs)
 | 
			
		||||
        self.false_positives = lambda threshold=0.5: self.calc_metric(False, True, threshold) / self.num_negatives
 | 
			
		||||
        self.false_negatives = lambda threshold=0.5: self.calc_metric(False, False, threshold) / self.num_positives
 | 
			
		||||
        self.num_correct = lambda threshold=0.5: sum(
 | 
			
		||||
            int(output >= threshold) == int(target)
 | 
			
		||||
            for output, target in zip(self.outputs, self.targets)
 | 
			
		||||
        )
 | 
			
		||||
        self.num_incorrect = lambda threshold=0.5: len(self) - self.num_correct(threshold)
 | 
			
		||||
        self.accuracy = lambda threshold=0.5: self.num_correct(threshold) / len(self)
 | 
			
		||||
 | 
			
		||||
    def to_dict(self):
 | 
			
		||||
        return {
 | 
			
		||||
            'true_pos': self.calc_metric(True, True),
 | 
			
		||||
            'true_neg': self.calc_metric(True, False),
 | 
			
		||||
            'false_pos': self.calc_metric(False, True),
 | 
			
		||||
            'false_neg': self.calc_metric(False, False),
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
    def counts_str(self):
 | 
			
		||||
        return counts_str.format(**self.to_dict())
 | 
			
		||||
 | 
			
		||||
    def summary_str(self):
 | 
			
		||||
        return summary_str.format(
 | 
			
		||||
            num_correct=self.num_correct(), total=len(self),
 | 
			
		||||
            accuracy_ratio=self.accuracy(),
 | 
			
		||||
            false_pos_ratio=self.false_positives(),
 | 
			
		||||
            false_neg_ratio=self.false_negatives()
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def calc_filenames(self, is_correct: bool, actual_output: bool, threshold=0.5) -> list:
 | 
			
		||||
        """Find a list of files with the given classification"""
 | 
			
		||||
        return [
 | 
			
		||||
            filename
 | 
			
		||||
            for output, target, filename in zip(self.outputs, self.targets, self.filenames)
 | 
			
		||||
            if self.matches_sample(output, target, threshold, is_correct, actual_output)
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
    def calc_metric(self, is_correct: bool, actual_output: bool, threshold=0.5) -> int:
 | 
			
		||||
        """For example, calc_metric(False, True) calculates false positives"""
 | 
			
		||||
        return sum(
 | 
			
		||||
            self.matches_sample(output, target, threshold, is_correct, actual_output)
 | 
			
		||||
            for output, target, filename in zip(self.outputs, self.targets, self.filenames)
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def matches_sample(output, target, threshold, is_correct, actual_output):
 | 
			
		||||
        """
 | 
			
		||||
        Check if a sample with the given network output, target output, and threshold
 | 
			
		||||
        is the classification (is_correct, actual_output) like true positive or false negative
 | 
			
		||||
        """
 | 
			
		||||
        return (bool(output > threshold) == bool(target)) == is_correct and actual_output == bool(output > threshold)
 | 
			
		||||
							
								
								
									
										1
									
								
								setup.py
								
								
								
								
							
							
						
						
									
										1
									
								
								setup.py
								
								
								
								
							| 
						 | 
				
			
			@ -59,6 +59,7 @@ setup(
 | 
			
		|||
            'precise-engine=precise.scripts.engine:main',
 | 
			
		||||
            'precise-simulate=precise.scripts.simulate:main',
 | 
			
		||||
            'precise-test=precise.scripts.test:main',
 | 
			
		||||
            'precise-graph=precise.scripts.graph:main',
 | 
			
		||||
            'precise-test-pocketsphinx=precise.pocketsphinx.scripts.test:main',
 | 
			
		||||
            'precise-train=precise.scripts.train:main',
 | 
			
		||||
            'precise-train-optimize=precise.scripts.train_optimize:main',
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue