mycroft-precise/precise/scripts/graph.py

167 lines
5.5 KiB
Python
Executable File

#!/usr/bin/env python3
# Copyright 2019 Mycroft AI Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
"""
Show ROC curves for a series of models
...
:-t --use-train
Evaluate training data instead of test data
:-nf --no-filenames
Don't print out the names of files that failed
:-r --resolution int 100
Number of points to generate
:-p --power float 3.0
Power of point distribution
:-l --labels
Print labels attached to each point
:-o --output-file str -
File to write data instead of displaying it
:-i --input-file str -
File to read data from and visualize
...
"""
import numpy as np
from functools import partial
from os.path import basename, splitext
from prettyparse import Usage
from typing import Callable, Tuple
from precise.network_runner import Listener
from precise.params import inject_params, pr
from precise.scripts.base_script import BaseScript
from precise.stats import Stats
from precise.threshold_decoder import ThresholdDecoder
from precise.train_data import TrainData
def get_thresholds(points=100, power=3) -> list:
"""Run a function with a series of thresholds between 0 and 1"""
return [(i / (points + 1)) ** power for i in range(1, points + 1)]
class CachedDataLoader:
"""
Class for reloading train data every time the params change
Args:
loader: Function that loads the train data (something that calls TrainData.load)
"""
def __init__(self, loader: Callable):
self.prev_cache = None
self.data = None
self.loader = loader
def load_for(self, model: str) -> Tuple[list, list]:
"""Injects the model parameters, reloading if they changed, and returning the data"""
inject_params(model)
if self.prev_cache != pr.vectorization_md5_hash():
self.prev_cache = pr.vectorization_md5_hash()
self.data = self.loader()
return self.data
def load_plt():
try:
import matplotlib.pyplot as plt
return plt
except ImportError:
print('Please install matplotlib first')
raise SystemExit(2)
def calc_stats(model_files, loader, use_train, filenames):
model_data = {}
for model in model_files:
train, test = loader.load_for(model)
inputs, targets = train if use_train else test
print('Running network...')
predictions = Listener.find_runner(model)(model).predict(inputs)
print(inputs.shape, targets.shape)
print('Generating statistics...')
stats = Stats(predictions, targets, filenames)
print('\n' + stats.counts_str() + '\n\n' + stats.summary_str() + '\n')
model_name = basename(splitext(model)[0])
model_data[model_name] = stats
return model_data
class GraphScript(BaseScript):
usage = Usage(__doc__)
usage.add_argument('models', nargs='*', help='Either Keras (.net) or TensorFlow (.pb) models to test')
usage |= TrainData.usage
def __init__(self, args):
super().__init__(args)
if not args.models and not args.input_file and args.folder:
args.input_file = args.folder
if bool(args.models) == bool(args.input_file):
raise ValueError('Please specify either a list of models or an input file')
if not args.output_file:
load_plt() # Error early if matplotlib not installed
def run(self):
args = self.args
if args.models:
data = TrainData.from_both(args.tags_file, args.tags_folder, args.folder)
print('Data:', data)
filenames = sum(data.train_files if args.use_train else data.test_files, [])
loader = CachedDataLoader(partial(
data.load, args.use_train, not args.use_train, shuffle=False
))
model_data = calc_stats(args.models, loader, args.use_train, filenames)
else:
model_data = {
name: Stats.from_np_dict(data) for name, data in np.load(args.input_file)['data'].item().items()
}
for name, stats in model_data.items():
print('=== {} ===\n{}\n\n{}\n'.format(name, stats.counts_str(), stats.summary_str()))
if args.output_file:
np.savez(args.output_file, data={name: stats.to_np_dict() for name, stats in model_data.items()})
else:
plt = load_plt()
decoder = ThresholdDecoder(pr.threshold_config, pr.threshold_center)
thresholds = [decoder.encode(i) for i in np.linspace(0.0, 1.0, args.resolution)[1:-1]]
for model_name, stats in model_data.items():
x = [stats.false_positives(i) for i in thresholds]
y = [stats.false_negatives(i) for i in thresholds]
plt.plot(x, y, marker='x', linestyle='-', label=model_name)
if args.labels:
for x, y, threshold in zip(x, y, thresholds):
plt.annotate('{:.4f}'.format(threshold), (x, y))
plt.legend()
plt.xlabel('False Positives')
plt.ylabel('False Negatives')
plt.show()
main = GraphScript.run_main
if __name__ == '__main__':
main()