Standardize arguments

pull/1/head
Matthew D. Scholefield 2017-11-08 21:46:54 -06:00
parent f31717bca9
commit 1b01d16386
3 changed files with 18 additions and 19 deletions

View File

@ -7,7 +7,7 @@
import sys
sys.path += ['.']
import argparse
from argparse import ArgumentParser
import os
from os.path import split, isfile
@ -61,9 +61,9 @@ def convert(model_path, out_file):
del sess
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Convert keyword model from Keras to TensorFlow')
parser.add_argument('--model','-m', default='keyword.net', help='Input Keras model', type=argparse.FileType())
parser.add_argument('--out', '-o', default='keyword.pb', help='Output TensorFlow protobuf')
parser = ArgumentParser(description='Convert keyword model from Keras to TensorFlow')
parser.add_argument('-m', '--model', default='keyword.net', help='Input Keras model')
parser.add_argument('-o', '--out', default='keyword.pb', help='Output TensorFlow protobuf')
args = parser.parse_args()
convert(args.model.name, args.out)
convert(args.model, args.out)

View File

@ -11,18 +11,15 @@ from precise.common import *
def main():
parser = ArgumentParser()
parser.add_argument('-m', '--model', help='Keras model to load (default: keyword.net)',
default='keyword.net')
parser.add_argument('-d', '--data-dir',
help='Directory to load test data from (default: data/test)',
default='data/test')
parser.add_argument('-m', '--model', default='keyword.net')
parser.add_argument('-t', '--test-dir', default='data/test')
parser.set_defaults(load=True, save_best=True)
args = parser.parse_args()
from keras.models import load_model
filenames = sum(find_wavs(args.data_dir), [])
inputs, outputs = load_data(args.data_dir)
filenames = sum(find_wavs(args.test_dir), [])
inputs, outputs = load_data(args.test_dir)
predictions = load_model(args.model).predict(inputs)
num_correct = 0

View File

@ -3,14 +3,16 @@
import sys
sys.path += ['.'] # noqa
import argparse
from argparse import ArgumentParser
import json
from precise.common import *
def main():
parser = argparse.ArgumentParser()
parser = ArgumentParser()
parser.add_argument('-m', '--model', default='keyword.net')
parser.add_argument('-d', '--data-dir', default='data')
parser.add_argument('-e', '--epochs', type=int, default=10)
parser.add_argument('-l', '--load', dest='load', action='store_true')
parser.add_argument('-nl', '--no-load', dest='load', action='store_false')
@ -19,19 +21,19 @@ def main():
parser.set_defaults(load=True, save_best=True)
args = parser.parse_args()
inputs, outputs = load_data('data')
validation_data = load_data('data/test')
inputs, outputs = load_data(args.data_dir)
validation_data = load_data(args.data_dir + '/test')
print('Inputs shape:', inputs.shape)
print('Outputs shape:', outputs.shape)
model = create_model('keyword.net', args.load)
model = create_model(args.model, args.load)
with open('keyword.net.params', 'w') as f:
with open(args.model + '.params', 'w') as f:
json.dump(pr._asdict(), f)
from keras.callbacks import ModelCheckpoint
checkpoint = ModelCheckpoint('keyword.net', monitor='val_acc', save_best_only=args.save_best, mode='max')
checkpoint = ModelCheckpoint(args.model, monitor='val_acc', save_best_only=args.save_best, mode='max')
try:
model.fit(inputs, outputs, 5000, args.epochs, validation_data=validation_data, callbacks=[checkpoint])