diff --git a/precise/scripts/graph.py b/precise/scripts/graph.py index f430414..6ebb18a 100755 --- a/precise/scripts/graph.py +++ b/precise/scripts/graph.py @@ -43,6 +43,12 @@ usage = ''' :-l --labels Print labels attached to each point + :-o --output-file str - + File to write data instead of displaying it + + :-i --input-file str - + File to read data from and visualize + ... ''' @@ -74,26 +80,20 @@ class CachedDataLoader: return self.data -def main(): +def load_plt(): try: import matplotlib.pyplot as plt + return plt except ImportError: print('Please install matplotlib first') raise SystemExit(2) - parser = create_parser(usage) - parser.add_argument('models', nargs='*', help='Either Keras (.net) or TensorFlow (.pb) models to test') - args = TrainData.parse_args(parser) - data = TrainData.from_both(args.tags_file, args.tags_folder, args.folder) - filenames = sum(data.train_files if args.use_train else data.test_files, []) - loader = CachedDataLoader(partial( - data.load, args.use_train, not args.use_train, shuffle=False - )) - - for model in args.models: +def calc_stats(model_files, loader, use_train, filenames): + model_data = {} + for model in model_files: train, test = loader.load_for(model) - inputs, targets = train if args.use_train else test + inputs, targets = train if use_train else test print('Running network...') predictions = Listener.find_runner(model)(model).predict(inputs) print(inputs.shape, targets.shape) @@ -102,21 +102,55 @@ def main(): stats = Stats(predictions, targets, filenames) print('\n' + stats.counts_str() + '\n\n' + stats.summary_str() + '\n') + model_name = basename(splitext(model)[0]) + model_data[model_name] = stats + return model_data + + +def main(): + parser = create_parser(usage) + parser.add_argument('models', nargs='*', help='Either Keras (.net) or TensorFlow (.pb) models to test') + args = TrainData.parse_args(parser) + if bool(args.models) == bool(args.input_file): + parser.error('Please specify either a list of models or an input file') + + if not args.output_file: + load_plt() # Error early if matplotlib not installed + import numpy as np + + data = TrainData.from_both(args.tags_file, args.tags_folder, args.folder) + filenames = sum(data.train_files if args.use_train else data.test_files, []) + loader = CachedDataLoader(partial( + data.load, args.use_train, not args.use_train, shuffle=False + )) + + if args.models: + model_data = calc_stats(args.models, loader, args.use_train, filenames) + else: + model_data = { + name: Stats.from_np_dict(data) for name, data in np.load(args.input_file)['data'].item().items() + } + + if args.output_file: + np.savez(args.output_file, data={name: stats.to_np_dict() for name, stats in model_data.items()}) + else: + plt = load_plt() thresholds = get_thresholds(args.resolution, args.power) - x = [stats.false_positives(i) for i in thresholds] - y = [stats.false_negatives(i) for i in thresholds] + for model_name, stats in model_data.items(): + x = [stats.false_positives(i) for i in thresholds] + y = [stats.false_negatives(i) for i in thresholds] - plt.plot(x, y, marker='x', linestyle='-', label=basename(splitext(model)[0])) + plt.plot(x, y, marker='x', linestyle='-', label=model_name) - if args.labels: - for x, y, threshold in zip(x, y, thresholds): - plt.annotate('{:.4f}'.format(threshold), (x, y)) + if args.labels: + for x, y, threshold in zip(x, y, thresholds): + plt.annotate('{:.4f}'.format(threshold), (x, y)) - print('Data:', data) - plt.legend() - plt.xlabel('False Positives') - plt.ylabel('False Negatives') - plt.show() + print('Data:', data) + plt.legend() + plt.xlabel('False Positives') + plt.ylabel('False Negatives') + plt.show() if __name__ == '__main__': diff --git a/precise/stats.py b/precise/stats.py index 383f3b0..fc956c1 100644 --- a/precise/stats.py +++ b/precise/stats.py @@ -55,6 +55,18 @@ class Stats: def __len__(self): return len(self.outputs) + def to_np_dict(self): + import numpy as np + return { + 'outputs': self.outputs, + 'targets': self.targets, + 'filenames': np.array(self.filenames) + } + + @staticmethod + def from_np_dict(data) -> 'Stats': + return Stats(data['outputs'], data['targets'], data['filenames']) + def to_dict(self, threshold=0.5): return { 'true_pos': self.calc_metric(True, True, threshold),