Partially revert tensorflow import move (#28184)

* Revert "Refactor imports for tensorflow (#27617)"

This reverts commit 5a83a92390.

* move only some imports to top

* fix lint

* add comments
pull/28192/head
SukramJ 2019-10-25 01:41:13 +02:00 committed by Paulus Schoutsen
parent 67cf7c26da
commit 32a024c641
1 changed files with 25 additions and 22 deletions

View File

@ -1,23 +1,12 @@
"""Support for performing TensorFlow classification on images.""" """Support for performing TensorFlow classification on images."""
import io
import logging import logging
import os import os
import sys import sys
import io
import voluptuous as vol
from PIL import Image, ImageDraw from PIL import Image, ImageDraw
import numpy as np import numpy as np
import voluptuous as vol
try:
import cv2
except ImportError:
cv2 = None
try:
# Verify that the TensorFlow Object Detection API is pre-installed
import tensorflow as tf # noqa
from object_detection.utils import label_map_util # noqa
except ImportError:
label_map_util = None
from homeassistant.components.image_processing import ( from homeassistant.components.image_processing import (
CONF_CONFIDENCE, CONF_CONFIDENCE,
@ -98,8 +87,16 @@ def setup_platform(hass, config, add_entities, discovery_info=None):
# append custom model path to sys.path # append custom model path to sys.path
sys.path.append(model_dir) sys.path.append(model_dir)
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" try:
if label_map_util is None: # Verify that the TensorFlow Object Detection API is pre-installed
# pylint: disable=unused-import,unused-variable
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
# These imports shouldn't be moved to the top, because they depend on code from the model_dir.
# (The model_dir is created during the manual setup process. See integration docs.)
import tensorflow as tf # noqa
from object_detection.utils import label_map_util # noqa
except ImportError:
# pylint: disable=line-too-long
_LOGGER.error( _LOGGER.error(
"No TensorFlow Object Detection library found! Install or compile " "No TensorFlow Object Detection library found! Install or compile "
"for your system following instructions here: " "for your system following instructions here: "
@ -107,7 +104,11 @@ def setup_platform(hass, config, add_entities, discovery_info=None):
) # noqa ) # noqa
return return
if cv2 is None: try:
# Display warning that PIL will be used if no OpenCV is found.
# pylint: disable=unused-import,unused-variable
import cv2 # noqa
except ImportError:
_LOGGER.warning( _LOGGER.warning(
"No OpenCV library found. TensorFlow will process image with " "No OpenCV library found. TensorFlow will process image with "
"PIL at reduced resolution" "PIL at reduced resolution"
@ -282,7 +283,13 @@ class TensorFlowImageProcessor(ImageProcessingEntity):
def process_image(self, image): def process_image(self, image):
"""Process the image.""" """Process the image."""
if cv2 is None: try:
import cv2 # pylint: disable=import-error
img = cv2.imdecode(np.asarray(bytearray(image)), cv2.IMREAD_UNCHANGED)
inp = img[:, :, [2, 1, 0]] # BGR->RGB
inp_expanded = inp.reshape(1, inp.shape[0], inp.shape[1], 3)
except ImportError:
img = Image.open(io.BytesIO(bytearray(image))).convert("RGB") img = Image.open(io.BytesIO(bytearray(image))).convert("RGB")
img.thumbnail((460, 460), Image.ANTIALIAS) img.thumbnail((460, 460), Image.ANTIALIAS)
img_width, img_height = img.size img_width, img_height = img.size
@ -292,10 +299,6 @@ class TensorFlowImageProcessor(ImageProcessingEntity):
.astype(np.uint8) .astype(np.uint8)
) )
inp_expanded = np.expand_dims(inp, axis=0) inp_expanded = np.expand_dims(inp, axis=0)
else:
img = cv2.imdecode(np.asarray(bytearray(image)), cv2.IMREAD_UNCHANGED)
inp = img[:, :, [2, 1, 0]] # BGR->RGB
inp_expanded = inp.reshape(1, inp.shape[0], inp.shape[1], 3)
image_tensor = self._graph.get_tensor_by_name("image_tensor:0") image_tensor = self._graph.get_tensor_by_name("image_tensor:0")
boxes = self._graph.get_tensor_by_name("detection_boxes:0") boxes = self._graph.get_tensor_by_name("detection_boxes:0")