"""Hass.io Add-on ingress service.""" import asyncio import logging import os from ipaddress import ip_address from typing import Dict, Union import aiohttp from aiohttp import hdrs, web from aiohttp.web_exceptions import HTTPBadGateway from multidict import CIMultiDict from homeassistant.components.http import HomeAssistantView from homeassistant.core import callback from homeassistant.helpers.typing import HomeAssistantType from .const import X_HASSIO, X_INGRESS_PATH _LOGGER = logging.getLogger(__name__) @callback def async_setup_ingress(hass: HomeAssistantType, host: str): """Auth setup.""" websession = hass.helpers.aiohttp_client.async_get_clientsession() hassio_ingress = HassIOIngress(host, websession) hass.http.register_view(hassio_ingress) class HassIOIngress(HomeAssistantView): """Hass.io view to handle base part.""" name = "api:hassio:ingress" url = "/api/hassio_ingress/{token}/{path:.*}" requires_auth = False def __init__(self, host: str, websession: aiohttp.ClientSession): """Initialize a Hass.io ingress view.""" self._host = host self._websession = websession def _create_url(self, token: str, path: str) -> str: """Create URL to service.""" return "http://{}/ingress/{}/{}".format(self._host, token, path) async def _handle( self, request: web.Request, token: str, path: str ) -> Union[web.Response, web.StreamResponse, web.WebSocketResponse]: """Route data to Hass.io ingress service.""" try: # Websocket if _is_websocket(request): return await self._handle_websocket(request, token, path) # Request return await self._handle_request(request, token, path) except aiohttp.ClientError as err: _LOGGER.debug("Ingress error with %s / %s: %s", token, path, err) raise HTTPBadGateway() from None get = _handle post = _handle put = _handle delete = _handle patch = _handle options = _handle async def _handle_websocket( self, request: web.Request, token: str, path: str ) -> web.WebSocketResponse: """Ingress route for websocket.""" if hdrs.SEC_WEBSOCKET_PROTOCOL in request.headers: req_protocols = [ str(proto.strip()) for proto in request.headers[hdrs.SEC_WEBSOCKET_PROTOCOL].split(",") ] else: req_protocols = () ws_server = web.WebSocketResponse( protocols=req_protocols, autoclose=False, autoping=False ) await ws_server.prepare(request) # Preparing url = self._create_url(token, path) source_header = _init_header(request, token) # Support GET query if request.query_string: url = "{}?{}".format(url, request.query_string) # Start proxy async with self._websession.ws_connect( url, headers=source_header, protocols=req_protocols, autoclose=False, autoping=False, ) as ws_client: # Proxy requests await asyncio.wait( [ _websocket_forward(ws_server, ws_client), _websocket_forward(ws_client, ws_server), ], return_when=asyncio.FIRST_COMPLETED ) return ws_server async def _handle_request( self, request: web.Request, token: str, path: str ) -> Union[web.Response, web.StreamResponse]: """Ingress route for request.""" url = self._create_url(token, path) data = await request.read() source_header = _init_header(request, token) async with self._websession.request( request.method, url, headers=source_header, params=request.query, data=data ) as result: headers = _response_header(result) # Simple request if hdrs.CONTENT_LENGTH in result.headers and \ int(result.headers.get(hdrs.CONTENT_LENGTH, 0)) < 4194000: # Return Response body = await result.read() return web.Response( headers=headers, status=result.status, content_type=result.content_type, body=body ) # Stream response response = web.StreamResponse( status=result.status, headers=headers) response.content_type = result.content_type try: await response.prepare(request) async for data in result.content.iter_chunked(4096): await response.write(data) except (aiohttp.ClientError, aiohttp.ClientPayloadError) as err: _LOGGER.debug("Stream error %s / %s: %s", token, path, err) return response def _init_header( request: web.Request, token: str ) -> Union[CIMultiDict, Dict[str, str]]: """Create initial header.""" headers = {} # filter flags for name, value in request.headers.items(): if name in (hdrs.CONTENT_LENGTH, hdrs.CONTENT_ENCODING): continue headers[name] = value # Inject token / cleanup later on Supervisor headers[X_HASSIO] = os.environ.get('HASSIO_TOKEN', "") # Ingress information headers[X_INGRESS_PATH] = "/api/hassio_ingress/{}".format(token) # Set X-Forwarded-For forward_for = request.headers.get(hdrs.X_FORWARDED_FOR) connected_ip = ip_address(request.transport.get_extra_info('peername')[0]) if forward_for: forward_for = "{}, {!s}".format(forward_for, connected_ip) else: forward_for = "{!s}".format(connected_ip) headers[hdrs.X_FORWARDED_FOR] = forward_for # Set X-Forwarded-Host forward_host = request.headers.get(hdrs.X_FORWARDED_HOST) if not forward_host: forward_host = request.host headers[hdrs.X_FORWARDED_HOST] = forward_host # Set X-Forwarded-Proto forward_proto = request.headers.get(hdrs.X_FORWARDED_PROTO) if not forward_proto: forward_proto = request.url.scheme headers[hdrs.X_FORWARDED_PROTO] = forward_proto return headers def _response_header(response: aiohttp.ClientResponse) -> Dict[str, str]: """Create response header.""" headers = {} for name, value in response.headers.items(): if name in (hdrs.TRANSFER_ENCODING, hdrs.CONTENT_LENGTH, hdrs.CONTENT_TYPE, hdrs.CONTENT_ENCODING): continue headers[name] = value return headers def _is_websocket(request: web.Request) -> bool: """Return True if request is a websocket.""" headers = request.headers if "upgrade" in headers.get(hdrs.CONNECTION, "").lower() and \ headers.get(hdrs.UPGRADE, "").lower() == "websocket": return True return False async def _websocket_forward(ws_from, ws_to): """Handle websocket message directly.""" try: async for msg in ws_from: if msg.type == aiohttp.WSMsgType.TEXT: await ws_to.send_str(msg.data) elif msg.type == aiohttp.WSMsgType.BINARY: await ws_to.send_bytes(msg.data) elif msg.type == aiohttp.WSMsgType.PING: await ws_to.ping() elif msg.type == aiohttp.WSMsgType.PONG: await ws_to.pong() elif ws_to.closed: await ws_to.close(code=ws_to.close_code, message=msg.extra) except RuntimeError: _LOGGER.debug("Ingress Websocket runtime error")