From 32a024c6412df1f45facd5a8afd8cd8dd2f1d596 Mon Sep 17 00:00:00 2001 From: SukramJ Date: Fri, 25 Oct 2019 01:41:13 +0200 Subject: [PATCH] Partially revert tensorflow import move (#28184) * Revert "Refactor imports for tensorflow (#27617)" This reverts commit 5a83a92390e8a3255885198c80622556f886b9b3. * move only some imports to top * fix lint * add comments --- .../components/tensorflow/image_processing.py | 47 ++++++++++--------- 1 file changed, 25 insertions(+), 22 deletions(-) diff --git a/homeassistant/components/tensorflow/image_processing.py b/homeassistant/components/tensorflow/image_processing.py index 1f49888cb95..ea73d52fe4a 100644 --- a/homeassistant/components/tensorflow/image_processing.py +++ b/homeassistant/components/tensorflow/image_processing.py @@ -1,23 +1,12 @@ """Support for performing TensorFlow classification on images.""" +import io import logging import os import sys -import io -import voluptuous as vol + from PIL import Image, ImageDraw import numpy as np - -try: - import cv2 -except ImportError: - cv2 = None - -try: - # Verify that the TensorFlow Object Detection API is pre-installed - import tensorflow as tf # noqa - from object_detection.utils import label_map_util # noqa -except ImportError: - label_map_util = None +import voluptuous as vol from homeassistant.components.image_processing import ( CONF_CONFIDENCE, @@ -98,8 +87,16 @@ def setup_platform(hass, config, add_entities, discovery_info=None): # append custom model path to sys.path sys.path.append(model_dir) - os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" - if label_map_util is None: + try: + # Verify that the TensorFlow Object Detection API is pre-installed + # pylint: disable=unused-import,unused-variable + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" + # These imports shouldn't be moved to the top, because they depend on code from the model_dir. + # (The model_dir is created during the manual setup process. See integration docs.) + import tensorflow as tf # noqa + from object_detection.utils import label_map_util # noqa + except ImportError: + # pylint: disable=line-too-long _LOGGER.error( "No TensorFlow Object Detection library found! Install or compile " "for your system following instructions here: " @@ -107,7 +104,11 @@ def setup_platform(hass, config, add_entities, discovery_info=None): ) # noqa return - if cv2 is None: + try: + # Display warning that PIL will be used if no OpenCV is found. + # pylint: disable=unused-import,unused-variable + import cv2 # noqa + except ImportError: _LOGGER.warning( "No OpenCV library found. TensorFlow will process image with " "PIL at reduced resolution" @@ -282,7 +283,13 @@ class TensorFlowImageProcessor(ImageProcessingEntity): def process_image(self, image): """Process the image.""" - if cv2 is None: + try: + import cv2 # pylint: disable=import-error + + img = cv2.imdecode(np.asarray(bytearray(image)), cv2.IMREAD_UNCHANGED) + inp = img[:, :, [2, 1, 0]] # BGR->RGB + inp_expanded = inp.reshape(1, inp.shape[0], inp.shape[1], 3) + except ImportError: img = Image.open(io.BytesIO(bytearray(image))).convert("RGB") img.thumbnail((460, 460), Image.ANTIALIAS) img_width, img_height = img.size @@ -292,10 +299,6 @@ class TensorFlowImageProcessor(ImageProcessingEntity): .astype(np.uint8) ) inp_expanded = np.expand_dims(inp, axis=0) - else: - img = cv2.imdecode(np.asarray(bytearray(image)), cv2.IMREAD_UNCHANGED) - inp = img[:, :, [2, 1, 0]] # BGR->RGB - inp_expanded = inp.reshape(1, inp.shape[0], inp.shape[1], 3) image_tensor = self._graph.get_tensor_by_name("image_tensor:0") boxes = self._graph.get_tensor_by_name("detection_boxes:0")