370 lines
13 KiB
Python
370 lines
13 KiB
Python
"""Support for performing TensorFlow classification on images."""
|
|
import io
|
|
import logging
|
|
import os
|
|
import sys
|
|
|
|
from PIL import Image, ImageDraw, UnidentifiedImageError
|
|
import numpy as np
|
|
import voluptuous as vol
|
|
|
|
from homeassistant.components.image_processing import (
|
|
CONF_CONFIDENCE,
|
|
CONF_ENTITY_ID,
|
|
CONF_NAME,
|
|
CONF_SOURCE,
|
|
PLATFORM_SCHEMA,
|
|
ImageProcessingEntity,
|
|
)
|
|
from homeassistant.core import split_entity_id
|
|
from homeassistant.helpers import template
|
|
import homeassistant.helpers.config_validation as cv
|
|
from homeassistant.util.pil import draw_box
|
|
|
|
_LOGGER = logging.getLogger(__name__)
|
|
|
|
ATTR_MATCHES = "matches"
|
|
ATTR_SUMMARY = "summary"
|
|
ATTR_TOTAL_MATCHES = "total_matches"
|
|
|
|
CONF_AREA = "area"
|
|
CONF_BOTTOM = "bottom"
|
|
CONF_CATEGORIES = "categories"
|
|
CONF_CATEGORY = "category"
|
|
CONF_FILE_OUT = "file_out"
|
|
CONF_GRAPH = "graph"
|
|
CONF_LABELS = "labels"
|
|
CONF_LEFT = "left"
|
|
CONF_MODEL = "model"
|
|
CONF_MODEL_DIR = "model_dir"
|
|
CONF_RIGHT = "right"
|
|
CONF_TOP = "top"
|
|
|
|
AREA_SCHEMA = vol.Schema(
|
|
{
|
|
vol.Optional(CONF_BOTTOM, default=1): cv.small_float,
|
|
vol.Optional(CONF_LEFT, default=0): cv.small_float,
|
|
vol.Optional(CONF_RIGHT, default=1): cv.small_float,
|
|
vol.Optional(CONF_TOP, default=0): cv.small_float,
|
|
}
|
|
)
|
|
|
|
CATEGORY_SCHEMA = vol.Schema(
|
|
{vol.Required(CONF_CATEGORY): cv.string, vol.Optional(CONF_AREA): AREA_SCHEMA}
|
|
)
|
|
|
|
PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(
|
|
{
|
|
vol.Optional(CONF_FILE_OUT, default=[]): vol.All(cv.ensure_list, [cv.template]),
|
|
vol.Required(CONF_MODEL): vol.Schema(
|
|
{
|
|
vol.Required(CONF_GRAPH): cv.isfile,
|
|
vol.Optional(CONF_AREA): AREA_SCHEMA,
|
|
vol.Optional(CONF_CATEGORIES, default=[]): vol.All(
|
|
cv.ensure_list, [vol.Any(cv.string, CATEGORY_SCHEMA)]
|
|
),
|
|
vol.Optional(CONF_LABELS): cv.isfile,
|
|
vol.Optional(CONF_MODEL_DIR): cv.isdir,
|
|
}
|
|
),
|
|
}
|
|
)
|
|
|
|
|
|
def setup_platform(hass, config, add_entities, discovery_info=None):
|
|
"""Set up the TensorFlow image processing platform."""
|
|
model_config = config.get(CONF_MODEL)
|
|
model_dir = model_config.get(CONF_MODEL_DIR) or hass.config.path("tensorflow")
|
|
labels = model_config.get(CONF_LABELS) or hass.config.path(
|
|
"tensorflow", "object_detection", "data", "mscoco_label_map.pbtxt"
|
|
)
|
|
|
|
# Make sure locations exist
|
|
if not os.path.isdir(model_dir) or not os.path.exists(labels):
|
|
_LOGGER.error("Unable to locate tensorflow models or label map")
|
|
return
|
|
|
|
# append custom model path to sys.path
|
|
sys.path.append(model_dir)
|
|
|
|
try:
|
|
# Verify that the TensorFlow Object Detection API is pre-installed
|
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
|
|
# These imports shouldn't be moved to the top, because they depend on code from the model_dir.
|
|
# (The model_dir is created during the manual setup process. See integration docs.)
|
|
import tensorflow as tf # pylint: disable=import-outside-toplevel
|
|
|
|
# pylint: disable=import-outside-toplevel
|
|
from object_detection.utils import label_map_util
|
|
except ImportError:
|
|
_LOGGER.error(
|
|
"No TensorFlow Object Detection library found! Install or compile "
|
|
"for your system following instructions here: "
|
|
"https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/installation.md"
|
|
)
|
|
return
|
|
|
|
try:
|
|
# Display warning that PIL will be used if no OpenCV is found.
|
|
import cv2 # noqa: F401 pylint: disable=unused-import, import-outside-toplevel
|
|
except ImportError:
|
|
_LOGGER.warning(
|
|
"No OpenCV library found. TensorFlow will process image with "
|
|
"PIL at reduced resolution"
|
|
)
|
|
|
|
# Set up Tensorflow graph, session, and label map to pass to processor
|
|
# pylint: disable=no-member
|
|
detection_graph = tf.Graph()
|
|
with detection_graph.as_default():
|
|
od_graph_def = tf.GraphDef()
|
|
with tf.gfile.GFile(model_config.get(CONF_GRAPH), "rb") as fid:
|
|
serialized_graph = fid.read()
|
|
od_graph_def.ParseFromString(serialized_graph)
|
|
tf.import_graph_def(od_graph_def, name="")
|
|
|
|
session = tf.Session(graph=detection_graph)
|
|
label_map = label_map_util.load_labelmap(labels)
|
|
categories = label_map_util.convert_label_map_to_categories(
|
|
label_map, max_num_classes=90, use_display_name=True
|
|
)
|
|
category_index = label_map_util.create_category_index(categories)
|
|
|
|
entities = []
|
|
|
|
for camera in config[CONF_SOURCE]:
|
|
entities.append(
|
|
TensorFlowImageProcessor(
|
|
hass,
|
|
camera[CONF_ENTITY_ID],
|
|
camera.get(CONF_NAME),
|
|
session,
|
|
detection_graph,
|
|
category_index,
|
|
config,
|
|
)
|
|
)
|
|
|
|
add_entities(entities)
|
|
|
|
|
|
class TensorFlowImageProcessor(ImageProcessingEntity):
|
|
"""Representation of an TensorFlow image processor."""
|
|
|
|
def __init__(
|
|
self,
|
|
hass,
|
|
camera_entity,
|
|
name,
|
|
session,
|
|
detection_graph,
|
|
category_index,
|
|
config,
|
|
):
|
|
"""Initialize the TensorFlow entity."""
|
|
model_config = config.get(CONF_MODEL)
|
|
self.hass = hass
|
|
self._camera_entity = camera_entity
|
|
if name:
|
|
self._name = name
|
|
else:
|
|
self._name = "TensorFlow {}".format(split_entity_id(camera_entity)[1])
|
|
self._session = session
|
|
self._graph = detection_graph
|
|
self._category_index = category_index
|
|
self._min_confidence = config.get(CONF_CONFIDENCE)
|
|
self._file_out = config.get(CONF_FILE_OUT)
|
|
|
|
# handle categories and specific detection areas
|
|
categories = model_config.get(CONF_CATEGORIES)
|
|
self._include_categories = []
|
|
self._category_areas = {}
|
|
for category in categories:
|
|
if isinstance(category, dict):
|
|
category_name = category.get(CONF_CATEGORY)
|
|
category_area = category.get(CONF_AREA)
|
|
self._include_categories.append(category_name)
|
|
self._category_areas[category_name] = [0, 0, 1, 1]
|
|
if category_area:
|
|
self._category_areas[category_name] = [
|
|
category_area.get(CONF_TOP),
|
|
category_area.get(CONF_LEFT),
|
|
category_area.get(CONF_BOTTOM),
|
|
category_area.get(CONF_RIGHT),
|
|
]
|
|
else:
|
|
self._include_categories.append(category)
|
|
self._category_areas[category] = [0, 0, 1, 1]
|
|
|
|
# Handle global detection area
|
|
self._area = [0, 0, 1, 1]
|
|
area_config = model_config.get(CONF_AREA)
|
|
if area_config:
|
|
self._area = [
|
|
area_config.get(CONF_TOP),
|
|
area_config.get(CONF_LEFT),
|
|
area_config.get(CONF_BOTTOM),
|
|
area_config.get(CONF_RIGHT),
|
|
]
|
|
|
|
template.attach(hass, self._file_out)
|
|
|
|
self._matches = {}
|
|
self._total_matches = 0
|
|
self._last_image = None
|
|
|
|
@property
|
|
def camera_entity(self):
|
|
"""Return camera entity id from process pictures."""
|
|
return self._camera_entity
|
|
|
|
@property
|
|
def name(self):
|
|
"""Return the name of the image processor."""
|
|
return self._name
|
|
|
|
@property
|
|
def state(self):
|
|
"""Return the state of the entity."""
|
|
return self._total_matches
|
|
|
|
@property
|
|
def device_state_attributes(self):
|
|
"""Return device specific state attributes."""
|
|
return {
|
|
ATTR_MATCHES: self._matches,
|
|
ATTR_SUMMARY: {
|
|
category: len(values) for category, values in self._matches.items()
|
|
},
|
|
ATTR_TOTAL_MATCHES: self._total_matches,
|
|
}
|
|
|
|
def _save_image(self, image, matches, paths):
|
|
img = Image.open(io.BytesIO(bytearray(image))).convert("RGB")
|
|
img_width, img_height = img.size
|
|
draw = ImageDraw.Draw(img)
|
|
|
|
# Draw custom global region/area
|
|
if self._area != [0, 0, 1, 1]:
|
|
draw_box(
|
|
draw, self._area, img_width, img_height, "Detection Area", (0, 255, 255)
|
|
)
|
|
|
|
for category, values in matches.items():
|
|
# Draw custom category regions/areas
|
|
if category in self._category_areas and self._category_areas[category] != [
|
|
0,
|
|
0,
|
|
1,
|
|
1,
|
|
]:
|
|
label = f"{category.capitalize()} Detection Area"
|
|
draw_box(
|
|
draw,
|
|
self._category_areas[category],
|
|
img_width,
|
|
img_height,
|
|
label,
|
|
(0, 255, 0),
|
|
)
|
|
|
|
# Draw detected objects
|
|
for instance in values:
|
|
label = "{} {:.1f}%".format(category, instance["score"])
|
|
draw_box(
|
|
draw, instance["box"], img_width, img_height, label, (255, 255, 0)
|
|
)
|
|
|
|
for path in paths:
|
|
_LOGGER.info("Saving results image to %s", path)
|
|
img.save(path)
|
|
|
|
def process_image(self, image):
|
|
"""Process the image."""
|
|
|
|
try:
|
|
import cv2 # pylint: disable=import-error, import-outside-toplevel
|
|
|
|
img = cv2.imdecode(np.asarray(bytearray(image)), cv2.IMREAD_UNCHANGED)
|
|
inp = img[:, :, [2, 1, 0]] # BGR->RGB
|
|
inp_expanded = inp.reshape(1, inp.shape[0], inp.shape[1], 3)
|
|
except ImportError:
|
|
try:
|
|
img = Image.open(io.BytesIO(bytearray(image))).convert("RGB")
|
|
except UnidentifiedImageError:
|
|
_LOGGER.warning("Unable to process image, bad data")
|
|
return
|
|
img.thumbnail((460, 460), Image.ANTIALIAS)
|
|
img_width, img_height = img.size
|
|
inp = (
|
|
np.array(img.getdata())
|
|
.reshape((img_height, img_width, 3))
|
|
.astype(np.uint8)
|
|
)
|
|
inp_expanded = np.expand_dims(inp, axis=0)
|
|
|
|
image_tensor = self._graph.get_tensor_by_name("image_tensor:0")
|
|
boxes = self._graph.get_tensor_by_name("detection_boxes:0")
|
|
scores = self._graph.get_tensor_by_name("detection_scores:0")
|
|
classes = self._graph.get_tensor_by_name("detection_classes:0")
|
|
boxes, scores, classes = self._session.run(
|
|
[boxes, scores, classes], feed_dict={image_tensor: inp_expanded}
|
|
)
|
|
boxes, scores, classes = map(np.squeeze, [boxes, scores, classes])
|
|
classes = classes.astype(int)
|
|
|
|
matches = {}
|
|
total_matches = 0
|
|
for box, score, obj_class in zip(boxes, scores, classes):
|
|
score = score * 100
|
|
boxes = box.tolist()
|
|
|
|
# Exclude matches below min confidence value
|
|
if score < self._min_confidence:
|
|
continue
|
|
|
|
# Exclude matches outside global area definition
|
|
if (
|
|
boxes[0] < self._area[0]
|
|
or boxes[1] < self._area[1]
|
|
or boxes[2] > self._area[2]
|
|
or boxes[3] > self._area[3]
|
|
):
|
|
continue
|
|
|
|
category = self._category_index[obj_class]["name"]
|
|
|
|
# Exclude unlisted categories
|
|
if self._include_categories and category not in self._include_categories:
|
|
continue
|
|
|
|
# Exclude matches outside category specific area definition
|
|
if self._category_areas and (
|
|
boxes[0] < self._category_areas[category][0]
|
|
or boxes[1] < self._category_areas[category][1]
|
|
or boxes[2] > self._category_areas[category][2]
|
|
or boxes[3] > self._category_areas[category][3]
|
|
):
|
|
continue
|
|
|
|
# If we got here, we should include it
|
|
if category not in matches.keys():
|
|
matches[category] = []
|
|
matches[category].append({"score": float(score), "box": boxes})
|
|
total_matches += 1
|
|
|
|
# Save Images
|
|
if total_matches and self._file_out:
|
|
paths = []
|
|
for path_template in self._file_out:
|
|
if isinstance(path_template, template.Template):
|
|
paths.append(
|
|
path_template.render(camera_entity=self._camera_entity)
|
|
)
|
|
else:
|
|
paths.append(path_template)
|
|
self._save_image(image, matches, paths)
|
|
|
|
self._matches = matches
|
|
self._total_matches = total_matches
|