Add option to print stats to file
parent
b4c28e1771
commit
b3f2ff8a2d
|
@ -43,6 +43,12 @@ usage = '''
|
|||
:-l --labels
|
||||
Print labels attached to each point
|
||||
|
||||
:-o --output-file str -
|
||||
File to write data instead of displaying it
|
||||
|
||||
:-i --input-file str -
|
||||
File to read data from and visualize
|
||||
|
||||
...
|
||||
'''
|
||||
|
||||
|
@ -74,26 +80,20 @@ class CachedDataLoader:
|
|||
return self.data
|
||||
|
||||
|
||||
def main():
|
||||
def load_plt():
|
||||
try:
|
||||
import matplotlib.pyplot as plt
|
||||
return plt
|
||||
except ImportError:
|
||||
print('Please install matplotlib first')
|
||||
raise SystemExit(2)
|
||||
|
||||
parser = create_parser(usage)
|
||||
parser.add_argument('models', nargs='*', help='Either Keras (.net) or TensorFlow (.pb) models to test')
|
||||
args = TrainData.parse_args(parser)
|
||||
|
||||
data = TrainData.from_both(args.tags_file, args.tags_folder, args.folder)
|
||||
filenames = sum(data.train_files if args.use_train else data.test_files, [])
|
||||
loader = CachedDataLoader(partial(
|
||||
data.load, args.use_train, not args.use_train, shuffle=False
|
||||
))
|
||||
|
||||
for model in args.models:
|
||||
def calc_stats(model_files, loader, use_train, filenames):
|
||||
model_data = {}
|
||||
for model in model_files:
|
||||
train, test = loader.load_for(model)
|
||||
inputs, targets = train if args.use_train else test
|
||||
inputs, targets = train if use_train else test
|
||||
print('Running network...')
|
||||
predictions = Listener.find_runner(model)(model).predict(inputs)
|
||||
print(inputs.shape, targets.shape)
|
||||
|
@ -102,11 +102,45 @@ def main():
|
|||
stats = Stats(predictions, targets, filenames)
|
||||
print('\n' + stats.counts_str() + '\n\n' + stats.summary_str() + '\n')
|
||||
|
||||
model_name = basename(splitext(model)[0])
|
||||
model_data[model_name] = stats
|
||||
return model_data
|
||||
|
||||
|
||||
def main():
|
||||
parser = create_parser(usage)
|
||||
parser.add_argument('models', nargs='*', help='Either Keras (.net) or TensorFlow (.pb) models to test')
|
||||
args = TrainData.parse_args(parser)
|
||||
if bool(args.models) == bool(args.input_file):
|
||||
parser.error('Please specify either a list of models or an input file')
|
||||
|
||||
if not args.output_file:
|
||||
load_plt() # Error early if matplotlib not installed
|
||||
import numpy as np
|
||||
|
||||
data = TrainData.from_both(args.tags_file, args.tags_folder, args.folder)
|
||||
filenames = sum(data.train_files if args.use_train else data.test_files, [])
|
||||
loader = CachedDataLoader(partial(
|
||||
data.load, args.use_train, not args.use_train, shuffle=False
|
||||
))
|
||||
|
||||
if args.models:
|
||||
model_data = calc_stats(args.models, loader, args.use_train, filenames)
|
||||
else:
|
||||
model_data = {
|
||||
name: Stats.from_np_dict(data) for name, data in np.load(args.input_file)['data'].item().items()
|
||||
}
|
||||
|
||||
if args.output_file:
|
||||
np.savez(args.output_file, data={name: stats.to_np_dict() for name, stats in model_data.items()})
|
||||
else:
|
||||
plt = load_plt()
|
||||
thresholds = get_thresholds(args.resolution, args.power)
|
||||
for model_name, stats in model_data.items():
|
||||
x = [stats.false_positives(i) for i in thresholds]
|
||||
y = [stats.false_negatives(i) for i in thresholds]
|
||||
|
||||
plt.plot(x, y, marker='x', linestyle='-', label=basename(splitext(model)[0]))
|
||||
plt.plot(x, y, marker='x', linestyle='-', label=model_name)
|
||||
|
||||
if args.labels:
|
||||
for x, y, threshold in zip(x, y, thresholds):
|
||||
|
|
|
@ -55,6 +55,18 @@ class Stats:
|
|||
def __len__(self):
|
||||
return len(self.outputs)
|
||||
|
||||
def to_np_dict(self):
|
||||
import numpy as np
|
||||
return {
|
||||
'outputs': self.outputs,
|
||||
'targets': self.targets,
|
||||
'filenames': np.array(self.filenames)
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def from_np_dict(data) -> 'Stats':
|
||||
return Stats(data['outputs'], data['targets'], data['filenames'])
|
||||
|
||||
def to_dict(self, threshold=0.5):
|
||||
return {
|
||||
'true_pos': self.calc_metric(True, True, threshold),
|
||||
|
|
Loading…
Reference in New Issue