Add MFA login flow support for cloud component (#132497)

* Add MFA login flow support for cloud component

* Add tests for cloud MFA login

* Update code to reflect used package changes

* Update code to use underlying package changes

* Remove unused change

* Fix login required parameters

* Fix parameter validation

* Use cv.has_at_least_one_key for param validation

---------

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
pull/133431/head
Krisjanis Lejejs 2024-12-17 15:44:50 +00:00 committed by GitHub
parent 5b1c5bf9f6
commit a14aca31e5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 186 additions and 3 deletions

View File

@ -88,3 +88,5 @@ DISPATCHER_REMOTE_UPDATE: SignalType[Any] = SignalType("cloud_remote_update")
STT_ENTITY_UNIQUE_ID = "cloud-speech-to-text"
TTS_ENTITY_UNIQUE_ID = "cloud-text-to-speech"
LOGIN_MFA_TIMEOUT = 60

View File

@ -9,6 +9,7 @@ import dataclasses
from functools import wraps
from http import HTTPStatus
import logging
import time
from typing import Any, Concatenate
import aiohttp
@ -31,6 +32,7 @@ from homeassistant.components.http.data_validator import RequestDataValidator
from homeassistant.const import CLOUD_NEVER_EXPOSED_ENTITIES
from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.aiohttp_client import async_get_clientsession
from homeassistant.util.location import async_detect_location_info
@ -39,6 +41,7 @@ from .assist_pipeline import async_create_cloud_pipeline
from .client import CloudClient
from .const import (
DATA_CLOUD,
LOGIN_MFA_TIMEOUT,
PREF_ALEXA_REPORT_STATE,
PREF_DISABLE_2FA,
PREF_ENABLE_ALEXA,
@ -69,6 +72,10 @@ _CLOUD_ERRORS: dict[type[Exception], tuple[HTTPStatus, str]] = {
}
class MFAExpiredOrNotStarted(auth.CloudError):
"""Multi-factor authentication expired, or not started."""
@callback
def async_setup(hass: HomeAssistant) -> None:
"""Initialize the HTTP API."""
@ -101,6 +108,11 @@ def async_setup(hass: HomeAssistant) -> None:
_CLOUD_ERRORS.update(
{
auth.InvalidTotpCode: (HTTPStatus.BAD_REQUEST, "Invalid TOTP code."),
auth.MFARequired: (
HTTPStatus.UNAUTHORIZED,
"Multi-factor authentication required.",
),
auth.UserNotFound: (HTTPStatus.BAD_REQUEST, "User does not exist."),
auth.UserNotConfirmed: (HTTPStatus.BAD_REQUEST, "Email not confirmed."),
auth.UserExists: (
@ -112,6 +124,10 @@ def async_setup(hass: HomeAssistant) -> None:
HTTPStatus.BAD_REQUEST,
"Password change required.",
),
MFAExpiredOrNotStarted: (
HTTPStatus.BAD_REQUEST,
"Multi-factor authentication expired, or not started. Please try again.",
),
}
)
@ -206,19 +222,57 @@ class GoogleActionsSyncView(HomeAssistantView):
class CloudLoginView(HomeAssistantView):
"""Login to Home Assistant cloud."""
_mfa_tokens: dict[str, str] = {}
_mfa_tokens_set_time: float = 0
url = "/api/cloud/login"
name = "api:cloud:login"
@require_admin
@_handle_cloud_errors
@RequestDataValidator(
vol.Schema({vol.Required("email"): str, vol.Required("password"): str})
vol.Schema(
vol.All(
{
vol.Required("email"): str,
vol.Exclusive("password", "login"): str,
vol.Exclusive("code", "login"): str,
},
cv.has_at_least_one_key("password", "code"),
)
)
)
async def post(self, request: web.Request, data: dict[str, Any]) -> web.Response:
"""Handle login request."""
hass = request.app[KEY_HASS]
cloud = hass.data[DATA_CLOUD]
await cloud.login(data["email"], data["password"])
try:
email = data["email"]
password = data.get("password")
code = data.get("code")
if email and password:
await cloud.login(email, password)
else:
if (
not self._mfa_tokens
or time.time() - self._mfa_tokens_set_time > LOGIN_MFA_TIMEOUT
):
raise MFAExpiredOrNotStarted
# Voluptuous should ensure that code is not None because password is
assert code is not None
await cloud.login_verify_totp(email, code, self._mfa_tokens)
self._mfa_tokens = {}
self._mfa_tokens_set_time = 0
except auth.MFARequired as mfa_err:
self._mfa_tokens = mfa_err.mfa_tokens
self._mfa_tokens_set_time = time.time()
raise
if "assist_pipeline" in hass.config.components:
new_cloud_pipeline_id = await async_create_cloud_pipeline(hass)

View File

@ -8,7 +8,12 @@ from unittest.mock import AsyncMock, MagicMock, Mock, patch
import aiohttp
from hass_nabucasa import thingtalk
from hass_nabucasa.auth import Unauthenticated, UnknownError
from hass_nabucasa.auth import (
InvalidTotpCode,
MFARequired,
Unauthenticated,
UnknownError,
)
from hass_nabucasa.const import STATE_CONNECTED
from hass_nabucasa.voice import TTS_VOICES
import pytest
@ -378,6 +383,128 @@ async def test_login_view_invalid_credentials(
assert req.status == HTTPStatus.UNAUTHORIZED
async def test_login_view_mfa_required(
cloud: MagicMock,
setup_cloud: None,
hass_client: ClientSessionGenerator,
) -> None:
"""Test logging in when MFA is required."""
cloud_client = await hass_client()
cloud.login.side_effect = MFARequired(mfa_tokens={"session": "tokens"})
req = await cloud_client.post(
"/api/cloud/login", json={"email": "my_username", "password": "my_password"}
)
assert req.status == HTTPStatus.UNAUTHORIZED
res = await req.json()
assert res["code"] == "mfarequired"
async def test_login_view_mfa_required_tokens_missing(
cloud: MagicMock,
setup_cloud: None,
hass_client: ClientSessionGenerator,
) -> None:
"""Test logging in when MFA is required, code is provided, but session tokens are missing."""
cloud_client = await hass_client()
cloud.login.side_effect = MFARequired(mfa_tokens={})
# Login with password and get MFA required error
req = await cloud_client.post(
"/api/cloud/login", json={"email": "my_username", "password": "my_password"}
)
assert req.status == HTTPStatus.UNAUTHORIZED
res = await req.json()
assert res["code"] == "mfarequired"
# Login with TOTP code and get MFA expired error
req = await cloud_client.post(
"/api/cloud/login",
json={"email": "my_username", "code": "123346"},
)
assert req.status == HTTPStatus.BAD_REQUEST
res = await req.json()
assert res["code"] == "mfaexpiredornotstarted"
async def test_login_view_mfa_password_and_totp_provided(
cloud: MagicMock,
setup_cloud: None,
hass_client: ClientSessionGenerator,
) -> None:
"""Test logging in when password and TOTP code provided at once."""
cloud_client = await hass_client()
req = await cloud_client.post(
"/api/cloud/login",
json={"email": "my_username", "password": "my_password", "code": "123346"},
)
assert req.status == HTTPStatus.BAD_REQUEST
async def test_login_view_invalid_totp_code(
cloud: MagicMock,
setup_cloud: None,
hass_client: ClientSessionGenerator,
) -> None:
"""Test logging in when MFA is required and invalid code is provided."""
cloud_client = await hass_client()
cloud.login.side_effect = MFARequired(mfa_tokens={"session": "tokens"})
cloud.login_verify_totp.side_effect = InvalidTotpCode
# Login with password and get MFA required error
req = await cloud_client.post(
"/api/cloud/login", json={"email": "my_username", "password": "my_password"}
)
assert req.status == HTTPStatus.UNAUTHORIZED
res = await req.json()
assert res["code"] == "mfarequired"
# Login with TOTP code and get invalid TOTP code error
req = await cloud_client.post(
"/api/cloud/login",
json={"email": "my_username", "code": "123346"},
)
assert req.status == HTTPStatus.BAD_REQUEST
res = await req.json()
assert res["code"] == "invalidtotpcode"
async def test_login_view_valid_totp_provided(
cloud: MagicMock,
setup_cloud: None,
hass_client: ClientSessionGenerator,
) -> None:
"""Test logging in with valid TOTP code."""
cloud_client = await hass_client()
cloud.login.side_effect = MFARequired(mfa_tokens={"session": "tokens"})
# Login with password and get MFA required error
req = await cloud_client.post(
"/api/cloud/login", json={"email": "my_username", "password": "my_password"}
)
assert req.status == HTTPStatus.UNAUTHORIZED
res = await req.json()
assert res["code"] == "mfarequired"
# Login with TOTP code and get success response
req = await cloud_client.post(
"/api/cloud/login",
json={"email": "my_username", "code": "123346"},
)
assert req.status == HTTPStatus.OK
result = await req.json()
assert result == {"success": True, "cloud_pipeline": None}
async def test_login_view_unknown_error(
cloud: MagicMock,
setup_cloud: None,