Stream API requests to the supervisor (#53909)

pull/53914/head
Joakim Sørensen 2021-08-03 16:48:22 +02:00 committed by GitHub
parent 2105419a4e
commit 56360feb9a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 49 additions and 39 deletions

View File

@ -8,9 +8,14 @@ import re
import aiohttp
from aiohttp import web
from aiohttp.hdrs import CONTENT_LENGTH, CONTENT_TYPE
from aiohttp.client import ClientTimeout
from aiohttp.hdrs import (
CONTENT_ENCODING,
CONTENT_LENGTH,
CONTENT_TYPE,
TRANSFER_ENCODING,
)
from aiohttp.web_exceptions import HTTPBadGateway
import async_timeout
from homeassistant.components.http import KEY_AUTHENTICATED, HomeAssistantView
from homeassistant.components.onboarding import async_is_onboarded
@ -75,14 +80,11 @@ class HassIOView(HomeAssistantView):
async def _command_proxy(
self, path: str, request: web.Request
) -> web.Response | web.StreamResponse:
) -> web.StreamResponse:
"""Return a client request with proxy origin for Hass.io supervisor.
This method is a coroutine.
"""
read_timeout = _get_timeout(path)
client_timeout = 10
data = None
headers = _init_header(request)
if path in ("snapshots/new/upload", "backups/new/upload"):
# We need to reuse the full content type that includes the boundary
@ -90,34 +92,20 @@ class HassIOView(HomeAssistantView):
"Content-Type"
] = request._stored_content_type # pylint: disable=protected-access
# Backups are big, so we need to adjust the allowed size
request._client_max_size = ( # pylint: disable=protected-access
MAX_UPLOAD_SIZE
)
client_timeout = 300
try:
with async_timeout.timeout(client_timeout):
data = await request.read()
method = getattr(self._websession, request.method.lower())
client = await method(
f"http://{self._host}/{path}",
data=data,
client = await self._websession.request(
method=request.method,
url=f"http://{self._host}/{path}",
params=request.query,
data=request.content,
headers=headers,
timeout=read_timeout,
timeout=_get_timeout(path),
)
# Simple request
if int(client.headers.get(CONTENT_LENGTH, 0)) < 4194000:
# Return Response
body = await client.read()
return web.Response(
content_type=client.content_type, status=client.status, body=body
)
# Stream response
response = web.StreamResponse(status=client.status, headers=client.headers)
response = web.StreamResponse(
status=client.status, headers=_response_header(client)
)
response.content_type = client.content_type
await response.prepare(request)
@ -151,11 +139,28 @@ def _init_header(request: web.Request) -> dict[str, str]:
return headers
def _get_timeout(path: str) -> int:
def _response_header(response: aiohttp.ClientResponse) -> dict[str, str]:
"""Create response header."""
headers = {}
for name, value in response.headers.items():
if name in (
TRANSFER_ENCODING,
CONTENT_LENGTH,
CONTENT_TYPE,
CONTENT_ENCODING,
):
continue
headers[name] = value
return headers
def _get_timeout(path: str) -> ClientTimeout:
"""Return timeout for a URL path."""
if NO_TIMEOUT.match(path):
return 0
return 300
return ClientTimeout(connect=10, total=None)
return ClientTimeout(connect=10, total=300)
def _need_auth(hass, path: str) -> bool:

View File

@ -1,7 +1,7 @@
"""The tests for the hassio component."""
import asyncio
from unittest.mock import patch
from aiohttp import StreamReader
import pytest
from homeassistant.components.hassio.http import _need_auth
@ -106,13 +106,11 @@ async def test_forward_log_request(hassio_client, aioclient_mock):
assert len(aioclient_mock.mock_calls) == 1
async def test_bad_gateway_when_cannot_find_supervisor(hassio_client):
async def test_bad_gateway_when_cannot_find_supervisor(hassio_client, aioclient_mock):
"""Test we get a bad gateway error if we can't find supervisor."""
with patch(
"homeassistant.components.hassio.http.async_timeout.timeout",
side_effect=asyncio.TimeoutError,
):
resp = await hassio_client.get("/api/hassio/addons/test/info")
aioclient_mock.get("http://127.0.0.1/addons/test/info", exc=asyncio.TimeoutError)
resp = await hassio_client.get("/api/hassio/addons/test/info")
assert resp.status == 502
@ -180,3 +178,10 @@ def test_need_auth(hass):
hass.data["onboarding"] = False
assert not _need_auth(hass, "backups/new/upload")
assert not _need_auth(hass, "supervisor/logs")
async def test_stream(hassio_client, aioclient_mock):
"""Verify that the request is a stream."""
aioclient_mock.get("http://127.0.0.1/test")
await hassio_client.get("/api/hassio/test", data="test")
assert isinstance(aioclient_mock.mock_calls[-1][2], StreamReader)