diff --git a/homeassistant/components/image_processing/facebox.py b/homeassistant/components/image_processing/facebox.py index f556b62e935..c863f804513 100644 --- a/homeassistant/components/image_processing/facebox.py +++ b/homeassistant/components/image_processing/facebox.py @@ -10,20 +10,26 @@ import logging import requests import voluptuous as vol -from homeassistant.const import ATTR_NAME +from homeassistant.const import ( + ATTR_ENTITY_ID, ATTR_NAME) from homeassistant.core import split_entity_id import homeassistant.helpers.config_validation as cv from homeassistant.components.image_processing import ( PLATFORM_SCHEMA, ImageProcessingFaceEntity, ATTR_CONFIDENCE, CONF_SOURCE, - CONF_ENTITY_ID, CONF_NAME) + CONF_ENTITY_ID, CONF_NAME, DOMAIN) from homeassistant.const import (CONF_IP_ADDRESS, CONF_PORT) _LOGGER = logging.getLogger(__name__) ATTR_BOUNDING_BOX = 'bounding_box' +ATTR_CLASSIFIER = 'classifier' ATTR_IMAGE_ID = 'image_id' ATTR_MATCHED = 'matched' CLASSIFIER = 'facebox' +DATA_FACEBOX = 'facebox_classifiers' +EVENT_CLASSIFIER_TEACH = 'image_processing.teach_classifier' +FILE_PATH = 'file_path' +SERVICE_TEACH_FACE = 'facebox_teach_face' TIMEOUT = 9 @@ -32,6 +38,12 @@ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend({ vol.Required(CONF_PORT): cv.port, }) +SERVICE_TEACH_SCHEMA = vol.Schema({ + vol.Optional(ATTR_ENTITY_ID): cv.entity_ids, + vol.Required(ATTR_NAME): cv.string, + vol.Required(FILE_PATH): cv.string, +}) + def encode_image(image): """base64 encode an image stream.""" @@ -63,18 +75,65 @@ def parse_faces(api_faces): return known_faces +def post_image(url, image): + """Post an image to the classifier.""" + try: + response = requests.post( + url, + json={"base64": encode_image(image)}, + timeout=TIMEOUT + ) + return response + except requests.exceptions.ConnectionError: + _LOGGER.error("ConnectionError: Is %s running?", CLASSIFIER) + + +def valid_file_path(file_path): + """Check that a file_path points to a valid file.""" + try: + cv.isfile(file_path) + return True + except vol.Invalid: + _LOGGER.error( + "%s error: Invalid file path: %s", CLASSIFIER, file_path) + return False + + def setup_platform(hass, config, add_devices, discovery_info=None): """Set up the classifier.""" + if DATA_FACEBOX not in hass.data: + hass.data[DATA_FACEBOX] = [] + entities = [] for camera in config[CONF_SOURCE]: - entities.append(FaceClassifyEntity( + facebox = FaceClassifyEntity( config[CONF_IP_ADDRESS], config[CONF_PORT], camera[CONF_ENTITY_ID], - camera.get(CONF_NAME) - )) + camera.get(CONF_NAME)) + entities.append(facebox) + hass.data[DATA_FACEBOX].append(facebox) add_devices(entities) + def service_handle(service): + """Handle for services.""" + entity_ids = service.data.get('entity_id') + + classifiers = hass.data[DATA_FACEBOX] + if entity_ids: + classifiers = [c for c in classifiers if c.entity_id in entity_ids] + + for classifier in classifiers: + name = service.data.get(ATTR_NAME) + file_path = service.data.get(FILE_PATH) + classifier.teach(name, file_path) + + hass.services.register( + DOMAIN, + SERVICE_TEACH_FACE, + service_handle, + schema=SERVICE_TEACH_SCHEMA) + class FaceClassifyEntity(ImageProcessingFaceEntity): """Perform a face classification.""" @@ -82,7 +141,8 @@ class FaceClassifyEntity(ImageProcessingFaceEntity): def __init__(self, ip, port, camera_entity, name=None): """Init with the API key and model id.""" super().__init__() - self._url = "http://{}:{}/{}/check".format(ip, port, CLASSIFIER) + self._url_check = "http://{}:{}/{}/check".format(ip, port, CLASSIFIER) + self._url_teach = "http://{}:{}/{}/teach".format(ip, port, CLASSIFIER) self._camera = camera_entity if name: self._name = name @@ -94,28 +154,54 @@ class FaceClassifyEntity(ImageProcessingFaceEntity): def process_image(self, image): """Process an image.""" - response = {} - try: - response = requests.post( - self._url, - json={"base64": encode_image(image)}, - timeout=TIMEOUT - ).json() - except requests.exceptions.ConnectionError: - _LOGGER.error("ConnectionError: Is %s running?", CLASSIFIER) - response['success'] = False - - if response['success']: - total_faces = response['facesCount'] - faces = parse_faces(response['faces']) - self._matched = get_matched_faces(faces) - self.process_faces(faces, total_faces) + response = post_image(self._url_check, image) + if response is not None: + response_json = response.json() + if response_json['success']: + total_faces = response_json['facesCount'] + faces = parse_faces(response_json['faces']) + self._matched = get_matched_faces(faces) + self.process_faces(faces, total_faces) else: self.total_faces = None self.faces = [] self._matched = {} + def teach(self, name, file_path): + """Teach classifier a face name.""" + if (not self.hass.config.is_allowed_path(file_path) + or not valid_file_path(file_path)): + return + with open(file_path, 'rb') as open_file: + response = requests.post( + self._url_teach, + data={ATTR_NAME: name, 'id': file_path}, + files={'file': open_file}) + + if response.status_code == 200: + self.hass.bus.fire( + EVENT_CLASSIFIER_TEACH, { + ATTR_CLASSIFIER: CLASSIFIER, + ATTR_NAME: name, + FILE_PATH: file_path, + 'success': True, + 'message': None + }) + + elif response.status_code == 400: + _LOGGER.warning( + "%s teaching of file %s failed with message:%s", + CLASSIFIER, file_path, response.text) + self.hass.bus.fire( + EVENT_CLASSIFIER_TEACH, { + ATTR_CLASSIFIER: CLASSIFIER, + ATTR_NAME: name, + FILE_PATH: file_path, + 'success': False, + 'message': response.text + }) + @property def camera_entity(self): """Return camera entity id from process pictures.""" @@ -131,4 +217,5 @@ class FaceClassifyEntity(ImageProcessingFaceEntity): """Return the classifier attributes.""" return { 'matched_faces': self._matched, + 'total_matched_faces': len(self._matched), } diff --git a/homeassistant/components/image_processing/services.yaml b/homeassistant/components/image_processing/services.yaml index 1f1fa347dc9..0689c34c1a3 100644 --- a/homeassistant/components/image_processing/services.yaml +++ b/homeassistant/components/image_processing/services.yaml @@ -6,3 +6,16 @@ scan: entity_id: description: Name(s) of entities to scan immediately. example: 'image_processing.alpr_garage' + +facebox_teach_face: + description: Teach facebox a face using a file. + fields: + entity_id: + description: The facebox entity to teach. + example: 'image_processing.facebox' + name: + description: The name of the face to teach. + example: 'my_name' + file_path: + description: The path to the image file. + example: '/images/my_image.jpg' diff --git a/tests/components/image_processing/test_facebox.py b/tests/components/image_processing/test_facebox.py index 9449ebf5f71..86811f94db3 100644 --- a/tests/components/image_processing/test_facebox.py +++ b/tests/components/image_processing/test_facebox.py @@ -1,5 +1,5 @@ """The tests for the facebox component.""" -from unittest.mock import patch +from unittest.mock import Mock, mock_open, patch import pytest import requests @@ -13,21 +13,26 @@ from homeassistant.setup import async_setup_component import homeassistant.components.image_processing as ip import homeassistant.components.image_processing.facebox as fb +# pylint: disable=redefined-outer-name + MOCK_IP = '192.168.0.1' MOCK_PORT = '8080' # Mock data returned by the facebox API. +MOCK_ERROR = "No face found" MOCK_FACE = {'confidence': 0.5812028911604818, 'id': 'john.jpg', 'matched': True, 'name': 'John Lennon', - 'rect': {'height': 75, 'left': 63, 'top': 262, 'width': 74} - } + 'rect': {'height': 75, 'left': 63, 'top': 262, 'width': 74}} + +MOCK_FILE_PATH = '/images/mock.jpg' MOCK_JSON = {"facesCount": 1, "success": True, - "faces": [MOCK_FACE] - } + "faces": [MOCK_FACE]} + +MOCK_NAME = 'mock_name' # Faces data after parsing. PARSED_FACES = [{ATTR_NAME: 'John Lennon', @@ -38,8 +43,7 @@ PARSED_FACES = [{ATTR_NAME: 'John Lennon', 'height': 75, 'left': 63, 'top': 262, - 'width': 74}, - }] + 'width': 74}}] MATCHED_FACES = {'John Lennon': 58.12} @@ -58,16 +62,42 @@ VALID_CONFIG = { } +@pytest.fixture +def mock_isfile(): + """Mock os.path.isfile.""" + with patch('homeassistant.components.image_processing.facebox.cv.isfile', + return_value=True) as _mock_isfile: + yield _mock_isfile + + +@pytest.fixture +def mock_open_file(): + """Mock open.""" + mopen = mock_open() + with patch('homeassistant.components.image_processing.facebox.open', + mopen, create=True) as _mock_open: + yield _mock_open + + def test_encode_image(): """Test that binary data is encoded correctly.""" assert fb.encode_image(b'test') == 'dGVzdA==' +def test_get_matched_faces(): + """Test that matched_faces are parsed correctly.""" + assert fb.get_matched_faces(PARSED_FACES) == MATCHED_FACES + + def test_parse_faces(): """Test parsing of raw face data, and generation of matched_faces.""" - parsed_faces = fb.parse_faces(MOCK_JSON['faces']) - assert parsed_faces == PARSED_FACES - assert fb.get_matched_faces(parsed_faces) == MATCHED_FACES + assert fb.parse_faces(MOCK_JSON['faces']) == PARSED_FACES + + +@patch('os.access', Mock(return_value=False)) +def test_valid_file_path(): + """Test that an invalid file_path is caught.""" + assert not fb.valid_file_path('test_path') @pytest.fixture @@ -110,6 +140,7 @@ async def test_process_image(hass, mock_image): state = hass.states.get(VALID_ENTITY_ID) assert state.state == '1' assert state.attributes.get('matched_faces') == MATCHED_FACES + assert state.attributes.get('total_matched_faces') == 1 PARSED_FACES[0][ATTR_ENTITY_ID] = VALID_ENTITY_ID # Update. assert state.attributes.get('faces') == PARSED_FACES @@ -134,7 +165,7 @@ async def test_connection_error(hass, mock_image): with requests_mock.Mocker() as mock_req: url = "http://{}:{}/facebox/check".format(MOCK_IP, MOCK_PORT) mock_req.register_uri( - 'POST', url, exc=requests.exceptions.ConnectTimeout) + 'POST', url, exc=requests.exceptions.ConnectTimeout) data = {ATTR_ENTITY_ID: VALID_ENTITY_ID} await hass.services.async_call(ip.DOMAIN, ip.SERVICE_SCAN, @@ -147,15 +178,69 @@ async def test_connection_error(hass, mock_image): assert state.attributes.get('matched_faces') == {} +async def test_teach_service(hass, mock_image, mock_isfile, mock_open_file): + """Test teaching of facebox.""" + await async_setup_component(hass, ip.DOMAIN, VALID_CONFIG) + assert hass.states.get(VALID_ENTITY_ID) + + teach_events = [] + + @callback + def mock_teach_event(event): + """Mock event.""" + teach_events.append(event) + + hass.bus.async_listen( + 'image_processing.teach_classifier', mock_teach_event) + + # Patch out 'is_allowed_path' as the mock files aren't allowed + hass.config.is_allowed_path = Mock(return_value=True) + + with requests_mock.Mocker() as mock_req: + url = "http://{}:{}/facebox/teach".format(MOCK_IP, MOCK_PORT) + mock_req.post(url, status_code=200) + data = {ATTR_ENTITY_ID: VALID_ENTITY_ID, + ATTR_NAME: MOCK_NAME, + fb.FILE_PATH: MOCK_FILE_PATH} + await hass.services.async_call( + ip.DOMAIN, fb.SERVICE_TEACH_FACE, service_data=data) + await hass.async_block_till_done() + + assert len(teach_events) == 1 + assert teach_events[0].data[fb.ATTR_CLASSIFIER] == fb.CLASSIFIER + assert teach_events[0].data[ATTR_NAME] == MOCK_NAME + assert teach_events[0].data[fb.FILE_PATH] == MOCK_FILE_PATH + assert teach_events[0].data['success'] + assert not teach_events[0].data['message'] + + # Now test the failed teaching. + with requests_mock.Mocker() as mock_req: + url = "http://{}:{}/facebox/teach".format(MOCK_IP, MOCK_PORT) + mock_req.post(url, status_code=400, text=MOCK_ERROR) + data = {ATTR_ENTITY_ID: VALID_ENTITY_ID, + ATTR_NAME: MOCK_NAME, + fb.FILE_PATH: MOCK_FILE_PATH} + await hass.services.async_call(ip.DOMAIN, + fb.SERVICE_TEACH_FACE, + service_data=data) + await hass.async_block_till_done() + + assert len(teach_events) == 2 + assert teach_events[1].data[fb.ATTR_CLASSIFIER] == fb.CLASSIFIER + assert teach_events[1].data[ATTR_NAME] == MOCK_NAME + assert teach_events[1].data[fb.FILE_PATH] == MOCK_FILE_PATH + assert not teach_events[1].data['success'] + assert teach_events[1].data['message'] == MOCK_ERROR + + async def test_setup_platform_with_name(hass): """Setup platform with one entity and a name.""" - MOCK_NAME = 'mock_name' - NAMED_ENTITY_ID = 'image_processing.{}'.format(MOCK_NAME) + named_entity_id = 'image_processing.{}'.format(MOCK_NAME) - VALID_CONFIG_NAMED = VALID_CONFIG.copy() - VALID_CONFIG_NAMED[ip.DOMAIN][ip.CONF_SOURCE][ip.CONF_NAME] = MOCK_NAME + valid_config_named = VALID_CONFIG.copy() + valid_config_named[ip.DOMAIN][ip.CONF_SOURCE][ip.CONF_NAME] = MOCK_NAME - await async_setup_component(hass, ip.DOMAIN, VALID_CONFIG_NAMED) - assert hass.states.get(NAMED_ENTITY_ID) - state = hass.states.get(NAMED_ENTITY_ID) + await async_setup_component(hass, ip.DOMAIN, valid_config_named) + assert hass.states.get(named_entity_id) + state = hass.states.get(named_entity_id) assert state.attributes.get(CONF_FRIENDLY_NAME) == MOCK_NAME