Add MFA login flow support for cloud component (#132497)
* Add MFA login flow support for cloud component * Add tests for cloud MFA login * Update code to reflect used package changes * Update code to use underlying package changes * Remove unused change * Fix login required parameters * Fix parameter validation * Use cv.has_at_least_one_key for param validation --------- Co-authored-by: Martin Hjelmare <marhje52@gmail.com>pull/133431/head
parent
5b1c5bf9f6
commit
a14aca31e5
|
@ -88,3 +88,5 @@ DISPATCHER_REMOTE_UPDATE: SignalType[Any] = SignalType("cloud_remote_update")
|
||||||
|
|
||||||
STT_ENTITY_UNIQUE_ID = "cloud-speech-to-text"
|
STT_ENTITY_UNIQUE_ID = "cloud-speech-to-text"
|
||||||
TTS_ENTITY_UNIQUE_ID = "cloud-text-to-speech"
|
TTS_ENTITY_UNIQUE_ID = "cloud-text-to-speech"
|
||||||
|
|
||||||
|
LOGIN_MFA_TIMEOUT = 60
|
||||||
|
|
|
@ -9,6 +9,7 @@ import dataclasses
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
import logging
|
import logging
|
||||||
|
import time
|
||||||
from typing import Any, Concatenate
|
from typing import Any, Concatenate
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
@ -31,6 +32,7 @@ from homeassistant.components.http.data_validator import RequestDataValidator
|
||||||
from homeassistant.const import CLOUD_NEVER_EXPOSED_ENTITIES
|
from homeassistant.const import CLOUD_NEVER_EXPOSED_ENTITIES
|
||||||
from homeassistant.core import HomeAssistant, callback
|
from homeassistant.core import HomeAssistant, callback
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
|
from homeassistant.helpers import config_validation as cv
|
||||||
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
||||||
from homeassistant.util.location import async_detect_location_info
|
from homeassistant.util.location import async_detect_location_info
|
||||||
|
|
||||||
|
@ -39,6 +41,7 @@ from .assist_pipeline import async_create_cloud_pipeline
|
||||||
from .client import CloudClient
|
from .client import CloudClient
|
||||||
from .const import (
|
from .const import (
|
||||||
DATA_CLOUD,
|
DATA_CLOUD,
|
||||||
|
LOGIN_MFA_TIMEOUT,
|
||||||
PREF_ALEXA_REPORT_STATE,
|
PREF_ALEXA_REPORT_STATE,
|
||||||
PREF_DISABLE_2FA,
|
PREF_DISABLE_2FA,
|
||||||
PREF_ENABLE_ALEXA,
|
PREF_ENABLE_ALEXA,
|
||||||
|
@ -69,6 +72,10 @@ _CLOUD_ERRORS: dict[type[Exception], tuple[HTTPStatus, str]] = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class MFAExpiredOrNotStarted(auth.CloudError):
|
||||||
|
"""Multi-factor authentication expired, or not started."""
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_setup(hass: HomeAssistant) -> None:
|
def async_setup(hass: HomeAssistant) -> None:
|
||||||
"""Initialize the HTTP API."""
|
"""Initialize the HTTP API."""
|
||||||
|
@ -101,6 +108,11 @@ def async_setup(hass: HomeAssistant) -> None:
|
||||||
|
|
||||||
_CLOUD_ERRORS.update(
|
_CLOUD_ERRORS.update(
|
||||||
{
|
{
|
||||||
|
auth.InvalidTotpCode: (HTTPStatus.BAD_REQUEST, "Invalid TOTP code."),
|
||||||
|
auth.MFARequired: (
|
||||||
|
HTTPStatus.UNAUTHORIZED,
|
||||||
|
"Multi-factor authentication required.",
|
||||||
|
),
|
||||||
auth.UserNotFound: (HTTPStatus.BAD_REQUEST, "User does not exist."),
|
auth.UserNotFound: (HTTPStatus.BAD_REQUEST, "User does not exist."),
|
||||||
auth.UserNotConfirmed: (HTTPStatus.BAD_REQUEST, "Email not confirmed."),
|
auth.UserNotConfirmed: (HTTPStatus.BAD_REQUEST, "Email not confirmed."),
|
||||||
auth.UserExists: (
|
auth.UserExists: (
|
||||||
|
@ -112,6 +124,10 @@ def async_setup(hass: HomeAssistant) -> None:
|
||||||
HTTPStatus.BAD_REQUEST,
|
HTTPStatus.BAD_REQUEST,
|
||||||
"Password change required.",
|
"Password change required.",
|
||||||
),
|
),
|
||||||
|
MFAExpiredOrNotStarted: (
|
||||||
|
HTTPStatus.BAD_REQUEST,
|
||||||
|
"Multi-factor authentication expired, or not started. Please try again.",
|
||||||
|
),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -206,19 +222,57 @@ class GoogleActionsSyncView(HomeAssistantView):
|
||||||
class CloudLoginView(HomeAssistantView):
|
class CloudLoginView(HomeAssistantView):
|
||||||
"""Login to Home Assistant cloud."""
|
"""Login to Home Assistant cloud."""
|
||||||
|
|
||||||
|
_mfa_tokens: dict[str, str] = {}
|
||||||
|
_mfa_tokens_set_time: float = 0
|
||||||
|
|
||||||
url = "/api/cloud/login"
|
url = "/api/cloud/login"
|
||||||
name = "api:cloud:login"
|
name = "api:cloud:login"
|
||||||
|
|
||||||
@require_admin
|
@require_admin
|
||||||
@_handle_cloud_errors
|
@_handle_cloud_errors
|
||||||
@RequestDataValidator(
|
@RequestDataValidator(
|
||||||
vol.Schema({vol.Required("email"): str, vol.Required("password"): str})
|
vol.Schema(
|
||||||
|
vol.All(
|
||||||
|
{
|
||||||
|
vol.Required("email"): str,
|
||||||
|
vol.Exclusive("password", "login"): str,
|
||||||
|
vol.Exclusive("code", "login"): str,
|
||||||
|
},
|
||||||
|
cv.has_at_least_one_key("password", "code"),
|
||||||
|
)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
async def post(self, request: web.Request, data: dict[str, Any]) -> web.Response:
|
async def post(self, request: web.Request, data: dict[str, Any]) -> web.Response:
|
||||||
"""Handle login request."""
|
"""Handle login request."""
|
||||||
hass = request.app[KEY_HASS]
|
hass = request.app[KEY_HASS]
|
||||||
cloud = hass.data[DATA_CLOUD]
|
cloud = hass.data[DATA_CLOUD]
|
||||||
await cloud.login(data["email"], data["password"])
|
|
||||||
|
try:
|
||||||
|
email = data["email"]
|
||||||
|
password = data.get("password")
|
||||||
|
code = data.get("code")
|
||||||
|
|
||||||
|
if email and password:
|
||||||
|
await cloud.login(email, password)
|
||||||
|
|
||||||
|
else:
|
||||||
|
if (
|
||||||
|
not self._mfa_tokens
|
||||||
|
or time.time() - self._mfa_tokens_set_time > LOGIN_MFA_TIMEOUT
|
||||||
|
):
|
||||||
|
raise MFAExpiredOrNotStarted
|
||||||
|
|
||||||
|
# Voluptuous should ensure that code is not None because password is
|
||||||
|
assert code is not None
|
||||||
|
|
||||||
|
await cloud.login_verify_totp(email, code, self._mfa_tokens)
|
||||||
|
self._mfa_tokens = {}
|
||||||
|
self._mfa_tokens_set_time = 0
|
||||||
|
|
||||||
|
except auth.MFARequired as mfa_err:
|
||||||
|
self._mfa_tokens = mfa_err.mfa_tokens
|
||||||
|
self._mfa_tokens_set_time = time.time()
|
||||||
|
raise
|
||||||
|
|
||||||
if "assist_pipeline" in hass.config.components:
|
if "assist_pipeline" in hass.config.components:
|
||||||
new_cloud_pipeline_id = await async_create_cloud_pipeline(hass)
|
new_cloud_pipeline_id = await async_create_cloud_pipeline(hass)
|
||||||
|
|
|
@ -8,7 +8,12 @@ from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from hass_nabucasa import thingtalk
|
from hass_nabucasa import thingtalk
|
||||||
from hass_nabucasa.auth import Unauthenticated, UnknownError
|
from hass_nabucasa.auth import (
|
||||||
|
InvalidTotpCode,
|
||||||
|
MFARequired,
|
||||||
|
Unauthenticated,
|
||||||
|
UnknownError,
|
||||||
|
)
|
||||||
from hass_nabucasa.const import STATE_CONNECTED
|
from hass_nabucasa.const import STATE_CONNECTED
|
||||||
from hass_nabucasa.voice import TTS_VOICES
|
from hass_nabucasa.voice import TTS_VOICES
|
||||||
import pytest
|
import pytest
|
||||||
|
@ -378,6 +383,128 @@ async def test_login_view_invalid_credentials(
|
||||||
assert req.status == HTTPStatus.UNAUTHORIZED
|
assert req.status == HTTPStatus.UNAUTHORIZED
|
||||||
|
|
||||||
|
|
||||||
|
async def test_login_view_mfa_required(
|
||||||
|
cloud: MagicMock,
|
||||||
|
setup_cloud: None,
|
||||||
|
hass_client: ClientSessionGenerator,
|
||||||
|
) -> None:
|
||||||
|
"""Test logging in when MFA is required."""
|
||||||
|
cloud_client = await hass_client()
|
||||||
|
cloud.login.side_effect = MFARequired(mfa_tokens={"session": "tokens"})
|
||||||
|
|
||||||
|
req = await cloud_client.post(
|
||||||
|
"/api/cloud/login", json={"email": "my_username", "password": "my_password"}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert req.status == HTTPStatus.UNAUTHORIZED
|
||||||
|
res = await req.json()
|
||||||
|
assert res["code"] == "mfarequired"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_login_view_mfa_required_tokens_missing(
|
||||||
|
cloud: MagicMock,
|
||||||
|
setup_cloud: None,
|
||||||
|
hass_client: ClientSessionGenerator,
|
||||||
|
) -> None:
|
||||||
|
"""Test logging in when MFA is required, code is provided, but session tokens are missing."""
|
||||||
|
cloud_client = await hass_client()
|
||||||
|
cloud.login.side_effect = MFARequired(mfa_tokens={})
|
||||||
|
|
||||||
|
# Login with password and get MFA required error
|
||||||
|
req = await cloud_client.post(
|
||||||
|
"/api/cloud/login", json={"email": "my_username", "password": "my_password"}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert req.status == HTTPStatus.UNAUTHORIZED
|
||||||
|
res = await req.json()
|
||||||
|
assert res["code"] == "mfarequired"
|
||||||
|
|
||||||
|
# Login with TOTP code and get MFA expired error
|
||||||
|
req = await cloud_client.post(
|
||||||
|
"/api/cloud/login",
|
||||||
|
json={"email": "my_username", "code": "123346"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert req.status == HTTPStatus.BAD_REQUEST
|
||||||
|
res = await req.json()
|
||||||
|
assert res["code"] == "mfaexpiredornotstarted"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_login_view_mfa_password_and_totp_provided(
|
||||||
|
cloud: MagicMock,
|
||||||
|
setup_cloud: None,
|
||||||
|
hass_client: ClientSessionGenerator,
|
||||||
|
) -> None:
|
||||||
|
"""Test logging in when password and TOTP code provided at once."""
|
||||||
|
cloud_client = await hass_client()
|
||||||
|
|
||||||
|
req = await cloud_client.post(
|
||||||
|
"/api/cloud/login",
|
||||||
|
json={"email": "my_username", "password": "my_password", "code": "123346"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert req.status == HTTPStatus.BAD_REQUEST
|
||||||
|
|
||||||
|
|
||||||
|
async def test_login_view_invalid_totp_code(
|
||||||
|
cloud: MagicMock,
|
||||||
|
setup_cloud: None,
|
||||||
|
hass_client: ClientSessionGenerator,
|
||||||
|
) -> None:
|
||||||
|
"""Test logging in when MFA is required and invalid code is provided."""
|
||||||
|
cloud_client = await hass_client()
|
||||||
|
cloud.login.side_effect = MFARequired(mfa_tokens={"session": "tokens"})
|
||||||
|
cloud.login_verify_totp.side_effect = InvalidTotpCode
|
||||||
|
|
||||||
|
# Login with password and get MFA required error
|
||||||
|
req = await cloud_client.post(
|
||||||
|
"/api/cloud/login", json={"email": "my_username", "password": "my_password"}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert req.status == HTTPStatus.UNAUTHORIZED
|
||||||
|
res = await req.json()
|
||||||
|
assert res["code"] == "mfarequired"
|
||||||
|
|
||||||
|
# Login with TOTP code and get invalid TOTP code error
|
||||||
|
req = await cloud_client.post(
|
||||||
|
"/api/cloud/login",
|
||||||
|
json={"email": "my_username", "code": "123346"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert req.status == HTTPStatus.BAD_REQUEST
|
||||||
|
res = await req.json()
|
||||||
|
assert res["code"] == "invalidtotpcode"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_login_view_valid_totp_provided(
|
||||||
|
cloud: MagicMock,
|
||||||
|
setup_cloud: None,
|
||||||
|
hass_client: ClientSessionGenerator,
|
||||||
|
) -> None:
|
||||||
|
"""Test logging in with valid TOTP code."""
|
||||||
|
cloud_client = await hass_client()
|
||||||
|
cloud.login.side_effect = MFARequired(mfa_tokens={"session": "tokens"})
|
||||||
|
|
||||||
|
# Login with password and get MFA required error
|
||||||
|
req = await cloud_client.post(
|
||||||
|
"/api/cloud/login", json={"email": "my_username", "password": "my_password"}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert req.status == HTTPStatus.UNAUTHORIZED
|
||||||
|
res = await req.json()
|
||||||
|
assert res["code"] == "mfarequired"
|
||||||
|
|
||||||
|
# Login with TOTP code and get success response
|
||||||
|
req = await cloud_client.post(
|
||||||
|
"/api/cloud/login",
|
||||||
|
json={"email": "my_username", "code": "123346"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert req.status == HTTPStatus.OK
|
||||||
|
result = await req.json()
|
||||||
|
assert result == {"success": True, "cloud_pipeline": None}
|
||||||
|
|
||||||
|
|
||||||
async def test_login_view_unknown_error(
|
async def test_login_view_unknown_error(
|
||||||
cloud: MagicMock,
|
cloud: MagicMock,
|
||||||
setup_cloud: None,
|
setup_cloud: None,
|
||||||
|
|
Loading…
Reference in New Issue