2017-10-30 15:01:26 +00:00
|
|
|
#!/usr/bin/env python3
|
|
|
|
# Attribution: This script was adapted from https://github.com/amir-abdi/keras_to_tensorflow
|
|
|
|
# Copyright (c) 2017 Mycroft AI Inc.
|
|
|
|
import os
|
|
|
|
from os.path import split, isfile
|
|
|
|
from shutil import copyfile
|
2018-02-22 07:10:41 +00:00
|
|
|
|
2018-02-21 05:42:04 +00:00
|
|
|
from prettyparse import create_parser
|
2018-02-09 00:43:03 +00:00
|
|
|
|
2018-02-21 05:42:04 +00:00
|
|
|
usage = '''
|
|
|
|
Convert keyword model from Keras to TensorFlow
|
|
|
|
|
|
|
|
:model str
|
|
|
|
Input Keras model (.net)
|
|
|
|
|
|
|
|
:-o --out str {model}.pb
|
|
|
|
Custom output TensorFlow protobuf filename
|
|
|
|
'''
|
2017-10-30 15:01:26 +00:00
|
|
|
|
|
|
|
|
2017-11-21 22:30:22 +00:00
|
|
|
def convert(model_path: str, out_file: str):
|
2017-10-30 15:01:26 +00:00
|
|
|
"""
|
|
|
|
Converts an HD5F file from Keras to a .pb for use with TensorFlow
|
|
|
|
|
|
|
|
Args:
|
2017-11-21 22:30:22 +00:00
|
|
|
model_path: location of Keras model
|
2018-02-09 00:43:03 +00:00
|
|
|
out_file: location to write protobuf
|
2017-10-30 15:01:26 +00:00
|
|
|
"""
|
|
|
|
print('Converting', model_path, 'to', out_file, '...')
|
|
|
|
|
|
|
|
import tensorflow as tf
|
2018-02-21 05:42:04 +00:00
|
|
|
from precise.model import load_precise_model
|
2017-10-30 15:01:26 +00:00
|
|
|
from keras import backend as K
|
|
|
|
|
|
|
|
out_dir, filename = split(out_file)
|
|
|
|
out_dir = out_dir or '.'
|
|
|
|
os.makedirs(out_dir, exist_ok=True)
|
|
|
|
|
|
|
|
K.set_learning_phase(0)
|
2017-11-21 20:12:02 +00:00
|
|
|
model = load_precise_model(model_path)
|
2017-10-30 15:01:26 +00:00
|
|
|
|
|
|
|
out_name = 'net_output'
|
|
|
|
tf.identity(model.output, name=out_name)
|
|
|
|
print('Output node name:', out_name)
|
|
|
|
print('Output folder:', out_dir)
|
|
|
|
|
|
|
|
sess = K.get_session()
|
|
|
|
|
|
|
|
# Write the graph in human readable
|
|
|
|
tf.train.write_graph(sess.graph.as_graph_def(), out_dir, filename + 'txt', as_text=True)
|
|
|
|
print('Saved readable graph to:', filename + 'txt')
|
|
|
|
|
|
|
|
# Write the graph in binary .pb file
|
|
|
|
from tensorflow.python.framework import graph_util
|
|
|
|
from tensorflow.python.framework import graph_io
|
|
|
|
|
|
|
|
cgraph = graph_util.convert_variables_to_constants(sess, sess.graph.as_graph_def(), [out_name])
|
|
|
|
graph_io.write_graph(cgraph, out_dir, filename, as_text=False)
|
|
|
|
|
|
|
|
if isfile(model_path + '.params'):
|
|
|
|
copyfile(model_path + '.params', out_file + '.params')
|
|
|
|
|
|
|
|
print('Saved graph to:', filename)
|
|
|
|
|
|
|
|
del sess
|
|
|
|
|
2018-02-09 00:43:03 +00:00
|
|
|
|
2017-10-30 15:01:26 +00:00
|
|
|
if __name__ == '__main__':
|
2018-02-09 00:43:03 +00:00
|
|
|
args = create_parser(usage).parse_args()
|
2017-10-30 15:01:26 +00:00
|
|
|
|
2018-02-09 00:43:03 +00:00
|
|
|
model_name = args.model.replace('.net', '')
|
|
|
|
convert(args.model, args.out.format(model=model_name))
|