"""AWS platform for notify component.""" import asyncio import base64 import json import logging import aiobotocore from homeassistant.components.notify import ( ATTR_TARGET, ATTR_TITLE, ATTR_TITLE_DEFAULT, BaseNotificationService, ) from homeassistant.const import CONF_NAME, CONF_PLATFORM from homeassistant.helpers.json import JSONEncoder from .const import ( CONF_CONTEXT, CONF_CREDENTIAL_NAME, CONF_PROFILE_NAME, CONF_REGION, CONF_SERVICE, DATA_SESSIONS, ) _LOGGER = logging.getLogger(__name__) async def get_available_regions(hass, service): """Get available regions for a service.""" session = aiobotocore.get_session() # get_available_regions is not a coroutine since it does not perform # network I/O. But it still perform file I/O heavily, so put it into # an executor thread to unblock event loop return await hass.async_add_executor_job(session.get_available_regions, service) async def async_get_service(hass, config, discovery_info=None): """Get the AWS notification service.""" if discovery_info is None: _LOGGER.error("Please config aws notify platform in aws component") return None session = None conf = discovery_info service = conf[CONF_SERVICE] region_name = conf[CONF_REGION] available_regions = await get_available_regions(hass, service) if region_name not in available_regions: _LOGGER.error( "Region %s is not available for %s service, must in %s", region_name, service, available_regions, ) return None aws_config = conf.copy() del aws_config[CONF_SERVICE] del aws_config[CONF_REGION] if CONF_PLATFORM in aws_config: del aws_config[CONF_PLATFORM] if CONF_NAME in aws_config: del aws_config[CONF_NAME] if CONF_CONTEXT in aws_config: del aws_config[CONF_CONTEXT] if not aws_config: # no platform config, use the first aws component credential instead if hass.data[DATA_SESSIONS]: session = next(iter(hass.data[DATA_SESSIONS].values())) else: _LOGGER.error("Missing aws credential for %s", config[CONF_NAME]) return None if session is None: credential_name = aws_config.get(CONF_CREDENTIAL_NAME) if credential_name is not None: session = hass.data[DATA_SESSIONS].get(credential_name) if session is None: _LOGGER.warning("No available aws session for %s", credential_name) del aws_config[CONF_CREDENTIAL_NAME] if session is None: profile = aws_config.get(CONF_PROFILE_NAME) if profile is not None: session = aiobotocore.AioSession(profile=profile) del aws_config[CONF_PROFILE_NAME] else: session = aiobotocore.AioSession() aws_config[CONF_REGION] = region_name if service == "lambda": context_str = json.dumps( {"custom": conf.get(CONF_CONTEXT, {})}, cls=JSONEncoder ) context_b64 = base64.b64encode(context_str.encode("utf-8")) context = context_b64.decode("utf-8") return AWSLambda(session, aws_config, context) if service == "sns": return AWSSNS(session, aws_config) if service == "sqs": return AWSSQS(session, aws_config) # should not reach here since service was checked in schema return None class AWSNotify(BaseNotificationService): """Implement the notification service for the AWS service.""" def __init__(self, session, aws_config): """Initialize the service.""" self.session = session self.aws_config = aws_config class AWSLambda(AWSNotify): """Implement the notification service for the AWS Lambda service.""" service = "lambda" def __init__(self, session, aws_config, context): """Initialize the service.""" super().__init__(session, aws_config) self.context = context async def async_send_message(self, message="", **kwargs): """Send notification to specified LAMBDA ARN.""" if not kwargs.get(ATTR_TARGET): _LOGGER.error("At least one target is required") return cleaned_kwargs = {k: v for k, v in kwargs.items() if v is not None} payload = {"message": message} payload.update(cleaned_kwargs) json_payload = json.dumps(payload) async with self.session.create_client( self.service, **self.aws_config ) as client: tasks = [] for target in kwargs.get(ATTR_TARGET, []): tasks.append( client.invoke( FunctionName=target, Payload=json_payload, ClientContext=self.context, ) ) if tasks: await asyncio.gather(*tasks) class AWSSNS(AWSNotify): """Implement the notification service for the AWS SNS service.""" service = "sns" async def async_send_message(self, message="", **kwargs): """Send notification to specified SNS ARN.""" if not kwargs.get(ATTR_TARGET): _LOGGER.error("At least one target is required") return message_attributes = { k: {"StringValue": json.dumps(v), "DataType": "String"} for k, v in kwargs.items() if v is not None } subject = kwargs.get(ATTR_TITLE, ATTR_TITLE_DEFAULT) async with self.session.create_client( self.service, **self.aws_config ) as client: tasks = [] for target in kwargs.get(ATTR_TARGET, []): tasks.append( client.publish( TargetArn=target, Message=message, Subject=subject, MessageAttributes=message_attributes, ) ) if tasks: await asyncio.gather(*tasks) class AWSSQS(AWSNotify): """Implement the notification service for the AWS SQS service.""" service = "sqs" async def async_send_message(self, message="", **kwargs): """Send notification to specified SQS ARN.""" if not kwargs.get(ATTR_TARGET): _LOGGER.error("At least one target is required") return cleaned_kwargs = {k: v for k, v in kwargs.items() if v is not None} message_body = {"message": message} message_body.update(cleaned_kwargs) json_body = json.dumps(message_body) message_attributes = {} for key, val in cleaned_kwargs.items(): message_attributes[key] = { "StringValue": json.dumps(val), "DataType": "String", } async with self.session.create_client( self.service, **self.aws_config ) as client: tasks = [] for target in kwargs.get(ATTR_TARGET, []): tasks.append( client.send_message( QueueUrl=target, MessageBody=json_body, MessageAttributes=message_attributes, ) ) if tasks: await asyncio.gather(*tasks)