Issue 88 and 120 - Fix
parent
ea7efb1f6c
commit
eeda42ef12
|
@ -68,14 +68,18 @@ class TensorFlowRunner(Runner):
|
|||
|
||||
class KerasRunner(Runner):
|
||||
def __init__(self, model_name: str):
|
||||
import os
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
||||
import tensorflow as tf
|
||||
# ISSUE 88 - Following 3 lines added to resolve issue 88 - JM 2020-02-04 per liny90626
|
||||
from tensorflow.python.keras.backend import set_session # ISSUE 88
|
||||
self.sess = tf.Session() # ISSUE 88
|
||||
set_session(self.sess) # ISSUE 88
|
||||
self.model = load_precise_model(model_name)
|
||||
self.graph = tf.get_default_graph()
|
||||
|
||||
def predict(self, inputs: np.ndarray):
|
||||
from tensorflow.python.keras.backend import set_session # ISSUE 88
|
||||
with self.graph.as_default():
|
||||
set_session(self.sess) # ISSUE 88
|
||||
return self.model.predict(inputs)
|
||||
|
||||
def run(self, inp: np.ndarray) -> float:
|
||||
|
|
Loading…
Reference in New Issue