Drastically improve precise-graph speed

It now uses numpy arrays instead of Python loops
pull/81/head
Matthew D. Scholefield 2019-04-03 04:42:42 -05:00
parent d609d783ad
commit b4c28e1771
2 changed files with 17 additions and 16 deletions

View File

@ -82,7 +82,7 @@ def main():
raise SystemExit(2) raise SystemExit(2)
parser = create_parser(usage) parser = create_parser(usage)
parser.add_argument('models', nargs='+', help='Either Keras (.net) or TensorFlow (.pb) models to test') parser.add_argument('models', nargs='*', help='Either Keras (.net) or TensorFlow (.pb) models to test')
args = TrainData.parse_args(parser) args = TrainData.parse_args(parser)
data = TrainData.from_both(args.tags_file, args.tags_folder, args.folder) data = TrainData.from_both(args.tags_file, args.tags_folder, args.folder)
@ -94,6 +94,7 @@ def main():
for model in args.models: for model in args.models:
train, test = loader.load_for(model) train, test = loader.load_for(model)
inputs, targets = train if args.use_train else test inputs, targets = train if args.use_train else test
print('Running network...')
predictions = Listener.find_runner(model)(model).predict(inputs) predictions = Listener.find_runner(model)(model).predict(inputs)
print(inputs.shape, targets.shape) print(inputs.shape, targets.shape)
@ -102,9 +103,7 @@ def main():
print('\n' + stats.counts_str() + '\n\n' + stats.summary_str() + '\n') print('\n' + stats.counts_str() + '\n\n' + stats.summary_str() + '\n')
thresholds = get_thresholds(args.resolution, args.power) thresholds = get_thresholds(args.resolution, args.power)
print('Generating x values...')
x = [stats.false_positives(i) for i in thresholds] x = [stats.false_positives(i) for i in thresholds]
print('Generating y values...')
y = [stats.false_negatives(i) for i in thresholds] y = [stats.false_negatives(i) for i in thresholds]
plt.plot(x, y, marker='x', linestyle='-', label=basename(splitext(model)[0])) plt.plot(x, y, marker='x', linestyle='-', label=basename(splitext(model)[0]))

View File

@ -24,7 +24,7 @@ True Positives: {true_pos}
summary_str = ''' summary_str = '''
=== Summary === === Summary ===
{num_correct} out of {total} {num_correct} out of {total}
{accuracy_ratio:.2f} {accuracy_ratio:.2%}
{false_pos_ratio:.2%} false positives {false_pos_ratio:.2%} false positives
{false_neg_ratio:.2%} false negatives {false_neg_ratio:.2%} false negatives
@ -33,20 +33,22 @@ summary_str = '''
class Stats: class Stats:
"""Represents a set of statistics from a model run on a dataset""" """Represents a set of statistics from a model run on a dataset"""
def __init__(self, outputs, targets, filenames): def __init__(self, outputs, targets, filenames):
self.outputs = outputs self.outputs = outputs
self.targets = targets self.targets = targets
self.filenames = filenames self.filenames = filenames
self.num_positives = sum(int(i > 0.5) for i in self.targets) self.num_positives = int((self.targets > 0.5).sum())
self.num_negatives = sum(int(i < 0.5) for i in self.targets) self.num_negatives = int((self.targets < 0.5).sum())
# Methods # Methods
self.false_positives = lambda threshold=0.5: self.calc_metric(False, True, threshold) / max(1, self.num_negatives) self.false_positives = lambda threshold=0.5: self.calc_metric(False, True, threshold) / max(1,
self.false_negatives = lambda threshold=0.5: self.calc_metric(False, False, threshold) / max(1, self.num_positives) self.num_negatives)
self.num_correct = lambda threshold=0.5: sum( self.false_negatives = lambda threshold=0.5: self.calc_metric(False, False, threshold) / max(1,
int(output >= threshold) == int(target) self.num_positives)
for output, target in zip(self.outputs, self.targets) self.num_correct = lambda threshold=0.5: (
) (self.outputs >= threshold) == self.targets.astype(bool)
).sum()
self.num_incorrect = lambda threshold=0.5: len(self) - self.num_correct(threshold) self.num_incorrect = lambda threshold=0.5: len(self) - self.num_correct(threshold)
self.accuracy = lambda threshold=0.5: self.num_correct(threshold) / max(1, len(self)) self.accuracy = lambda threshold=0.5: self.num_correct(threshold) / max(1, len(self))
@ -77,14 +79,14 @@ class Stats:
return [ return [
filename filename
for output, target, filename in zip(self.outputs, self.targets, self.filenames) for output, target, filename in zip(self.outputs, self.targets, self.filenames)
if self.matches_sample(output, target, threshold, is_correct, actual_output) if ((output > threshold) == bool(target)) == is_correct and actual_output == bool(output > threshold)
] ]
def calc_metric(self, is_correct: bool, actual_output: bool, threshold=0.5) -> int: def calc_metric(self, is_correct: bool, actual_output: bool, threshold=0.5) -> int:
"""For example, calc_metric(False, True) calculates false positives""" """For example, calc_metric(False, True) calculates false positives"""
return sum( return int(
self.matches_sample(output, target, threshold, is_correct, actual_output) ((((self.outputs > threshold) == self.targets.astype(bool)) == is_correct) &
for output, target, filename in zip(self.outputs, self.targets, self.filenames) ((self.outputs > threshold) == actual_output)).sum()
) )
@staticmethod @staticmethod