Drastically improve precise-graph speed
It now uses numpy arrays instead of Python loopspull/81/head
parent
d609d783ad
commit
b4c28e1771
|
@ -82,7 +82,7 @@ def main():
|
|||
raise SystemExit(2)
|
||||
|
||||
parser = create_parser(usage)
|
||||
parser.add_argument('models', nargs='+', help='Either Keras (.net) or TensorFlow (.pb) models to test')
|
||||
parser.add_argument('models', nargs='*', help='Either Keras (.net) or TensorFlow (.pb) models to test')
|
||||
args = TrainData.parse_args(parser)
|
||||
|
||||
data = TrainData.from_both(args.tags_file, args.tags_folder, args.folder)
|
||||
|
@ -94,6 +94,7 @@ def main():
|
|||
for model in args.models:
|
||||
train, test = loader.load_for(model)
|
||||
inputs, targets = train if args.use_train else test
|
||||
print('Running network...')
|
||||
predictions = Listener.find_runner(model)(model).predict(inputs)
|
||||
print(inputs.shape, targets.shape)
|
||||
|
||||
|
@ -102,9 +103,7 @@ def main():
|
|||
print('\n' + stats.counts_str() + '\n\n' + stats.summary_str() + '\n')
|
||||
|
||||
thresholds = get_thresholds(args.resolution, args.power)
|
||||
print('Generating x values...')
|
||||
x = [stats.false_positives(i) for i in thresholds]
|
||||
print('Generating y values...')
|
||||
y = [stats.false_negatives(i) for i in thresholds]
|
||||
|
||||
plt.plot(x, y, marker='x', linestyle='-', label=basename(splitext(model)[0]))
|
||||
|
|
|
@ -24,7 +24,7 @@ True Positives: {true_pos}
|
|||
summary_str = '''
|
||||
=== Summary ===
|
||||
{num_correct} out of {total}
|
||||
{accuracy_ratio:.2f}
|
||||
{accuracy_ratio:.2%}
|
||||
|
||||
{false_pos_ratio:.2%} false positives
|
||||
{false_neg_ratio:.2%} false negatives
|
||||
|
@ -33,20 +33,22 @@ summary_str = '''
|
|||
|
||||
class Stats:
|
||||
"""Represents a set of statistics from a model run on a dataset"""
|
||||
|
||||
def __init__(self, outputs, targets, filenames):
|
||||
self.outputs = outputs
|
||||
self.targets = targets
|
||||
self.filenames = filenames
|
||||
self.num_positives = sum(int(i > 0.5) for i in self.targets)
|
||||
self.num_negatives = sum(int(i < 0.5) for i in self.targets)
|
||||
self.num_positives = int((self.targets > 0.5).sum())
|
||||
self.num_negatives = int((self.targets < 0.5).sum())
|
||||
|
||||
# Methods
|
||||
self.false_positives = lambda threshold=0.5: self.calc_metric(False, True, threshold) / max(1, self.num_negatives)
|
||||
self.false_negatives = lambda threshold=0.5: self.calc_metric(False, False, threshold) / max(1, self.num_positives)
|
||||
self.num_correct = lambda threshold=0.5: sum(
|
||||
int(output >= threshold) == int(target)
|
||||
for output, target in zip(self.outputs, self.targets)
|
||||
)
|
||||
self.false_positives = lambda threshold=0.5: self.calc_metric(False, True, threshold) / max(1,
|
||||
self.num_negatives)
|
||||
self.false_negatives = lambda threshold=0.5: self.calc_metric(False, False, threshold) / max(1,
|
||||
self.num_positives)
|
||||
self.num_correct = lambda threshold=0.5: (
|
||||
(self.outputs >= threshold) == self.targets.astype(bool)
|
||||
).sum()
|
||||
self.num_incorrect = lambda threshold=0.5: len(self) - self.num_correct(threshold)
|
||||
self.accuracy = lambda threshold=0.5: self.num_correct(threshold) / max(1, len(self))
|
||||
|
||||
|
@ -77,14 +79,14 @@ class Stats:
|
|||
return [
|
||||
filename
|
||||
for output, target, filename in zip(self.outputs, self.targets, self.filenames)
|
||||
if self.matches_sample(output, target, threshold, is_correct, actual_output)
|
||||
if ((output > threshold) == bool(target)) == is_correct and actual_output == bool(output > threshold)
|
||||
]
|
||||
|
||||
def calc_metric(self, is_correct: bool, actual_output: bool, threshold=0.5) -> int:
|
||||
"""For example, calc_metric(False, True) calculates false positives"""
|
||||
return sum(
|
||||
self.matches_sample(output, target, threshold, is_correct, actual_output)
|
||||
for output, target, filename in zip(self.outputs, self.targets, self.filenames)
|
||||
return int(
|
||||
((((self.outputs > threshold) == self.targets.astype(bool)) == is_correct) &
|
||||
((self.outputs > threshold) == actual_output)).sum()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
|
Loading…
Reference in New Issue