#!/usr/bin/env python3 # Attribution: This script was adapted from https://github.com/amir-abdi/keras_to_tensorflow # Copyright (c) 2017 Mycroft AI Inc. import sys sys.path += ['.'] # noqa import os from os.path import split, isfile from shutil import copyfile from prettyparse import create_parser usage = ''' Convert keyword model from Keras to TensorFlow :model str Input Keras model (.net) :-o --out str {model}.pb Custom output TensorFlow protobuf filename ''' def convert(model_path: str, out_file: str): """ Converts an HD5F file from Keras to a .pb for use with TensorFlow Args: model_path: location of Keras model out_file: location to write protobuf """ print('Converting', model_path, 'to', out_file, '...') import tensorflow as tf from precise.model import load_precise_model from keras import backend as K out_dir, filename = split(out_file) out_dir = out_dir or '.' os.makedirs(out_dir, exist_ok=True) K.set_learning_phase(0) model = load_precise_model(model_path) out_name = 'net_output' tf.identity(model.output, name=out_name) print('Output node name:', out_name) print('Output folder:', out_dir) sess = K.get_session() # Write the graph in human readable tf.train.write_graph(sess.graph.as_graph_def(), out_dir, filename + 'txt', as_text=True) print('Saved readable graph to:', filename + 'txt') # Write the graph in binary .pb file from tensorflow.python.framework import graph_util from tensorflow.python.framework import graph_io cgraph = graph_util.convert_variables_to_constants(sess, sess.graph.as_graph_def(), [out_name]) graph_io.write_graph(cgraph, out_dir, filename, as_text=False) if isfile(model_path + '.params'): copyfile(model_path + '.params', out_file + '.params') print('Saved graph to:', filename) del sess if __name__ == '__main__': args = create_parser(usage).parse_args() model_name = args.model.replace('.net', '') convert(args.model, args.out.format(model=model_name))