Partially revert tensorflow import move (#28184)
* Revert "Refactor imports for tensorflow (#27617)"
This reverts commit 5a83a92390
.
* move only some imports to top
* fix lint
* add comments
pull/28192/head
parent
67cf7c26da
commit
32a024c641
|
@ -1,23 +1,12 @@
|
|||
"""Support for performing TensorFlow classification on images."""
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import io
|
||||
import voluptuous as vol
|
||||
|
||||
from PIL import Image, ImageDraw
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
import cv2
|
||||
except ImportError:
|
||||
cv2 = None
|
||||
|
||||
try:
|
||||
# Verify that the TensorFlow Object Detection API is pre-installed
|
||||
import tensorflow as tf # noqa
|
||||
from object_detection.utils import label_map_util # noqa
|
||||
except ImportError:
|
||||
label_map_util = None
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components.image_processing import (
|
||||
CONF_CONFIDENCE,
|
||||
|
@ -98,8 +87,16 @@ def setup_platform(hass, config, add_entities, discovery_info=None):
|
|||
# append custom model path to sys.path
|
||||
sys.path.append(model_dir)
|
||||
|
||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
|
||||
if label_map_util is None:
|
||||
try:
|
||||
# Verify that the TensorFlow Object Detection API is pre-installed
|
||||
# pylint: disable=unused-import,unused-variable
|
||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
|
||||
# These imports shouldn't be moved to the top, because they depend on code from the model_dir.
|
||||
# (The model_dir is created during the manual setup process. See integration docs.)
|
||||
import tensorflow as tf # noqa
|
||||
from object_detection.utils import label_map_util # noqa
|
||||
except ImportError:
|
||||
# pylint: disable=line-too-long
|
||||
_LOGGER.error(
|
||||
"No TensorFlow Object Detection library found! Install or compile "
|
||||
"for your system following instructions here: "
|
||||
|
@ -107,7 +104,11 @@ def setup_platform(hass, config, add_entities, discovery_info=None):
|
|||
) # noqa
|
||||
return
|
||||
|
||||
if cv2 is None:
|
||||
try:
|
||||
# Display warning that PIL will be used if no OpenCV is found.
|
||||
# pylint: disable=unused-import,unused-variable
|
||||
import cv2 # noqa
|
||||
except ImportError:
|
||||
_LOGGER.warning(
|
||||
"No OpenCV library found. TensorFlow will process image with "
|
||||
"PIL at reduced resolution"
|
||||
|
@ -282,7 +283,13 @@ class TensorFlowImageProcessor(ImageProcessingEntity):
|
|||
def process_image(self, image):
|
||||
"""Process the image."""
|
||||
|
||||
if cv2 is None:
|
||||
try:
|
||||
import cv2 # pylint: disable=import-error
|
||||
|
||||
img = cv2.imdecode(np.asarray(bytearray(image)), cv2.IMREAD_UNCHANGED)
|
||||
inp = img[:, :, [2, 1, 0]] # BGR->RGB
|
||||
inp_expanded = inp.reshape(1, inp.shape[0], inp.shape[1], 3)
|
||||
except ImportError:
|
||||
img = Image.open(io.BytesIO(bytearray(image))).convert("RGB")
|
||||
img.thumbnail((460, 460), Image.ANTIALIAS)
|
||||
img_width, img_height = img.size
|
||||
|
@ -292,10 +299,6 @@ class TensorFlowImageProcessor(ImageProcessingEntity):
|
|||
.astype(np.uint8)
|
||||
)
|
||||
inp_expanded = np.expand_dims(inp, axis=0)
|
||||
else:
|
||||
img = cv2.imdecode(np.asarray(bytearray(image)), cv2.IMREAD_UNCHANGED)
|
||||
inp = img[:, :, [2, 1, 0]] # BGR->RGB
|
||||
inp_expanded = inp.reshape(1, inp.shape[0], inp.shape[1], 3)
|
||||
|
||||
image_tensor = self._graph.get_tensor_by_name("image_tensor:0")
|
||||
boxes = self._graph.get_tensor_by_name("detection_boxes:0")
|
||||
|
|
Loading…
Reference in New Issue