Add option to print stats to file

pull/81/head
Matthew D. Scholefield 2019-04-03 04:59:34 -05:00
parent b4c28e1771
commit b3f2ff8a2d
2 changed files with 69 additions and 23 deletions

View File

@ -43,6 +43,12 @@ usage = '''
:-l --labels
Print labels attached to each point
:-o --output-file str -
File to write data instead of displaying it
:-i --input-file str -
File to read data from and visualize
...
'''
@ -74,26 +80,20 @@ class CachedDataLoader:
return self.data
def main():
def load_plt():
try:
import matplotlib.pyplot as plt
return plt
except ImportError:
print('Please install matplotlib first')
raise SystemExit(2)
parser = create_parser(usage)
parser.add_argument('models', nargs='*', help='Either Keras (.net) or TensorFlow (.pb) models to test')
args = TrainData.parse_args(parser)
data = TrainData.from_both(args.tags_file, args.tags_folder, args.folder)
filenames = sum(data.train_files if args.use_train else data.test_files, [])
loader = CachedDataLoader(partial(
data.load, args.use_train, not args.use_train, shuffle=False
))
for model in args.models:
def calc_stats(model_files, loader, use_train, filenames):
model_data = {}
for model in model_files:
train, test = loader.load_for(model)
inputs, targets = train if args.use_train else test
inputs, targets = train if use_train else test
print('Running network...')
predictions = Listener.find_runner(model)(model).predict(inputs)
print(inputs.shape, targets.shape)
@ -102,11 +102,45 @@ def main():
stats = Stats(predictions, targets, filenames)
print('\n' + stats.counts_str() + '\n\n' + stats.summary_str() + '\n')
model_name = basename(splitext(model)[0])
model_data[model_name] = stats
return model_data
def main():
parser = create_parser(usage)
parser.add_argument('models', nargs='*', help='Either Keras (.net) or TensorFlow (.pb) models to test')
args = TrainData.parse_args(parser)
if bool(args.models) == bool(args.input_file):
parser.error('Please specify either a list of models or an input file')
if not args.output_file:
load_plt() # Error early if matplotlib not installed
import numpy as np
data = TrainData.from_both(args.tags_file, args.tags_folder, args.folder)
filenames = sum(data.train_files if args.use_train else data.test_files, [])
loader = CachedDataLoader(partial(
data.load, args.use_train, not args.use_train, shuffle=False
))
if args.models:
model_data = calc_stats(args.models, loader, args.use_train, filenames)
else:
model_data = {
name: Stats.from_np_dict(data) for name, data in np.load(args.input_file)['data'].item().items()
}
if args.output_file:
np.savez(args.output_file, data={name: stats.to_np_dict() for name, stats in model_data.items()})
else:
plt = load_plt()
thresholds = get_thresholds(args.resolution, args.power)
for model_name, stats in model_data.items():
x = [stats.false_positives(i) for i in thresholds]
y = [stats.false_negatives(i) for i in thresholds]
plt.plot(x, y, marker='x', linestyle='-', label=basename(splitext(model)[0]))
plt.plot(x, y, marker='x', linestyle='-', label=model_name)
if args.labels:
for x, y, threshold in zip(x, y, thresholds):

View File

@ -55,6 +55,18 @@ class Stats:
def __len__(self):
return len(self.outputs)
def to_np_dict(self):
import numpy as np
return {
'outputs': self.outputs,
'targets': self.targets,
'filenames': np.array(self.filenames)
}
@staticmethod
def from_np_dict(data) -> 'Stats':
return Stats(data['outputs'], data['targets'], data['filenames'])
def to_dict(self, threshold=0.5):
return {
'true_pos': self.calc_metric(True, True, threshold),