Add precise-graph script to show ROC curves
parent
c1b6677f48
commit
02da99e21d
|
@ -0,0 +1,108 @@
|
|||
#!/usr/bin/env python3
|
||||
# Copyright 2019 Mycroft AI Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License
|
||||
from functools import partial
|
||||
from os.path import basename, splitext
|
||||
|
||||
from prettyparse import create_parser
|
||||
from typing import Callable, Tuple
|
||||
|
||||
from precise.network_runner import Listener
|
||||
from precise.params import inject_params, pr
|
||||
from precise.stats import Stats
|
||||
from precise.train_data import TrainData
|
||||
from precise.vectorization import get_cache_folder
|
||||
|
||||
usage = '''
|
||||
Show ROC curves for a series of models
|
||||
|
||||
...
|
||||
|
||||
:-t --use-train
|
||||
Evaluate training data instead of test data
|
||||
|
||||
:-nf --no-filenames
|
||||
Don't print out the names of files that failed
|
||||
|
||||
...
|
||||
'''
|
||||
|
||||
|
||||
def test_thresholds(func, delta=0.01, power=3) -> list:
|
||||
"""Run a function with a series of thresholds between 0 and 1"""
|
||||
return [func((th * delta) ** power) for th in range(1, int(1.0 / delta))]
|
||||
|
||||
|
||||
class CachedDataLoader:
|
||||
"""
|
||||
Class for reloading train data every time the params change
|
||||
|
||||
Args:
|
||||
loader: Function that loads the train data (something that calls TrainData.load)
|
||||
"""
|
||||
|
||||
def __init__(self, loader: Callable):
|
||||
self.prev_cache = None
|
||||
self.data = None
|
||||
self.loader = loader
|
||||
|
||||
def load_for(self, model: str) -> Tuple[list, list]:
|
||||
"""Injects the model parameters, reloading if they changed, and returning the data"""
|
||||
inject_params(model)
|
||||
if get_cache_folder() != self.prev_cache:
|
||||
self.prev_cache = get_cache_folder()
|
||||
self.data = self.loader()
|
||||
return self.data
|
||||
|
||||
|
||||
def main():
|
||||
try:
|
||||
import matplotlib.pyplot as plt
|
||||
except ImportError:
|
||||
print('Please install matplotlib first')
|
||||
raise SystemExit(2)
|
||||
|
||||
parser = create_parser(usage)
|
||||
parser.add_argument('models', nargs='+', help='Either Keras (.net) or TensorFlow (.pb) models to test')
|
||||
args = TrainData.parse_args(parser)
|
||||
|
||||
data = TrainData.from_both(args.tags_file, args.tags_folder, args.folder)
|
||||
filenames = sum(data.train_files if args.use_train else data.test_files, [])
|
||||
loader = CachedDataLoader(partial(
|
||||
data.load, args.use_train, not args.use_train, shuffle=False
|
||||
))
|
||||
|
||||
for model in args.models:
|
||||
train, test = loader.load_for(model)
|
||||
inputs, targets = train if args.use_train else test
|
||||
predictions = Listener.find_runner(model)(model).predict(inputs)
|
||||
|
||||
print('Generating statistics...')
|
||||
stats = Stats(predictions, targets, filenames)
|
||||
print('\n' + stats.counts_str() + '\n\n' + stats.summary_str() + '\n')
|
||||
print('Generating x values...')
|
||||
x = test_thresholds(stats.false_positives)
|
||||
print('Generating y values...')
|
||||
y = test_thresholds(stats.false_negatives)
|
||||
plt.plot(x, y, marker='x', linestyle='-', label=basename(splitext(model)[0]))
|
||||
|
||||
print('Data:', data)
|
||||
plt.legend()
|
||||
plt.xlabel('False Positives')
|
||||
plt.ylabel('False Negatives')
|
||||
plt.show()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,94 @@
|
|||
#!/usr/bin/env python3
|
||||
# Copyright 2018 Mycroft AI Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License
|
||||
|
||||
counts_str = '''
|
||||
=== Counts ===
|
||||
False Positives: {false_pos}
|
||||
True Negatives: {true_neg}
|
||||
False Negatives: {false_neg}
|
||||
True Positives: {true_pos}
|
||||
'''.strip()
|
||||
|
||||
summary_str = '''
|
||||
=== Summary ===
|
||||
{num_correct} out of {total}
|
||||
{accuracy_ratio:.2f}
|
||||
|
||||
{false_pos_ratio:.2%} false positives
|
||||
{false_neg_ratio:.2%} false negatives
|
||||
'''
|
||||
|
||||
|
||||
class Stats:
|
||||
"""Represents a set of statistics from a model run on a dataset"""
|
||||
def __init__(self, outputs, targets, filenames):
|
||||
self.outputs = outputs
|
||||
self.targets = targets
|
||||
self.filenames = filenames
|
||||
self.num_positives = sum(int(i > 0.5) for i in self.targets)
|
||||
self.num_negatives = sum(int(i < 0.5) for i in self.targets)
|
||||
|
||||
# Methods
|
||||
self.__len__ = lambda: len(self.outputs)
|
||||
self.false_positives = lambda threshold=0.5: self.calc_metric(False, True, threshold) / self.num_negatives
|
||||
self.false_negatives = lambda threshold=0.5: self.calc_metric(False, False, threshold) / self.num_positives
|
||||
self.num_correct = lambda threshold=0.5: sum(
|
||||
int(output >= threshold) == int(target)
|
||||
for output, target in zip(self.outputs, self.targets)
|
||||
)
|
||||
self.num_incorrect = lambda threshold=0.5: len(self) - self.num_correct(threshold)
|
||||
self.accuracy = lambda threshold=0.5: self.num_correct(threshold) / len(self)
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
'true_pos': self.calc_metric(True, True),
|
||||
'true_neg': self.calc_metric(True, False),
|
||||
'false_pos': self.calc_metric(False, True),
|
||||
'false_neg': self.calc_metric(False, False),
|
||||
}
|
||||
|
||||
def counts_str(self):
|
||||
return counts_str.format(**self.to_dict())
|
||||
|
||||
def summary_str(self):
|
||||
return summary_str.format(
|
||||
num_correct=self.num_correct(), total=len(self),
|
||||
accuracy_ratio=self.accuracy(),
|
||||
false_pos_ratio=self.false_positives(),
|
||||
false_neg_ratio=self.false_negatives()
|
||||
)
|
||||
|
||||
def calc_filenames(self, is_correct: bool, actual_output: bool, threshold=0.5) -> list:
|
||||
"""Find a list of files with the given classification"""
|
||||
return [
|
||||
filename
|
||||
for output, target, filename in zip(self.outputs, self.targets, self.filenames)
|
||||
if self.matches_sample(output, target, threshold, is_correct, actual_output)
|
||||
]
|
||||
|
||||
def calc_metric(self, is_correct: bool, actual_output: bool, threshold=0.5) -> int:
|
||||
"""For example, calc_metric(False, True) calculates false positives"""
|
||||
return sum(
|
||||
self.matches_sample(output, target, threshold, is_correct, actual_output)
|
||||
for output, target, filename in zip(self.outputs, self.targets, self.filenames)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def matches_sample(output, target, threshold, is_correct, actual_output):
|
||||
"""
|
||||
Check if a sample with the given network output, target output, and threshold
|
||||
is the classification (is_correct, actual_output) like true positive or false negative
|
||||
"""
|
||||
return (bool(output > threshold) == bool(target)) == is_correct and actual_output == bool(output > threshold)
|
1
setup.py
1
setup.py
|
@ -59,6 +59,7 @@ setup(
|
|||
'precise-engine=precise.scripts.engine:main',
|
||||
'precise-simulate=precise.scripts.simulate:main',
|
||||
'precise-test=precise.scripts.test:main',
|
||||
'precise-graph=precise.scripts.graph:main',
|
||||
'precise-test-pocketsphinx=precise.pocketsphinx.scripts.test:main',
|
||||
'precise-train=precise.scripts.train:main',
|
||||
'precise-train-optimize=precise.scripts.train_optimize:main',
|
||||
|
|
Loading…
Reference in New Issue