239 lines
7.9 KiB
Python
239 lines
7.9 KiB
Python
"""Support for Snips on-device ASR and NLU."""
|
|
from datetime import timedelta
|
|
import json
|
|
import logging
|
|
|
|
import voluptuous as vol
|
|
|
|
from homeassistant.components import mqtt
|
|
from homeassistant.core import callback
|
|
from homeassistant.helpers import config_validation as cv, intent
|
|
|
|
DOMAIN = "snips"
|
|
CONF_INTENTS = "intents"
|
|
CONF_ACTION = "action"
|
|
CONF_FEEDBACK = "feedback_sounds"
|
|
CONF_PROBABILITY = "probability_threshold"
|
|
CONF_SITE_IDS = "site_ids"
|
|
|
|
SERVICE_SAY = "say"
|
|
SERVICE_SAY_ACTION = "say_action"
|
|
SERVICE_FEEDBACK_ON = "feedback_on"
|
|
SERVICE_FEEDBACK_OFF = "feedback_off"
|
|
|
|
INTENT_TOPIC = "hermes/intent/#"
|
|
FEEDBACK_ON_TOPIC = "hermes/feedback/sound/toggleOn"
|
|
FEEDBACK_OFF_TOPIC = "hermes/feedback/sound/toggleOff"
|
|
|
|
ATTR_TEXT = "text"
|
|
ATTR_SITE_ID = "site_id"
|
|
ATTR_CUSTOM_DATA = "custom_data"
|
|
ATTR_CAN_BE_ENQUEUED = "can_be_enqueued"
|
|
ATTR_INTENT_FILTER = "intent_filter"
|
|
|
|
_LOGGER = logging.getLogger(__name__)
|
|
|
|
CONFIG_SCHEMA = vol.Schema(
|
|
{
|
|
DOMAIN: vol.Schema(
|
|
{
|
|
vol.Optional(CONF_FEEDBACK): cv.boolean,
|
|
vol.Optional(CONF_PROBABILITY, default=0): vol.Coerce(float),
|
|
vol.Optional(CONF_SITE_IDS, default=["default"]): vol.All(
|
|
cv.ensure_list, [cv.string]
|
|
),
|
|
}
|
|
)
|
|
},
|
|
extra=vol.ALLOW_EXTRA,
|
|
)
|
|
|
|
INTENT_SCHEMA = vol.Schema(
|
|
{
|
|
vol.Required("input"): str,
|
|
vol.Required("intent"): {vol.Required("intentName"): str},
|
|
vol.Optional("slots"): [
|
|
{
|
|
vol.Required("slotName"): str,
|
|
vol.Required("value"): {
|
|
vol.Required("kind"): str,
|
|
vol.Optional("value"): cv.match_all,
|
|
vol.Optional("rawValue"): cv.match_all,
|
|
},
|
|
}
|
|
],
|
|
},
|
|
extra=vol.ALLOW_EXTRA,
|
|
)
|
|
|
|
SERVICE_SCHEMA_SAY = vol.Schema(
|
|
{
|
|
vol.Required(ATTR_TEXT): str,
|
|
vol.Optional(ATTR_SITE_ID, default="default"): str,
|
|
vol.Optional(ATTR_CUSTOM_DATA, default=""): str,
|
|
}
|
|
)
|
|
SERVICE_SCHEMA_SAY_ACTION = vol.Schema(
|
|
{
|
|
vol.Required(ATTR_TEXT): str,
|
|
vol.Optional(ATTR_SITE_ID, default="default"): str,
|
|
vol.Optional(ATTR_CUSTOM_DATA, default=""): str,
|
|
vol.Optional(ATTR_CAN_BE_ENQUEUED, default=True): cv.boolean,
|
|
vol.Optional(ATTR_INTENT_FILTER): vol.All(cv.ensure_list),
|
|
}
|
|
)
|
|
SERVICE_SCHEMA_FEEDBACK = vol.Schema(
|
|
{vol.Optional(ATTR_SITE_ID, default="default"): str}
|
|
)
|
|
|
|
|
|
async def async_setup(hass, config):
|
|
"""Activate Snips component."""
|
|
|
|
@callback
|
|
def async_set_feedback(site_ids, state):
|
|
"""Set Feedback sound state."""
|
|
site_ids = site_ids if site_ids else config[DOMAIN].get(CONF_SITE_IDS)
|
|
topic = FEEDBACK_ON_TOPIC if state else FEEDBACK_OFF_TOPIC
|
|
for site_id in site_ids:
|
|
payload = json.dumps({"siteId": site_id})
|
|
hass.components.mqtt.async_publish(
|
|
FEEDBACK_ON_TOPIC, "", qos=0, retain=False
|
|
)
|
|
hass.components.mqtt.async_publish(
|
|
topic, payload, qos=int(state), retain=state
|
|
)
|
|
|
|
if CONF_FEEDBACK in config[DOMAIN]:
|
|
async_set_feedback(None, config[DOMAIN][CONF_FEEDBACK])
|
|
|
|
async def message_received(msg):
|
|
"""Handle new messages on MQTT."""
|
|
_LOGGER.debug("New intent: %s", msg.payload)
|
|
|
|
try:
|
|
request = json.loads(msg.payload)
|
|
except TypeError:
|
|
_LOGGER.error("Received invalid JSON: %s", msg.payload)
|
|
return
|
|
|
|
if request["intent"]["confidenceScore"] < config[DOMAIN].get(CONF_PROBABILITY):
|
|
_LOGGER.warning(
|
|
"Intent below probaility threshold %s < %s",
|
|
request["intent"]["confidenceScore"],
|
|
config[DOMAIN].get(CONF_PROBABILITY),
|
|
)
|
|
return
|
|
|
|
try:
|
|
request = INTENT_SCHEMA(request)
|
|
except vol.Invalid as err:
|
|
_LOGGER.error("Intent has invalid schema: %s. %s", err, request)
|
|
return
|
|
|
|
if request["intent"]["intentName"].startswith("user_"):
|
|
intent_type = request["intent"]["intentName"].split("__")[-1]
|
|
else:
|
|
intent_type = request["intent"]["intentName"].split(":")[-1]
|
|
slots = {}
|
|
for slot in request.get("slots", []):
|
|
slots[slot["slotName"]] = {"value": resolve_slot_values(slot)}
|
|
slots["{}_raw".format(slot["slotName"])] = {"value": slot["rawValue"]}
|
|
slots["site_id"] = {"value": request.get("siteId")}
|
|
slots["session_id"] = {"value": request.get("sessionId")}
|
|
slots["confidenceScore"] = {"value": request["intent"]["confidenceScore"]}
|
|
|
|
try:
|
|
intent_response = await intent.async_handle(
|
|
hass, DOMAIN, intent_type, slots, request["input"]
|
|
)
|
|
notification = {"sessionId": request.get("sessionId", "default")}
|
|
|
|
if "plain" in intent_response.speech:
|
|
notification["text"] = intent_response.speech["plain"]["speech"]
|
|
|
|
_LOGGER.debug("send_response %s", json.dumps(notification))
|
|
mqtt.async_publish(
|
|
hass, "hermes/dialogueManager/endSession", json.dumps(notification)
|
|
)
|
|
except intent.UnknownIntent:
|
|
_LOGGER.warning(
|
|
"Received unknown intent %s", request["intent"]["intentName"]
|
|
)
|
|
except intent.IntentError:
|
|
_LOGGER.exception("Error while handling intent: %s", intent_type)
|
|
|
|
await hass.components.mqtt.async_subscribe(INTENT_TOPIC, message_received)
|
|
|
|
async def snips_say(call):
|
|
"""Send a Snips notification message."""
|
|
notification = {
|
|
"siteId": call.data.get(ATTR_SITE_ID, "default"),
|
|
"customData": call.data.get(ATTR_CUSTOM_DATA, ""),
|
|
"init": {"type": "notification", "text": call.data.get(ATTR_TEXT)},
|
|
}
|
|
mqtt.async_publish(
|
|
hass, "hermes/dialogueManager/startSession", json.dumps(notification)
|
|
)
|
|
return
|
|
|
|
async def snips_say_action(call):
|
|
"""Send a Snips action message."""
|
|
notification = {
|
|
"siteId": call.data.get(ATTR_SITE_ID, "default"),
|
|
"customData": call.data.get(ATTR_CUSTOM_DATA, ""),
|
|
"init": {
|
|
"type": "action",
|
|
"text": call.data.get(ATTR_TEXT),
|
|
"canBeEnqueued": call.data.get(ATTR_CAN_BE_ENQUEUED, True),
|
|
"intentFilter": call.data.get(ATTR_INTENT_FILTER, []),
|
|
},
|
|
}
|
|
mqtt.async_publish(
|
|
hass, "hermes/dialogueManager/startSession", json.dumps(notification)
|
|
)
|
|
return
|
|
|
|
async def feedback_on(call):
|
|
"""Turn feedback sounds on."""
|
|
async_set_feedback(call.data.get(ATTR_SITE_ID), True)
|
|
|
|
async def feedback_off(call):
|
|
"""Turn feedback sounds off."""
|
|
async_set_feedback(call.data.get(ATTR_SITE_ID), False)
|
|
|
|
hass.services.async_register(
|
|
DOMAIN, SERVICE_SAY, snips_say, schema=SERVICE_SCHEMA_SAY
|
|
)
|
|
hass.services.async_register(
|
|
DOMAIN, SERVICE_SAY_ACTION, snips_say_action, schema=SERVICE_SCHEMA_SAY_ACTION
|
|
)
|
|
hass.services.async_register(
|
|
DOMAIN, SERVICE_FEEDBACK_ON, feedback_on, schema=SERVICE_SCHEMA_FEEDBACK
|
|
)
|
|
hass.services.async_register(
|
|
DOMAIN, SERVICE_FEEDBACK_OFF, feedback_off, schema=SERVICE_SCHEMA_FEEDBACK
|
|
)
|
|
|
|
return True
|
|
|
|
|
|
def resolve_slot_values(slot):
|
|
"""Convert snips builtin types to usable values."""
|
|
if "value" in slot["value"]:
|
|
value = slot["value"]["value"]
|
|
else:
|
|
value = slot["rawValue"]
|
|
|
|
if slot.get("entity") == "snips/duration":
|
|
delta = timedelta(
|
|
weeks=slot["value"]["weeks"],
|
|
days=slot["value"]["days"],
|
|
hours=slot["value"]["hours"],
|
|
minutes=slot["value"]["minutes"],
|
|
seconds=slot["value"]["seconds"],
|
|
)
|
|
value = delta.total_seconds()
|
|
|
|
return value
|