Standardize arguments
parent
f31717bca9
commit
1b01d16386
|
@ -7,7 +7,7 @@
|
|||
import sys
|
||||
sys.path += ['.']
|
||||
|
||||
import argparse
|
||||
from argparse import ArgumentParser
|
||||
import os
|
||||
from os.path import split, isfile
|
||||
|
||||
|
@ -61,9 +61,9 @@ def convert(model_path, out_file):
|
|||
del sess
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Convert keyword model from Keras to TensorFlow')
|
||||
parser.add_argument('--model','-m', default='keyword.net', help='Input Keras model', type=argparse.FileType())
|
||||
parser.add_argument('--out', '-o', default='keyword.pb', help='Output TensorFlow protobuf')
|
||||
parser = ArgumentParser(description='Convert keyword model from Keras to TensorFlow')
|
||||
parser.add_argument('-m', '--model', default='keyword.net', help='Input Keras model')
|
||||
parser.add_argument('-o', '--out', default='keyword.pb', help='Output TensorFlow protobuf')
|
||||
args = parser.parse_args()
|
||||
|
||||
convert(args.model.name, args.out)
|
||||
convert(args.model, args.out)
|
||||
|
|
|
@ -11,18 +11,15 @@ from precise.common import *
|
|||
|
||||
def main():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument('-m', '--model', help='Keras model to load (default: keyword.net)',
|
||||
default='keyword.net')
|
||||
parser.add_argument('-d', '--data-dir',
|
||||
help='Directory to load test data from (default: data/test)',
|
||||
default='data/test')
|
||||
parser.add_argument('-m', '--model', default='keyword.net')
|
||||
parser.add_argument('-t', '--test-dir', default='data/test')
|
||||
parser.set_defaults(load=True, save_best=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
from keras.models import load_model
|
||||
|
||||
filenames = sum(find_wavs(args.data_dir), [])
|
||||
inputs, outputs = load_data(args.data_dir)
|
||||
filenames = sum(find_wavs(args.test_dir), [])
|
||||
inputs, outputs = load_data(args.test_dir)
|
||||
predictions = load_model(args.model).predict(inputs)
|
||||
|
||||
num_correct = 0
|
||||
|
|
|
@ -3,14 +3,16 @@
|
|||
import sys
|
||||
sys.path += ['.'] # noqa
|
||||
|
||||
import argparse
|
||||
from argparse import ArgumentParser
|
||||
import json
|
||||
|
||||
from precise.common import *
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument('-m', '--model', default='keyword.net')
|
||||
parser.add_argument('-d', '--data-dir', default='data')
|
||||
parser.add_argument('-e', '--epochs', type=int, default=10)
|
||||
parser.add_argument('-l', '--load', dest='load', action='store_true')
|
||||
parser.add_argument('-nl', '--no-load', dest='load', action='store_false')
|
||||
|
@ -19,19 +21,19 @@ def main():
|
|||
parser.set_defaults(load=True, save_best=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
inputs, outputs = load_data('data')
|
||||
validation_data = load_data('data/test')
|
||||
inputs, outputs = load_data(args.data_dir)
|
||||
validation_data = load_data(args.data_dir + '/test')
|
||||
|
||||
print('Inputs shape:', inputs.shape)
|
||||
print('Outputs shape:', outputs.shape)
|
||||
|
||||
model = create_model('keyword.net', args.load)
|
||||
model = create_model(args.model, args.load)
|
||||
|
||||
with open('keyword.net.params', 'w') as f:
|
||||
with open(args.model + '.params', 'w') as f:
|
||||
json.dump(pr._asdict(), f)
|
||||
|
||||
from keras.callbacks import ModelCheckpoint
|
||||
checkpoint = ModelCheckpoint('keyword.net', monitor='val_acc', save_best_only=args.save_best, mode='max')
|
||||
checkpoint = ModelCheckpoint(args.model, monitor='val_acc', save_best_only=args.save_best, mode='max')
|
||||
|
||||
try:
|
||||
model.fit(inputs, outputs, 5000, args.epochs, validation_data=validation_data, callbacks=[checkpoint])
|
||||
|
|
Loading…
Reference in New Issue