Add precise-graph script to show ROC curves
parent
c1b6677f48
commit
02da99e21d
|
@ -0,0 +1,108 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2019 Mycroft AI Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License
|
||||||
|
from functools import partial
|
||||||
|
from os.path import basename, splitext
|
||||||
|
|
||||||
|
from prettyparse import create_parser
|
||||||
|
from typing import Callable, Tuple
|
||||||
|
|
||||||
|
from precise.network_runner import Listener
|
||||||
|
from precise.params import inject_params, pr
|
||||||
|
from precise.stats import Stats
|
||||||
|
from precise.train_data import TrainData
|
||||||
|
from precise.vectorization import get_cache_folder
|
||||||
|
|
||||||
|
usage = '''
|
||||||
|
Show ROC curves for a series of models
|
||||||
|
|
||||||
|
...
|
||||||
|
|
||||||
|
:-t --use-train
|
||||||
|
Evaluate training data instead of test data
|
||||||
|
|
||||||
|
:-nf --no-filenames
|
||||||
|
Don't print out the names of files that failed
|
||||||
|
|
||||||
|
...
|
||||||
|
'''
|
||||||
|
|
||||||
|
|
||||||
|
def test_thresholds(func, delta=0.01, power=3) -> list:
|
||||||
|
"""Run a function with a series of thresholds between 0 and 1"""
|
||||||
|
return [func((th * delta) ** power) for th in range(1, int(1.0 / delta))]
|
||||||
|
|
||||||
|
|
||||||
|
class CachedDataLoader:
|
||||||
|
"""
|
||||||
|
Class for reloading train data every time the params change
|
||||||
|
|
||||||
|
Args:
|
||||||
|
loader: Function that loads the train data (something that calls TrainData.load)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, loader: Callable):
|
||||||
|
self.prev_cache = None
|
||||||
|
self.data = None
|
||||||
|
self.loader = loader
|
||||||
|
|
||||||
|
def load_for(self, model: str) -> Tuple[list, list]:
|
||||||
|
"""Injects the model parameters, reloading if they changed, and returning the data"""
|
||||||
|
inject_params(model)
|
||||||
|
if get_cache_folder() != self.prev_cache:
|
||||||
|
self.prev_cache = get_cache_folder()
|
||||||
|
self.data = self.loader()
|
||||||
|
return self.data
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
try:
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
except ImportError:
|
||||||
|
print('Please install matplotlib first')
|
||||||
|
raise SystemExit(2)
|
||||||
|
|
||||||
|
parser = create_parser(usage)
|
||||||
|
parser.add_argument('models', nargs='+', help='Either Keras (.net) or TensorFlow (.pb) models to test')
|
||||||
|
args = TrainData.parse_args(parser)
|
||||||
|
|
||||||
|
data = TrainData.from_both(args.tags_file, args.tags_folder, args.folder)
|
||||||
|
filenames = sum(data.train_files if args.use_train else data.test_files, [])
|
||||||
|
loader = CachedDataLoader(partial(
|
||||||
|
data.load, args.use_train, not args.use_train, shuffle=False
|
||||||
|
))
|
||||||
|
|
||||||
|
for model in args.models:
|
||||||
|
train, test = loader.load_for(model)
|
||||||
|
inputs, targets = train if args.use_train else test
|
||||||
|
predictions = Listener.find_runner(model)(model).predict(inputs)
|
||||||
|
|
||||||
|
print('Generating statistics...')
|
||||||
|
stats = Stats(predictions, targets, filenames)
|
||||||
|
print('\n' + stats.counts_str() + '\n\n' + stats.summary_str() + '\n')
|
||||||
|
print('Generating x values...')
|
||||||
|
x = test_thresholds(stats.false_positives)
|
||||||
|
print('Generating y values...')
|
||||||
|
y = test_thresholds(stats.false_negatives)
|
||||||
|
plt.plot(x, y, marker='x', linestyle='-', label=basename(splitext(model)[0]))
|
||||||
|
|
||||||
|
print('Data:', data)
|
||||||
|
plt.legend()
|
||||||
|
plt.xlabel('False Positives')
|
||||||
|
plt.ylabel('False Negatives')
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
|
@ -0,0 +1,94 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2018 Mycroft AI Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License
|
||||||
|
|
||||||
|
counts_str = '''
|
||||||
|
=== Counts ===
|
||||||
|
False Positives: {false_pos}
|
||||||
|
True Negatives: {true_neg}
|
||||||
|
False Negatives: {false_neg}
|
||||||
|
True Positives: {true_pos}
|
||||||
|
'''.strip()
|
||||||
|
|
||||||
|
summary_str = '''
|
||||||
|
=== Summary ===
|
||||||
|
{num_correct} out of {total}
|
||||||
|
{accuracy_ratio:.2f}
|
||||||
|
|
||||||
|
{false_pos_ratio:.2%} false positives
|
||||||
|
{false_neg_ratio:.2%} false negatives
|
||||||
|
'''
|
||||||
|
|
||||||
|
|
||||||
|
class Stats:
|
||||||
|
"""Represents a set of statistics from a model run on a dataset"""
|
||||||
|
def __init__(self, outputs, targets, filenames):
|
||||||
|
self.outputs = outputs
|
||||||
|
self.targets = targets
|
||||||
|
self.filenames = filenames
|
||||||
|
self.num_positives = sum(int(i > 0.5) for i in self.targets)
|
||||||
|
self.num_negatives = sum(int(i < 0.5) for i in self.targets)
|
||||||
|
|
||||||
|
# Methods
|
||||||
|
self.__len__ = lambda: len(self.outputs)
|
||||||
|
self.false_positives = lambda threshold=0.5: self.calc_metric(False, True, threshold) / self.num_negatives
|
||||||
|
self.false_negatives = lambda threshold=0.5: self.calc_metric(False, False, threshold) / self.num_positives
|
||||||
|
self.num_correct = lambda threshold=0.5: sum(
|
||||||
|
int(output >= threshold) == int(target)
|
||||||
|
for output, target in zip(self.outputs, self.targets)
|
||||||
|
)
|
||||||
|
self.num_incorrect = lambda threshold=0.5: len(self) - self.num_correct(threshold)
|
||||||
|
self.accuracy = lambda threshold=0.5: self.num_correct(threshold) / len(self)
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
return {
|
||||||
|
'true_pos': self.calc_metric(True, True),
|
||||||
|
'true_neg': self.calc_metric(True, False),
|
||||||
|
'false_pos': self.calc_metric(False, True),
|
||||||
|
'false_neg': self.calc_metric(False, False),
|
||||||
|
}
|
||||||
|
|
||||||
|
def counts_str(self):
|
||||||
|
return counts_str.format(**self.to_dict())
|
||||||
|
|
||||||
|
def summary_str(self):
|
||||||
|
return summary_str.format(
|
||||||
|
num_correct=self.num_correct(), total=len(self),
|
||||||
|
accuracy_ratio=self.accuracy(),
|
||||||
|
false_pos_ratio=self.false_positives(),
|
||||||
|
false_neg_ratio=self.false_negatives()
|
||||||
|
)
|
||||||
|
|
||||||
|
def calc_filenames(self, is_correct: bool, actual_output: bool, threshold=0.5) -> list:
|
||||||
|
"""Find a list of files with the given classification"""
|
||||||
|
return [
|
||||||
|
filename
|
||||||
|
for output, target, filename in zip(self.outputs, self.targets, self.filenames)
|
||||||
|
if self.matches_sample(output, target, threshold, is_correct, actual_output)
|
||||||
|
]
|
||||||
|
|
||||||
|
def calc_metric(self, is_correct: bool, actual_output: bool, threshold=0.5) -> int:
|
||||||
|
"""For example, calc_metric(False, True) calculates false positives"""
|
||||||
|
return sum(
|
||||||
|
self.matches_sample(output, target, threshold, is_correct, actual_output)
|
||||||
|
for output, target, filename in zip(self.outputs, self.targets, self.filenames)
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def matches_sample(output, target, threshold, is_correct, actual_output):
|
||||||
|
"""
|
||||||
|
Check if a sample with the given network output, target output, and threshold
|
||||||
|
is the classification (is_correct, actual_output) like true positive or false negative
|
||||||
|
"""
|
||||||
|
return (bool(output > threshold) == bool(target)) == is_correct and actual_output == bool(output > threshold)
|
1
setup.py
1
setup.py
|
@ -59,6 +59,7 @@ setup(
|
||||||
'precise-engine=precise.scripts.engine:main',
|
'precise-engine=precise.scripts.engine:main',
|
||||||
'precise-simulate=precise.scripts.simulate:main',
|
'precise-simulate=precise.scripts.simulate:main',
|
||||||
'precise-test=precise.scripts.test:main',
|
'precise-test=precise.scripts.test:main',
|
||||||
|
'precise-graph=precise.scripts.graph:main',
|
||||||
'precise-test-pocketsphinx=precise.pocketsphinx.scripts.test:main',
|
'precise-test-pocketsphinx=precise.pocketsphinx.scripts.test:main',
|
||||||
'precise-train=precise.scripts.train:main',
|
'precise-train=precise.scripts.train:main',
|
||||||
'precise-train-optimize=precise.scripts.train_optimize:main',
|
'precise-train-optimize=precise.scripts.train_optimize:main',
|
||||||
|
|
Loading…
Reference in New Issue