diff --git a/precise/scripts/graph.py b/precise/scripts/graph.py new file mode 100755 index 0000000..aeea4b8 --- /dev/null +++ b/precise/scripts/graph.py @@ -0,0 +1,108 @@ +#!/usr/bin/env python3 +# Copyright 2019 Mycroft AI Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License +from functools import partial +from os.path import basename, splitext + +from prettyparse import create_parser +from typing import Callable, Tuple + +from precise.network_runner import Listener +from precise.params import inject_params, pr +from precise.stats import Stats +from precise.train_data import TrainData +from precise.vectorization import get_cache_folder + +usage = ''' + Show ROC curves for a series of models + + ... + + :-t --use-train + Evaluate training data instead of test data + + :-nf --no-filenames + Don't print out the names of files that failed + + ... +''' + + +def test_thresholds(func, delta=0.01, power=3) -> list: + """Run a function with a series of thresholds between 0 and 1""" + return [func((th * delta) ** power) for th in range(1, int(1.0 / delta))] + + +class CachedDataLoader: + """ + Class for reloading train data every time the params change + + Args: + loader: Function that loads the train data (something that calls TrainData.load) + """ + + def __init__(self, loader: Callable): + self.prev_cache = None + self.data = None + self.loader = loader + + def load_for(self, model: str) -> Tuple[list, list]: + """Injects the model parameters, reloading if they changed, and returning the data""" + inject_params(model) + if get_cache_folder() != self.prev_cache: + self.prev_cache = get_cache_folder() + self.data = self.loader() + return self.data + + +def main(): + try: + import matplotlib.pyplot as plt + except ImportError: + print('Please install matplotlib first') + raise SystemExit(2) + + parser = create_parser(usage) + parser.add_argument('models', nargs='+', help='Either Keras (.net) or TensorFlow (.pb) models to test') + args = TrainData.parse_args(parser) + + data = TrainData.from_both(args.tags_file, args.tags_folder, args.folder) + filenames = sum(data.train_files if args.use_train else data.test_files, []) + loader = CachedDataLoader(partial( + data.load, args.use_train, not args.use_train, shuffle=False + )) + + for model in args.models: + train, test = loader.load_for(model) + inputs, targets = train if args.use_train else test + predictions = Listener.find_runner(model)(model).predict(inputs) + + print('Generating statistics...') + stats = Stats(predictions, targets, filenames) + print('\n' + stats.counts_str() + '\n\n' + stats.summary_str() + '\n') + print('Generating x values...') + x = test_thresholds(stats.false_positives) + print('Generating y values...') + y = test_thresholds(stats.false_negatives) + plt.plot(x, y, marker='x', linestyle='-', label=basename(splitext(model)[0])) + + print('Data:', data) + plt.legend() + plt.xlabel('False Positives') + plt.ylabel('False Negatives') + plt.show() + + +if __name__ == '__main__': + main() diff --git a/precise/stats.py b/precise/stats.py new file mode 100644 index 0000000..3101a4b --- /dev/null +++ b/precise/stats.py @@ -0,0 +1,94 @@ +#!/usr/bin/env python3 +# Copyright 2018 Mycroft AI Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +counts_str = ''' +=== Counts === +False Positives: {false_pos} +True Negatives: {true_neg} +False Negatives: {false_neg} +True Positives: {true_pos} +'''.strip() + +summary_str = ''' +=== Summary === +{num_correct} out of {total} +{accuracy_ratio:.2f} + +{false_pos_ratio:.2%} false positives +{false_neg_ratio:.2%} false negatives +''' + + +class Stats: + """Represents a set of statistics from a model run on a dataset""" + def __init__(self, outputs, targets, filenames): + self.outputs = outputs + self.targets = targets + self.filenames = filenames + self.num_positives = sum(int(i > 0.5) for i in self.targets) + self.num_negatives = sum(int(i < 0.5) for i in self.targets) + + # Methods + self.__len__ = lambda: len(self.outputs) + self.false_positives = lambda threshold=0.5: self.calc_metric(False, True, threshold) / self.num_negatives + self.false_negatives = lambda threshold=0.5: self.calc_metric(False, False, threshold) / self.num_positives + self.num_correct = lambda threshold=0.5: sum( + int(output >= threshold) == int(target) + for output, target in zip(self.outputs, self.targets) + ) + self.num_incorrect = lambda threshold=0.5: len(self) - self.num_correct(threshold) + self.accuracy = lambda threshold=0.5: self.num_correct(threshold) / len(self) + + def to_dict(self): + return { + 'true_pos': self.calc_metric(True, True), + 'true_neg': self.calc_metric(True, False), + 'false_pos': self.calc_metric(False, True), + 'false_neg': self.calc_metric(False, False), + } + + def counts_str(self): + return counts_str.format(**self.to_dict()) + + def summary_str(self): + return summary_str.format( + num_correct=self.num_correct(), total=len(self), + accuracy_ratio=self.accuracy(), + false_pos_ratio=self.false_positives(), + false_neg_ratio=self.false_negatives() + ) + + def calc_filenames(self, is_correct: bool, actual_output: bool, threshold=0.5) -> list: + """Find a list of files with the given classification""" + return [ + filename + for output, target, filename in zip(self.outputs, self.targets, self.filenames) + if self.matches_sample(output, target, threshold, is_correct, actual_output) + ] + + def calc_metric(self, is_correct: bool, actual_output: bool, threshold=0.5) -> int: + """For example, calc_metric(False, True) calculates false positives""" + return sum( + self.matches_sample(output, target, threshold, is_correct, actual_output) + for output, target, filename in zip(self.outputs, self.targets, self.filenames) + ) + + @staticmethod + def matches_sample(output, target, threshold, is_correct, actual_output): + """ + Check if a sample with the given network output, target output, and threshold + is the classification (is_correct, actual_output) like true positive or false negative + """ + return (bool(output > threshold) == bool(target)) == is_correct and actual_output == bool(output > threshold) diff --git a/setup.py b/setup.py index 41346d3..442a441 100644 --- a/setup.py +++ b/setup.py @@ -59,6 +59,7 @@ setup( 'precise-engine=precise.scripts.engine:main', 'precise-simulate=precise.scripts.simulate:main', 'precise-test=precise.scripts.test:main', + 'precise-graph=precise.scripts.graph:main', 'precise-test-pocketsphinx=precise.pocketsphinx.scripts.test:main', 'precise-train=precise.scripts.train:main', 'precise-train-optimize=precise.scripts.train_optimize:main',