Stream API requests to the supervisor (#53909)
parent
2105419a4e
commit
56360feb9a
|
@ -8,9 +8,14 @@ import re
|
|||
|
||||
import aiohttp
|
||||
from aiohttp import web
|
||||
from aiohttp.hdrs import CONTENT_LENGTH, CONTENT_TYPE
|
||||
from aiohttp.client import ClientTimeout
|
||||
from aiohttp.hdrs import (
|
||||
CONTENT_ENCODING,
|
||||
CONTENT_LENGTH,
|
||||
CONTENT_TYPE,
|
||||
TRANSFER_ENCODING,
|
||||
)
|
||||
from aiohttp.web_exceptions import HTTPBadGateway
|
||||
import async_timeout
|
||||
|
||||
from homeassistant.components.http import KEY_AUTHENTICATED, HomeAssistantView
|
||||
from homeassistant.components.onboarding import async_is_onboarded
|
||||
|
@ -75,14 +80,11 @@ class HassIOView(HomeAssistantView):
|
|||
|
||||
async def _command_proxy(
|
||||
self, path: str, request: web.Request
|
||||
) -> web.Response | web.StreamResponse:
|
||||
) -> web.StreamResponse:
|
||||
"""Return a client request with proxy origin for Hass.io supervisor.
|
||||
|
||||
This method is a coroutine.
|
||||
"""
|
||||
read_timeout = _get_timeout(path)
|
||||
client_timeout = 10
|
||||
data = None
|
||||
headers = _init_header(request)
|
||||
if path in ("snapshots/new/upload", "backups/new/upload"):
|
||||
# We need to reuse the full content type that includes the boundary
|
||||
|
@ -90,34 +92,20 @@ class HassIOView(HomeAssistantView):
|
|||
"Content-Type"
|
||||
] = request._stored_content_type # pylint: disable=protected-access
|
||||
|
||||
# Backups are big, so we need to adjust the allowed size
|
||||
request._client_max_size = ( # pylint: disable=protected-access
|
||||
MAX_UPLOAD_SIZE
|
||||
)
|
||||
client_timeout = 300
|
||||
|
||||
try:
|
||||
with async_timeout.timeout(client_timeout):
|
||||
data = await request.read()
|
||||
|
||||
method = getattr(self._websession, request.method.lower())
|
||||
client = await method(
|
||||
f"http://{self._host}/{path}",
|
||||
data=data,
|
||||
client = await self._websession.request(
|
||||
method=request.method,
|
||||
url=f"http://{self._host}/{path}",
|
||||
params=request.query,
|
||||
data=request.content,
|
||||
headers=headers,
|
||||
timeout=read_timeout,
|
||||
timeout=_get_timeout(path),
|
||||
)
|
||||
|
||||
# Simple request
|
||||
if int(client.headers.get(CONTENT_LENGTH, 0)) < 4194000:
|
||||
# Return Response
|
||||
body = await client.read()
|
||||
return web.Response(
|
||||
content_type=client.content_type, status=client.status, body=body
|
||||
)
|
||||
|
||||
# Stream response
|
||||
response = web.StreamResponse(status=client.status, headers=client.headers)
|
||||
response = web.StreamResponse(
|
||||
status=client.status, headers=_response_header(client)
|
||||
)
|
||||
response.content_type = client.content_type
|
||||
|
||||
await response.prepare(request)
|
||||
|
@ -151,11 +139,28 @@ def _init_header(request: web.Request) -> dict[str, str]:
|
|||
return headers
|
||||
|
||||
|
||||
def _get_timeout(path: str) -> int:
|
||||
def _response_header(response: aiohttp.ClientResponse) -> dict[str, str]:
|
||||
"""Create response header."""
|
||||
headers = {}
|
||||
|
||||
for name, value in response.headers.items():
|
||||
if name in (
|
||||
TRANSFER_ENCODING,
|
||||
CONTENT_LENGTH,
|
||||
CONTENT_TYPE,
|
||||
CONTENT_ENCODING,
|
||||
):
|
||||
continue
|
||||
headers[name] = value
|
||||
|
||||
return headers
|
||||
|
||||
|
||||
def _get_timeout(path: str) -> ClientTimeout:
|
||||
"""Return timeout for a URL path."""
|
||||
if NO_TIMEOUT.match(path):
|
||||
return 0
|
||||
return 300
|
||||
return ClientTimeout(connect=10, total=None)
|
||||
return ClientTimeout(connect=10, total=300)
|
||||
|
||||
|
||||
def _need_auth(hass, path: str) -> bool:
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
"""The tests for the hassio component."""
|
||||
import asyncio
|
||||
from unittest.mock import patch
|
||||
|
||||
from aiohttp import StreamReader
|
||||
import pytest
|
||||
|
||||
from homeassistant.components.hassio.http import _need_auth
|
||||
|
@ -106,13 +106,11 @@ async def test_forward_log_request(hassio_client, aioclient_mock):
|
|||
assert len(aioclient_mock.mock_calls) == 1
|
||||
|
||||
|
||||
async def test_bad_gateway_when_cannot_find_supervisor(hassio_client):
|
||||
async def test_bad_gateway_when_cannot_find_supervisor(hassio_client, aioclient_mock):
|
||||
"""Test we get a bad gateway error if we can't find supervisor."""
|
||||
with patch(
|
||||
"homeassistant.components.hassio.http.async_timeout.timeout",
|
||||
side_effect=asyncio.TimeoutError,
|
||||
):
|
||||
resp = await hassio_client.get("/api/hassio/addons/test/info")
|
||||
aioclient_mock.get("http://127.0.0.1/addons/test/info", exc=asyncio.TimeoutError)
|
||||
|
||||
resp = await hassio_client.get("/api/hassio/addons/test/info")
|
||||
assert resp.status == 502
|
||||
|
||||
|
||||
|
@ -180,3 +178,10 @@ def test_need_auth(hass):
|
|||
hass.data["onboarding"] = False
|
||||
assert not _need_auth(hass, "backups/new/upload")
|
||||
assert not _need_auth(hass, "supervisor/logs")
|
||||
|
||||
|
||||
async def test_stream(hassio_client, aioclient_mock):
|
||||
"""Verify that the request is a stream."""
|
||||
aioclient_mock.get("http://127.0.0.1/test")
|
||||
await hassio_client.get("/api/hassio/test", data="test")
|
||||
assert isinstance(aioclient_mock.mock_calls[-1][2], StreamReader)
|
||||
|
|
Loading…
Reference in New Issue