Add Facebox teach service (#14998)

* Adds service

* Address pylint

* Update facebox.py

* patch tests

* Update facebox.py

* Update test_facebox.py

* Update facebox.py

* Update facebox.py

* Update facebox.py

* Update test_facebox.py

* Update test_facebox.py

* Update facebox.py

* Update facebox.py

* Update facebox.py

* Update facebox.py

* Adds total_matched_faces

* Update test_facebox.py

* Update facebox.py

* Update test_facebox.py

* Update test_facebox.py

* Remove fixtures

Removes the fixtures which were causing `setup` to fail, replace with `@patch`

* Fix teach service test and lint issues
pull/15396/merge
Robin 2018-07-10 02:11:39 +01:00 committed by Martin Hjelmare
parent c5a2ffbcb9
commit df8c59406b
3 changed files with 225 additions and 40 deletions

View File

@ -10,20 +10,26 @@ import logging
import requests
import voluptuous as vol
from homeassistant.const import ATTR_NAME
from homeassistant.const import (
ATTR_ENTITY_ID, ATTR_NAME)
from homeassistant.core import split_entity_id
import homeassistant.helpers.config_validation as cv
from homeassistant.components.image_processing import (
PLATFORM_SCHEMA, ImageProcessingFaceEntity, ATTR_CONFIDENCE, CONF_SOURCE,
CONF_ENTITY_ID, CONF_NAME)
CONF_ENTITY_ID, CONF_NAME, DOMAIN)
from homeassistant.const import (CONF_IP_ADDRESS, CONF_PORT)
_LOGGER = logging.getLogger(__name__)
ATTR_BOUNDING_BOX = 'bounding_box'
ATTR_CLASSIFIER = 'classifier'
ATTR_IMAGE_ID = 'image_id'
ATTR_MATCHED = 'matched'
CLASSIFIER = 'facebox'
DATA_FACEBOX = 'facebox_classifiers'
EVENT_CLASSIFIER_TEACH = 'image_processing.teach_classifier'
FILE_PATH = 'file_path'
SERVICE_TEACH_FACE = 'facebox_teach_face'
TIMEOUT = 9
@ -32,6 +38,12 @@ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend({
vol.Required(CONF_PORT): cv.port,
})
SERVICE_TEACH_SCHEMA = vol.Schema({
vol.Optional(ATTR_ENTITY_ID): cv.entity_ids,
vol.Required(ATTR_NAME): cv.string,
vol.Required(FILE_PATH): cv.string,
})
def encode_image(image):
"""base64 encode an image stream."""
@ -63,18 +75,65 @@ def parse_faces(api_faces):
return known_faces
def post_image(url, image):
"""Post an image to the classifier."""
try:
response = requests.post(
url,
json={"base64": encode_image(image)},
timeout=TIMEOUT
)
return response
except requests.exceptions.ConnectionError:
_LOGGER.error("ConnectionError: Is %s running?", CLASSIFIER)
def valid_file_path(file_path):
"""Check that a file_path points to a valid file."""
try:
cv.isfile(file_path)
return True
except vol.Invalid:
_LOGGER.error(
"%s error: Invalid file path: %s", CLASSIFIER, file_path)
return False
def setup_platform(hass, config, add_devices, discovery_info=None):
"""Set up the classifier."""
if DATA_FACEBOX not in hass.data:
hass.data[DATA_FACEBOX] = []
entities = []
for camera in config[CONF_SOURCE]:
entities.append(FaceClassifyEntity(
facebox = FaceClassifyEntity(
config[CONF_IP_ADDRESS],
config[CONF_PORT],
camera[CONF_ENTITY_ID],
camera.get(CONF_NAME)
))
camera.get(CONF_NAME))
entities.append(facebox)
hass.data[DATA_FACEBOX].append(facebox)
add_devices(entities)
def service_handle(service):
"""Handle for services."""
entity_ids = service.data.get('entity_id')
classifiers = hass.data[DATA_FACEBOX]
if entity_ids:
classifiers = [c for c in classifiers if c.entity_id in entity_ids]
for classifier in classifiers:
name = service.data.get(ATTR_NAME)
file_path = service.data.get(FILE_PATH)
classifier.teach(name, file_path)
hass.services.register(
DOMAIN,
SERVICE_TEACH_FACE,
service_handle,
schema=SERVICE_TEACH_SCHEMA)
class FaceClassifyEntity(ImageProcessingFaceEntity):
"""Perform a face classification."""
@ -82,7 +141,8 @@ class FaceClassifyEntity(ImageProcessingFaceEntity):
def __init__(self, ip, port, camera_entity, name=None):
"""Init with the API key and model id."""
super().__init__()
self._url = "http://{}:{}/{}/check".format(ip, port, CLASSIFIER)
self._url_check = "http://{}:{}/{}/check".format(ip, port, CLASSIFIER)
self._url_teach = "http://{}:{}/{}/teach".format(ip, port, CLASSIFIER)
self._camera = camera_entity
if name:
self._name = name
@ -94,28 +154,54 @@ class FaceClassifyEntity(ImageProcessingFaceEntity):
def process_image(self, image):
"""Process an image."""
response = {}
try:
response = requests.post(
self._url,
json={"base64": encode_image(image)},
timeout=TIMEOUT
).json()
except requests.exceptions.ConnectionError:
_LOGGER.error("ConnectionError: Is %s running?", CLASSIFIER)
response['success'] = False
if response['success']:
total_faces = response['facesCount']
faces = parse_faces(response['faces'])
self._matched = get_matched_faces(faces)
self.process_faces(faces, total_faces)
response = post_image(self._url_check, image)
if response is not None:
response_json = response.json()
if response_json['success']:
total_faces = response_json['facesCount']
faces = parse_faces(response_json['faces'])
self._matched = get_matched_faces(faces)
self.process_faces(faces, total_faces)
else:
self.total_faces = None
self.faces = []
self._matched = {}
def teach(self, name, file_path):
"""Teach classifier a face name."""
if (not self.hass.config.is_allowed_path(file_path)
or not valid_file_path(file_path)):
return
with open(file_path, 'rb') as open_file:
response = requests.post(
self._url_teach,
data={ATTR_NAME: name, 'id': file_path},
files={'file': open_file})
if response.status_code == 200:
self.hass.bus.fire(
EVENT_CLASSIFIER_TEACH, {
ATTR_CLASSIFIER: CLASSIFIER,
ATTR_NAME: name,
FILE_PATH: file_path,
'success': True,
'message': None
})
elif response.status_code == 400:
_LOGGER.warning(
"%s teaching of file %s failed with message:%s",
CLASSIFIER, file_path, response.text)
self.hass.bus.fire(
EVENT_CLASSIFIER_TEACH, {
ATTR_CLASSIFIER: CLASSIFIER,
ATTR_NAME: name,
FILE_PATH: file_path,
'success': False,
'message': response.text
})
@property
def camera_entity(self):
"""Return camera entity id from process pictures."""
@ -131,4 +217,5 @@ class FaceClassifyEntity(ImageProcessingFaceEntity):
"""Return the classifier attributes."""
return {
'matched_faces': self._matched,
'total_matched_faces': len(self._matched),
}

View File

@ -6,3 +6,16 @@ scan:
entity_id:
description: Name(s) of entities to scan immediately.
example: 'image_processing.alpr_garage'
facebox_teach_face:
description: Teach facebox a face using a file.
fields:
entity_id:
description: The facebox entity to teach.
example: 'image_processing.facebox'
name:
description: The name of the face to teach.
example: 'my_name'
file_path:
description: The path to the image file.
example: '/images/my_image.jpg'

View File

@ -1,5 +1,5 @@
"""The tests for the facebox component."""
from unittest.mock import patch
from unittest.mock import Mock, mock_open, patch
import pytest
import requests
@ -13,21 +13,26 @@ from homeassistant.setup import async_setup_component
import homeassistant.components.image_processing as ip
import homeassistant.components.image_processing.facebox as fb
# pylint: disable=redefined-outer-name
MOCK_IP = '192.168.0.1'
MOCK_PORT = '8080'
# Mock data returned by the facebox API.
MOCK_ERROR = "No face found"
MOCK_FACE = {'confidence': 0.5812028911604818,
'id': 'john.jpg',
'matched': True,
'name': 'John Lennon',
'rect': {'height': 75, 'left': 63, 'top': 262, 'width': 74}
}
'rect': {'height': 75, 'left': 63, 'top': 262, 'width': 74}}
MOCK_FILE_PATH = '/images/mock.jpg'
MOCK_JSON = {"facesCount": 1,
"success": True,
"faces": [MOCK_FACE]
}
"faces": [MOCK_FACE]}
MOCK_NAME = 'mock_name'
# Faces data after parsing.
PARSED_FACES = [{ATTR_NAME: 'John Lennon',
@ -38,8 +43,7 @@ PARSED_FACES = [{ATTR_NAME: 'John Lennon',
'height': 75,
'left': 63,
'top': 262,
'width': 74},
}]
'width': 74}}]
MATCHED_FACES = {'John Lennon': 58.12}
@ -58,16 +62,42 @@ VALID_CONFIG = {
}
@pytest.fixture
def mock_isfile():
"""Mock os.path.isfile."""
with patch('homeassistant.components.image_processing.facebox.cv.isfile',
return_value=True) as _mock_isfile:
yield _mock_isfile
@pytest.fixture
def mock_open_file():
"""Mock open."""
mopen = mock_open()
with patch('homeassistant.components.image_processing.facebox.open',
mopen, create=True) as _mock_open:
yield _mock_open
def test_encode_image():
"""Test that binary data is encoded correctly."""
assert fb.encode_image(b'test') == 'dGVzdA=='
def test_get_matched_faces():
"""Test that matched_faces are parsed correctly."""
assert fb.get_matched_faces(PARSED_FACES) == MATCHED_FACES
def test_parse_faces():
"""Test parsing of raw face data, and generation of matched_faces."""
parsed_faces = fb.parse_faces(MOCK_JSON['faces'])
assert parsed_faces == PARSED_FACES
assert fb.get_matched_faces(parsed_faces) == MATCHED_FACES
assert fb.parse_faces(MOCK_JSON['faces']) == PARSED_FACES
@patch('os.access', Mock(return_value=False))
def test_valid_file_path():
"""Test that an invalid file_path is caught."""
assert not fb.valid_file_path('test_path')
@pytest.fixture
@ -110,6 +140,7 @@ async def test_process_image(hass, mock_image):
state = hass.states.get(VALID_ENTITY_ID)
assert state.state == '1'
assert state.attributes.get('matched_faces') == MATCHED_FACES
assert state.attributes.get('total_matched_faces') == 1
PARSED_FACES[0][ATTR_ENTITY_ID] = VALID_ENTITY_ID # Update.
assert state.attributes.get('faces') == PARSED_FACES
@ -134,7 +165,7 @@ async def test_connection_error(hass, mock_image):
with requests_mock.Mocker() as mock_req:
url = "http://{}:{}/facebox/check".format(MOCK_IP, MOCK_PORT)
mock_req.register_uri(
'POST', url, exc=requests.exceptions.ConnectTimeout)
'POST', url, exc=requests.exceptions.ConnectTimeout)
data = {ATTR_ENTITY_ID: VALID_ENTITY_ID}
await hass.services.async_call(ip.DOMAIN,
ip.SERVICE_SCAN,
@ -147,15 +178,69 @@ async def test_connection_error(hass, mock_image):
assert state.attributes.get('matched_faces') == {}
async def test_teach_service(hass, mock_image, mock_isfile, mock_open_file):
"""Test teaching of facebox."""
await async_setup_component(hass, ip.DOMAIN, VALID_CONFIG)
assert hass.states.get(VALID_ENTITY_ID)
teach_events = []
@callback
def mock_teach_event(event):
"""Mock event."""
teach_events.append(event)
hass.bus.async_listen(
'image_processing.teach_classifier', mock_teach_event)
# Patch out 'is_allowed_path' as the mock files aren't allowed
hass.config.is_allowed_path = Mock(return_value=True)
with requests_mock.Mocker() as mock_req:
url = "http://{}:{}/facebox/teach".format(MOCK_IP, MOCK_PORT)
mock_req.post(url, status_code=200)
data = {ATTR_ENTITY_ID: VALID_ENTITY_ID,
ATTR_NAME: MOCK_NAME,
fb.FILE_PATH: MOCK_FILE_PATH}
await hass.services.async_call(
ip.DOMAIN, fb.SERVICE_TEACH_FACE, service_data=data)
await hass.async_block_till_done()
assert len(teach_events) == 1
assert teach_events[0].data[fb.ATTR_CLASSIFIER] == fb.CLASSIFIER
assert teach_events[0].data[ATTR_NAME] == MOCK_NAME
assert teach_events[0].data[fb.FILE_PATH] == MOCK_FILE_PATH
assert teach_events[0].data['success']
assert not teach_events[0].data['message']
# Now test the failed teaching.
with requests_mock.Mocker() as mock_req:
url = "http://{}:{}/facebox/teach".format(MOCK_IP, MOCK_PORT)
mock_req.post(url, status_code=400, text=MOCK_ERROR)
data = {ATTR_ENTITY_ID: VALID_ENTITY_ID,
ATTR_NAME: MOCK_NAME,
fb.FILE_PATH: MOCK_FILE_PATH}
await hass.services.async_call(ip.DOMAIN,
fb.SERVICE_TEACH_FACE,
service_data=data)
await hass.async_block_till_done()
assert len(teach_events) == 2
assert teach_events[1].data[fb.ATTR_CLASSIFIER] == fb.CLASSIFIER
assert teach_events[1].data[ATTR_NAME] == MOCK_NAME
assert teach_events[1].data[fb.FILE_PATH] == MOCK_FILE_PATH
assert not teach_events[1].data['success']
assert teach_events[1].data['message'] == MOCK_ERROR
async def test_setup_platform_with_name(hass):
"""Setup platform with one entity and a name."""
MOCK_NAME = 'mock_name'
NAMED_ENTITY_ID = 'image_processing.{}'.format(MOCK_NAME)
named_entity_id = 'image_processing.{}'.format(MOCK_NAME)
VALID_CONFIG_NAMED = VALID_CONFIG.copy()
VALID_CONFIG_NAMED[ip.DOMAIN][ip.CONF_SOURCE][ip.CONF_NAME] = MOCK_NAME
valid_config_named = VALID_CONFIG.copy()
valid_config_named[ip.DOMAIN][ip.CONF_SOURCE][ip.CONF_NAME] = MOCK_NAME
await async_setup_component(hass, ip.DOMAIN, VALID_CONFIG_NAMED)
assert hass.states.get(NAMED_ENTITY_ID)
state = hass.states.get(NAMED_ENTITY_ID)
await async_setup_component(hass, ip.DOMAIN, valid_config_named)
assert hass.states.get(named_entity_id)
state = hass.states.get(named_entity_id)
assert state.attributes.get(CONF_FRIENDLY_NAME) == MOCK_NAME