Add precise-graph script to show ROC curves

feature/cyclic
Matthew D. Scholefield 2019-03-19 23:22:23 -05:00
parent c1b6677f48
commit 02da99e21d
3 changed files with 203 additions and 0 deletions

108
precise/scripts/graph.py Executable file
View File

@ -0,0 +1,108 @@
#!/usr/bin/env python3
# Copyright 2019 Mycroft AI Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
from functools import partial
from os.path import basename, splitext
from prettyparse import create_parser
from typing import Callable, Tuple
from precise.network_runner import Listener
from precise.params import inject_params, pr
from precise.stats import Stats
from precise.train_data import TrainData
from precise.vectorization import get_cache_folder
usage = '''
Show ROC curves for a series of models
...
:-t --use-train
Evaluate training data instead of test data
:-nf --no-filenames
Don't print out the names of files that failed
...
'''
def test_thresholds(func, delta=0.01, power=3) -> list:
"""Run a function with a series of thresholds between 0 and 1"""
return [func((th * delta) ** power) for th in range(1, int(1.0 / delta))]
class CachedDataLoader:
"""
Class for reloading train data every time the params change
Args:
loader: Function that loads the train data (something that calls TrainData.load)
"""
def __init__(self, loader: Callable):
self.prev_cache = None
self.data = None
self.loader = loader
def load_for(self, model: str) -> Tuple[list, list]:
"""Injects the model parameters, reloading if they changed, and returning the data"""
inject_params(model)
if get_cache_folder() != self.prev_cache:
self.prev_cache = get_cache_folder()
self.data = self.loader()
return self.data
def main():
try:
import matplotlib.pyplot as plt
except ImportError:
print('Please install matplotlib first')
raise SystemExit(2)
parser = create_parser(usage)
parser.add_argument('models', nargs='+', help='Either Keras (.net) or TensorFlow (.pb) models to test')
args = TrainData.parse_args(parser)
data = TrainData.from_both(args.tags_file, args.tags_folder, args.folder)
filenames = sum(data.train_files if args.use_train else data.test_files, [])
loader = CachedDataLoader(partial(
data.load, args.use_train, not args.use_train, shuffle=False
))
for model in args.models:
train, test = loader.load_for(model)
inputs, targets = train if args.use_train else test
predictions = Listener.find_runner(model)(model).predict(inputs)
print('Generating statistics...')
stats = Stats(predictions, targets, filenames)
print('\n' + stats.counts_str() + '\n\n' + stats.summary_str() + '\n')
print('Generating x values...')
x = test_thresholds(stats.false_positives)
print('Generating y values...')
y = test_thresholds(stats.false_negatives)
plt.plot(x, y, marker='x', linestyle='-', label=basename(splitext(model)[0]))
print('Data:', data)
plt.legend()
plt.xlabel('False Positives')
plt.ylabel('False Negatives')
plt.show()
if __name__ == '__main__':
main()

94
precise/stats.py Normal file
View File

@ -0,0 +1,94 @@
#!/usr/bin/env python3
# Copyright 2018 Mycroft AI Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
counts_str = '''
=== Counts ===
False Positives: {false_pos}
True Negatives: {true_neg}
False Negatives: {false_neg}
True Positives: {true_pos}
'''.strip()
summary_str = '''
=== Summary ===
{num_correct} out of {total}
{accuracy_ratio:.2f}
{false_pos_ratio:.2%} false positives
{false_neg_ratio:.2%} false negatives
'''
class Stats:
"""Represents a set of statistics from a model run on a dataset"""
def __init__(self, outputs, targets, filenames):
self.outputs = outputs
self.targets = targets
self.filenames = filenames
self.num_positives = sum(int(i > 0.5) for i in self.targets)
self.num_negatives = sum(int(i < 0.5) for i in self.targets)
# Methods
self.__len__ = lambda: len(self.outputs)
self.false_positives = lambda threshold=0.5: self.calc_metric(False, True, threshold) / self.num_negatives
self.false_negatives = lambda threshold=0.5: self.calc_metric(False, False, threshold) / self.num_positives
self.num_correct = lambda threshold=0.5: sum(
int(output >= threshold) == int(target)
for output, target in zip(self.outputs, self.targets)
)
self.num_incorrect = lambda threshold=0.5: len(self) - self.num_correct(threshold)
self.accuracy = lambda threshold=0.5: self.num_correct(threshold) / len(self)
def to_dict(self):
return {
'true_pos': self.calc_metric(True, True),
'true_neg': self.calc_metric(True, False),
'false_pos': self.calc_metric(False, True),
'false_neg': self.calc_metric(False, False),
}
def counts_str(self):
return counts_str.format(**self.to_dict())
def summary_str(self):
return summary_str.format(
num_correct=self.num_correct(), total=len(self),
accuracy_ratio=self.accuracy(),
false_pos_ratio=self.false_positives(),
false_neg_ratio=self.false_negatives()
)
def calc_filenames(self, is_correct: bool, actual_output: bool, threshold=0.5) -> list:
"""Find a list of files with the given classification"""
return [
filename
for output, target, filename in zip(self.outputs, self.targets, self.filenames)
if self.matches_sample(output, target, threshold, is_correct, actual_output)
]
def calc_metric(self, is_correct: bool, actual_output: bool, threshold=0.5) -> int:
"""For example, calc_metric(False, True) calculates false positives"""
return sum(
self.matches_sample(output, target, threshold, is_correct, actual_output)
for output, target, filename in zip(self.outputs, self.targets, self.filenames)
)
@staticmethod
def matches_sample(output, target, threshold, is_correct, actual_output):
"""
Check if a sample with the given network output, target output, and threshold
is the classification (is_correct, actual_output) like true positive or false negative
"""
return (bool(output > threshold) == bool(target)) == is_correct and actual_output == bool(output > threshold)

View File

@ -59,6 +59,7 @@ setup(
'precise-engine=precise.scripts.engine:main',
'precise-simulate=precise.scripts.simulate:main',
'precise-test=precise.scripts.test:main',
'precise-graph=precise.scripts.graph:main',
'precise-test-pocketsphinx=precise.pocketsphinx.scripts.test:main',
'precise-train=precise.scripts.train:main',
'precise-train-optimize=precise.scripts.train_optimize:main',