45 lines
1.4 KiB
Python
Executable File
45 lines
1.4 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
|
|
import sys
|
|
sys.path += ['.'] # noqa
|
|
|
|
from argparse import ArgumentParser
|
|
import json
|
|
|
|
from precise.common import *
|
|
|
|
|
|
def main():
|
|
parser = ArgumentParser()
|
|
parser.add_argument('-m', '--model', default='keyword.net')
|
|
parser.add_argument('-d', '--data-dir', default='data')
|
|
parser.add_argument('-e', '--epochs', type=int, default=10)
|
|
parser.add_argument('-l', '--load', dest='load', action='store_true')
|
|
parser.add_argument('-nl', '--no-load', dest='load', action='store_false')
|
|
parser.add_argument('-b', '--save-best', dest='save_best', action='store_true')
|
|
parser.add_argument('-nb', '--no-save-best', dest='save_best', action='store_false')
|
|
parser.set_defaults(load=True, save_best=True)
|
|
args = parser.parse_args()
|
|
|
|
inputs, outputs = load_data(args.data_dir)
|
|
validation_data = load_data(args.data_dir + '/test')
|
|
|
|
print('Inputs shape:', inputs.shape)
|
|
print('Outputs shape:', outputs.shape)
|
|
|
|
model = create_model(args.model, args.load)
|
|
|
|
with open(args.model + '.params', 'w') as f:
|
|
json.dump(pr._asdict(), f)
|
|
|
|
from keras.callbacks import ModelCheckpoint
|
|
checkpoint = ModelCheckpoint(args.model, monitor='val_acc', save_best_only=args.save_best, mode='max')
|
|
|
|
try:
|
|
model.fit(inputs, outputs, 5000, args.epochs, validation_data=validation_data, callbacks=[checkpoint])
|
|
except KeyboardInterrupt:
|
|
print()
|
|
|
|
if __name__ == '__main__':
|
|
main()
|