Add strict typing for auth (#75586)
parent
735dec8dde
commit
563ec67d39
|
@ -59,6 +59,7 @@ homeassistant.components.ampio.*
|
|||
homeassistant.components.anthemav.*
|
||||
homeassistant.components.aseko_pool_live.*
|
||||
homeassistant.components.asuswrt.*
|
||||
homeassistant.components.auth.*
|
||||
homeassistant.components.automation.*
|
||||
homeassistant.components.backup.*
|
||||
homeassistant.components.baf.*
|
||||
|
|
|
@ -124,15 +124,22 @@ as part of a config flow.
|
|||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import timedelta
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime, timedelta
|
||||
from http import HTTPStatus
|
||||
from typing import Any, Optional, cast
|
||||
import uuid
|
||||
|
||||
from aiohttp import web
|
||||
from multidict import MultiDictProxy
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.auth import InvalidAuthError
|
||||
from homeassistant.auth.models import TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN, Credentials
|
||||
from homeassistant.auth.models import (
|
||||
TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN,
|
||||
Credentials,
|
||||
User,
|
||||
)
|
||||
from homeassistant.components import websocket_api
|
||||
from homeassistant.components.http.auth import (
|
||||
async_sign_path,
|
||||
|
@ -151,11 +158,16 @@ from . import indieauth, login_flow, mfa_setup_flow
|
|||
|
||||
DOMAIN = "auth"
|
||||
|
||||
StoreResultType = Callable[[str, Credentials], str]
|
||||
RetrieveResultType = Callable[[str, str], Optional[Credentials]]
|
||||
|
||||
|
||||
@bind_hass
|
||||
def create_auth_code(hass, client_id: str, credential: Credentials) -> str:
|
||||
def create_auth_code(
|
||||
hass: HomeAssistant, client_id: str, credential: Credentials
|
||||
) -> str:
|
||||
"""Create an authorization code to fetch tokens."""
|
||||
return hass.data[DOMAIN](client_id, credential)
|
||||
return cast(StoreResultType, hass.data[DOMAIN])(client_id, credential)
|
||||
|
||||
|
||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
|
@ -188,15 +200,15 @@ class TokenView(HomeAssistantView):
|
|||
requires_auth = False
|
||||
cors_allowed = True
|
||||
|
||||
def __init__(self, retrieve_auth):
|
||||
def __init__(self, retrieve_auth: RetrieveResultType) -> None:
|
||||
"""Initialize the token view."""
|
||||
self._retrieve_auth = retrieve_auth
|
||||
|
||||
@log_invalid_auth
|
||||
async def post(self, request):
|
||||
async def post(self, request: web.Request) -> web.Response:
|
||||
"""Grant a token."""
|
||||
hass = request.app["hass"]
|
||||
data = await request.post()
|
||||
hass: HomeAssistant = request.app["hass"]
|
||||
data = cast(MultiDictProxy[str], await request.post())
|
||||
|
||||
grant_type = data.get("grant_type")
|
||||
|
||||
|
@ -217,7 +229,11 @@ class TokenView(HomeAssistantView):
|
|||
{"error": "unsupported_grant_type"}, status_code=HTTPStatus.BAD_REQUEST
|
||||
)
|
||||
|
||||
async def _async_handle_revoke_token(self, hass, data):
|
||||
async def _async_handle_revoke_token(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
data: MultiDictProxy[str],
|
||||
) -> web.Response:
|
||||
"""Handle revoke token request."""
|
||||
|
||||
# OAuth 2.0 Token Revocation [RFC7009]
|
||||
|
@ -235,7 +251,12 @@ class TokenView(HomeAssistantView):
|
|||
await hass.auth.async_remove_refresh_token(refresh_token)
|
||||
return web.Response(status=HTTPStatus.OK)
|
||||
|
||||
async def _async_handle_auth_code(self, hass, data, remote_addr):
|
||||
async def _async_handle_auth_code(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
data: MultiDictProxy[str],
|
||||
remote_addr: str | None,
|
||||
) -> web.Response:
|
||||
"""Handle authorization code request."""
|
||||
client_id = data.get("client_id")
|
||||
if client_id is None or not indieauth.verify_client_id(client_id):
|
||||
|
@ -298,7 +319,12 @@ class TokenView(HomeAssistantView):
|
|||
},
|
||||
)
|
||||
|
||||
async def _async_handle_refresh_token(self, hass, data, remote_addr):
|
||||
async def _async_handle_refresh_token(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
data: MultiDictProxy[str],
|
||||
remote_addr: str | None,
|
||||
) -> web.Response:
|
||||
"""Handle authorization code request."""
|
||||
client_id = data.get("client_id")
|
||||
if client_id is not None and not indieauth.verify_client_id(client_id):
|
||||
|
@ -366,15 +392,15 @@ class LinkUserView(HomeAssistantView):
|
|||
url = "/auth/link_user"
|
||||
name = "api:auth:link_user"
|
||||
|
||||
def __init__(self, retrieve_credentials):
|
||||
def __init__(self, retrieve_credentials: RetrieveResultType) -> None:
|
||||
"""Initialize the link user view."""
|
||||
self._retrieve_credentials = retrieve_credentials
|
||||
|
||||
@RequestDataValidator(vol.Schema({"code": str, "client_id": str}))
|
||||
async def post(self, request, data):
|
||||
async def post(self, request: web.Request, data: dict[str, Any]) -> web.Response:
|
||||
"""Link a user."""
|
||||
hass = request.app["hass"]
|
||||
user = request["hass_user"]
|
||||
hass: HomeAssistant = request.app["hass"]
|
||||
user: User = request["hass_user"]
|
||||
|
||||
credentials = self._retrieve_credentials(data["client_id"], data["code"])
|
||||
|
||||
|
@ -394,12 +420,12 @@ class LinkUserView(HomeAssistantView):
|
|||
|
||||
|
||||
@callback
|
||||
def _create_auth_code_store():
|
||||
def _create_auth_code_store() -> tuple[StoreResultType, RetrieveResultType]:
|
||||
"""Create an in memory store."""
|
||||
temp_results = {}
|
||||
temp_results: dict[tuple[str, str], tuple[datetime, Credentials]] = {}
|
||||
|
||||
@callback
|
||||
def store_result(client_id, result):
|
||||
def store_result(client_id: str, result: Credentials) -> str:
|
||||
"""Store flow result and return a code to retrieve it."""
|
||||
if not isinstance(result, Credentials):
|
||||
raise ValueError("result has to be a Credentials instance")
|
||||
|
@ -412,7 +438,7 @@ def _create_auth_code_store():
|
|||
return code
|
||||
|
||||
@callback
|
||||
def retrieve_result(client_id, code):
|
||||
def retrieve_result(client_id: str, code: str) -> Credentials | None:
|
||||
"""Retrieve flow result."""
|
||||
key = (client_id, code)
|
||||
|
||||
|
@ -437,8 +463,8 @@ def _create_auth_code_store():
|
|||
@websocket_api.ws_require_user()
|
||||
@websocket_api.async_response
|
||||
async def websocket_current_user(
|
||||
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg
|
||||
):
|
||||
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict[str, Any]
|
||||
) -> None:
|
||||
"""Return the current user."""
|
||||
user = connection.user
|
||||
enabled_modules = await hass.auth.async_get_enabled_mfa(user)
|
||||
|
@ -482,8 +508,8 @@ async def websocket_current_user(
|
|||
@websocket_api.ws_require_user()
|
||||
@websocket_api.async_response
|
||||
async def websocket_create_long_lived_access_token(
|
||||
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg
|
||||
):
|
||||
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict[str, Any]
|
||||
) -> None:
|
||||
"""Create or a long-lived access token."""
|
||||
refresh_token = await hass.auth.async_create_refresh_token(
|
||||
connection.user,
|
||||
|
@ -506,12 +532,12 @@ async def websocket_create_long_lived_access_token(
|
|||
@websocket_api.ws_require_user()
|
||||
@callback
|
||||
def websocket_refresh_tokens(
|
||||
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg
|
||||
):
|
||||
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict[str, Any]
|
||||
) -> None:
|
||||
"""Return metadata of users refresh tokens."""
|
||||
current_id = connection.refresh_token_id
|
||||
|
||||
tokens = []
|
||||
tokens: list[dict[str, Any]] = []
|
||||
for refresh in connection.user.refresh_tokens.values():
|
||||
if refresh.credential:
|
||||
auth_provider_type = refresh.credential.auth_provider_type
|
||||
|
@ -545,8 +571,8 @@ def websocket_refresh_tokens(
|
|||
@websocket_api.ws_require_user()
|
||||
@websocket_api.async_response
|
||||
async def websocket_delete_refresh_token(
|
||||
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg
|
||||
):
|
||||
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict[str, Any]
|
||||
) -> None:
|
||||
"""Handle a delete refresh token request."""
|
||||
refresh_token = connection.user.refresh_tokens.get(msg["refresh_token_id"])
|
||||
|
||||
|
@ -569,8 +595,8 @@ async def websocket_delete_refresh_token(
|
|||
@websocket_api.ws_require_user()
|
||||
@callback
|
||||
def websocket_sign_path(
|
||||
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg
|
||||
):
|
||||
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict[str, Any]
|
||||
) -> None:
|
||||
"""Handle a sign path request."""
|
||||
connection.send_message(
|
||||
websocket_api.result_message(
|
||||
|
|
|
@ -1,18 +1,24 @@
|
|||
"""Helpers to resolve client ID/secret."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from html.parser import HTMLParser
|
||||
from ipaddress import ip_address
|
||||
import logging
|
||||
from urllib.parse import urljoin, urlparse
|
||||
from urllib.parse import ParseResult, urljoin, urlparse
|
||||
|
||||
import aiohttp
|
||||
import aiohttp.client_exceptions
|
||||
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.util.network import is_local
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def verify_redirect_uri(hass, client_id, redirect_uri):
|
||||
async def verify_redirect_uri(
|
||||
hass: HomeAssistant, client_id: str, redirect_uri: str
|
||||
) -> bool:
|
||||
"""Verify that the client and redirect uri match."""
|
||||
try:
|
||||
client_id_parts = _parse_client_id(client_id)
|
||||
|
@ -47,24 +53,24 @@ async def verify_redirect_uri(hass, client_id, redirect_uri):
|
|||
class LinkTagParser(HTMLParser):
|
||||
"""Parser to find link tags."""
|
||||
|
||||
def __init__(self, rel):
|
||||
def __init__(self, rel: str) -> None:
|
||||
"""Initialize a link tag parser."""
|
||||
super().__init__()
|
||||
self.rel = rel
|
||||
self.found = []
|
||||
self.found: list[str | None] = []
|
||||
|
||||
def handle_starttag(self, tag, attrs):
|
||||
def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None:
|
||||
"""Handle finding a start tag."""
|
||||
if tag != "link":
|
||||
return
|
||||
|
||||
attrs = dict(attrs)
|
||||
attributes: dict[str, str | None] = dict(attrs)
|
||||
|
||||
if attrs.get("rel") == self.rel:
|
||||
self.found.append(attrs.get("href"))
|
||||
if attributes.get("rel") == self.rel:
|
||||
self.found.append(attributes.get("href"))
|
||||
|
||||
|
||||
async def fetch_redirect_uris(hass, url):
|
||||
async def fetch_redirect_uris(hass: HomeAssistant, url: str) -> list[str]:
|
||||
"""Find link tag with redirect_uri values.
|
||||
|
||||
IndieAuth 4.2.2
|
||||
|
@ -108,7 +114,7 @@ async def fetch_redirect_uris(hass, url):
|
|||
return [urljoin(url, found) for found in parser.found]
|
||||
|
||||
|
||||
def verify_client_id(client_id):
|
||||
def verify_client_id(client_id: str) -> bool:
|
||||
"""Verify that the client id is valid."""
|
||||
try:
|
||||
_parse_client_id(client_id)
|
||||
|
@ -117,7 +123,7 @@ def verify_client_id(client_id):
|
|||
return False
|
||||
|
||||
|
||||
def _parse_url(url):
|
||||
def _parse_url(url: str) -> ParseResult:
|
||||
"""Parse a url in parts and canonicalize according to IndieAuth."""
|
||||
parts = urlparse(url)
|
||||
|
||||
|
@ -134,7 +140,7 @@ def _parse_url(url):
|
|||
return parts
|
||||
|
||||
|
||||
def _parse_client_id(client_id):
|
||||
def _parse_client_id(client_id: str) -> ParseResult:
|
||||
"""Test if client id is a valid URL according to IndieAuth section 3.2.
|
||||
|
||||
https://indieauth.spec.indieweb.org/#client-identifier
|
||||
|
|
|
@ -66,14 +66,19 @@ associate with an credential if "type" set to "link_user" in
|
|||
"version": 1
|
||||
}
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from http import HTTPStatus
|
||||
from ipaddress import ip_address
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from aiohttp import web
|
||||
import voluptuous as vol
|
||||
import voluptuous_serialize
|
||||
|
||||
from homeassistant import data_entry_flow
|
||||
from homeassistant.auth import AuthManagerFlowManager
|
||||
from homeassistant.auth.models import Credentials
|
||||
from homeassistant.components import onboarding
|
||||
from homeassistant.components.http.auth import async_user_not_allowed_do_auth
|
||||
|
@ -88,8 +93,13 @@ from homeassistant.core import HomeAssistant
|
|||
|
||||
from . import indieauth
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from . import StoreResultType
|
||||
|
||||
async def async_setup(hass, store_result):
|
||||
|
||||
async def async_setup(
|
||||
hass: HomeAssistant, store_result: Callable[[str, Credentials], str]
|
||||
) -> None:
|
||||
"""Component to allow users to login."""
|
||||
hass.http.register_view(AuthProvidersView)
|
||||
hass.http.register_view(LoginFlowIndexView(hass.auth.login_flow, store_result))
|
||||
|
@ -103,9 +113,9 @@ class AuthProvidersView(HomeAssistantView):
|
|||
name = "api:auth:providers"
|
||||
requires_auth = False
|
||||
|
||||
async def get(self, request):
|
||||
async def get(self, request: web.Request) -> web.Response:
|
||||
"""Get available auth providers."""
|
||||
hass = request.app["hass"]
|
||||
hass: HomeAssistant = request.app["hass"]
|
||||
if not onboarding.async_is_user_onboarded(hass):
|
||||
return self.json_message(
|
||||
message="Onboarding not finished",
|
||||
|
@ -121,7 +131,9 @@ class AuthProvidersView(HomeAssistantView):
|
|||
)
|
||||
|
||||
|
||||
def _prepare_result_json(result):
|
||||
def _prepare_result_json(
|
||||
result: data_entry_flow.FlowResult,
|
||||
) -> data_entry_flow.FlowResult:
|
||||
"""Convert result to JSON."""
|
||||
if result["type"] == data_entry_flow.FlowResultType.CREATE_ENTRY:
|
||||
data = result.copy()
|
||||
|
@ -147,12 +159,21 @@ class LoginFlowBaseView(HomeAssistantView):
|
|||
|
||||
requires_auth = False
|
||||
|
||||
def __init__(self, flow_mgr, store_result):
|
||||
def __init__(
|
||||
self,
|
||||
flow_mgr: AuthManagerFlowManager,
|
||||
store_result: StoreResultType,
|
||||
) -> None:
|
||||
"""Initialize the flow manager index view."""
|
||||
self._flow_mgr = flow_mgr
|
||||
self._store_result = store_result
|
||||
|
||||
async def _async_flow_result_to_response(self, request, client_id, result):
|
||||
async def _async_flow_result_to_response(
|
||||
self,
|
||||
request: web.Request,
|
||||
client_id: str,
|
||||
result: data_entry_flow.FlowResult,
|
||||
) -> web.Response:
|
||||
"""Convert the flow result to a response."""
|
||||
if result["type"] != data_entry_flow.FlowResultType.CREATE_ENTRY:
|
||||
# @log_invalid_auth does not work here since it returns HTTP 200.
|
||||
|
@ -196,7 +217,7 @@ class LoginFlowIndexView(LoginFlowBaseView):
|
|||
url = "/auth/login_flow"
|
||||
name = "api:auth:login_flow"
|
||||
|
||||
async def get(self, request):
|
||||
async def get(self, request: web.Request) -> web.Response:
|
||||
"""Do not allow index of flows in progress."""
|
||||
return web.Response(status=HTTPStatus.METHOD_NOT_ALLOWED)
|
||||
|
||||
|
@ -211,15 +232,18 @@ class LoginFlowIndexView(LoginFlowBaseView):
|
|||
)
|
||||
)
|
||||
@log_invalid_auth
|
||||
async def post(self, request, data):
|
||||
async def post(self, request: web.Request, data: dict[str, Any]) -> web.Response:
|
||||
"""Create a new login flow."""
|
||||
if not await indieauth.verify_redirect_uri(
|
||||
request.app["hass"], data["client_id"], data["redirect_uri"]
|
||||
):
|
||||
hass: HomeAssistant = request.app["hass"]
|
||||
client_id: str = data["client_id"]
|
||||
redirect_uri: str = data["redirect_uri"]
|
||||
|
||||
if not await indieauth.verify_redirect_uri(hass, client_id, redirect_uri):
|
||||
return self.json_message(
|
||||
"invalid client id or redirect uri", HTTPStatus.BAD_REQUEST
|
||||
)
|
||||
|
||||
handler: tuple[str, ...] | str
|
||||
if isinstance(data["handler"], list):
|
||||
handler = tuple(data["handler"])
|
||||
else:
|
||||
|
@ -227,9 +251,9 @@ class LoginFlowIndexView(LoginFlowBaseView):
|
|||
|
||||
try:
|
||||
result = await self._flow_mgr.async_init(
|
||||
handler,
|
||||
handler, # type: ignore[arg-type]
|
||||
context={
|
||||
"ip_address": ip_address(request.remote),
|
||||
"ip_address": ip_address(request.remote), # type: ignore[arg-type]
|
||||
"credential_only": data.get("type") == "link_user",
|
||||
},
|
||||
)
|
||||
|
@ -240,9 +264,7 @@ class LoginFlowIndexView(LoginFlowBaseView):
|
|||
"Handler does not support init", HTTPStatus.BAD_REQUEST
|
||||
)
|
||||
|
||||
return await self._async_flow_result_to_response(
|
||||
request, data["client_id"], result
|
||||
)
|
||||
return await self._async_flow_result_to_response(request, client_id, result)
|
||||
|
||||
|
||||
class LoginFlowResourceView(LoginFlowBaseView):
|
||||
|
@ -251,13 +273,15 @@ class LoginFlowResourceView(LoginFlowBaseView):
|
|||
url = "/auth/login_flow/{flow_id}"
|
||||
name = "api:auth:login_flow:resource"
|
||||
|
||||
async def get(self, request):
|
||||
async def get(self, request: web.Request) -> web.Response:
|
||||
"""Do not allow getting status of a flow in progress."""
|
||||
return self.json_message("Invalid flow specified", HTTPStatus.NOT_FOUND)
|
||||
|
||||
@RequestDataValidator(vol.Schema({"client_id": str}, extra=vol.ALLOW_EXTRA))
|
||||
@log_invalid_auth
|
||||
async def post(self, request, data, flow_id):
|
||||
async def post(
|
||||
self, request: web.Request, data: dict[str, Any], flow_id: str
|
||||
) -> web.Response:
|
||||
"""Handle progressing a login flow request."""
|
||||
client_id = data.pop("client_id")
|
||||
|
||||
|
@ -267,7 +291,7 @@ class LoginFlowResourceView(LoginFlowBaseView):
|
|||
try:
|
||||
# do not allow change ip during login flow
|
||||
flow = self._flow_mgr.async_get(flow_id)
|
||||
if flow["context"]["ip_address"] != ip_address(request.remote):
|
||||
if flow["context"]["ip_address"] != ip_address(request.remote): # type: ignore[arg-type]
|
||||
return self.json_message("IP address changed", HTTPStatus.BAD_REQUEST)
|
||||
result = await self._flow_mgr.async_configure(flow_id, data)
|
||||
except data_entry_flow.UnknownFlow:
|
||||
|
@ -277,7 +301,7 @@ class LoginFlowResourceView(LoginFlowBaseView):
|
|||
|
||||
return await self._async_flow_result_to_response(request, client_id, result)
|
||||
|
||||
async def delete(self, request, flow_id):
|
||||
async def delete(self, request: web.Request, flow_id: str) -> web.Response:
|
||||
"""Cancel a flow in progress."""
|
||||
try:
|
||||
self._flow_mgr.async_abort(flow_id)
|
||||
|
|
|
@ -1,5 +1,8 @@
|
|||
"""Helpers to setup multi-factor auth module."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import voluptuous as vol
|
||||
import voluptuous_serialize
|
||||
|
@ -7,15 +10,19 @@ import voluptuous_serialize
|
|||
from homeassistant import data_entry_flow
|
||||
from homeassistant.components import websocket_api
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
|
||||
WS_TYPE_SETUP_MFA = "auth/setup_mfa"
|
||||
SCHEMA_WS_SETUP_MFA = websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend(
|
||||
{
|
||||
vol.Required("type"): WS_TYPE_SETUP_MFA,
|
||||
vol.Exclusive("mfa_module_id", "module_or_flow_id"): str,
|
||||
vol.Exclusive("flow_id", "module_or_flow_id"): str,
|
||||
vol.Optional("user_input"): object,
|
||||
}
|
||||
SCHEMA_WS_SETUP_MFA = vol.All(
|
||||
websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend(
|
||||
{
|
||||
vol.Required("type"): WS_TYPE_SETUP_MFA,
|
||||
vol.Exclusive("mfa_module_id", "module_or_flow_id"): str,
|
||||
vol.Exclusive("flow_id", "module_or_flow_id"): str,
|
||||
vol.Optional("user_input"): object,
|
||||
}
|
||||
),
|
||||
cv.has_at_least_one_key("mfa_module_id", "flow_id"),
|
||||
)
|
||||
|
||||
WS_TYPE_DEPOSE_MFA = "auth/depose_mfa"
|
||||
|
@ -31,7 +38,13 @@ _LOGGER = logging.getLogger(__name__)
|
|||
class MfaFlowManager(data_entry_flow.FlowManager):
|
||||
"""Manage multi factor authentication flows."""
|
||||
|
||||
async def async_create_flow(self, handler_key, *, context, data):
|
||||
async def async_create_flow( # type: ignore[override]
|
||||
self,
|
||||
handler_key: Any,
|
||||
*,
|
||||
context: dict[str, Any],
|
||||
data: dict[str, Any],
|
||||
) -> data_entry_flow.FlowHandler:
|
||||
"""Create a setup flow. handler is a mfa module."""
|
||||
mfa_module = self.hass.auth.get_auth_mfa_module(handler_key)
|
||||
if mfa_module is None:
|
||||
|
@ -40,13 +53,15 @@ class MfaFlowManager(data_entry_flow.FlowManager):
|
|||
user_id = data.pop("user_id")
|
||||
return await mfa_module.async_setup_flow(user_id)
|
||||
|
||||
async def async_finish_flow(self, flow, result):
|
||||
async def async_finish_flow(
|
||||
self, flow: data_entry_flow.FlowHandler, result: data_entry_flow.FlowResult
|
||||
) -> data_entry_flow.FlowResult:
|
||||
"""Complete an mfs setup flow."""
|
||||
_LOGGER.debug("flow_result: %s", result)
|
||||
return result
|
||||
|
||||
|
||||
async def async_setup(hass):
|
||||
async def async_setup(hass: HomeAssistant) -> None:
|
||||
"""Init mfa setup flow manager."""
|
||||
hass.data[DATA_SETUP_FLOW_MGR] = MfaFlowManager(hass)
|
||||
|
||||
|
@ -62,13 +77,13 @@ async def async_setup(hass):
|
|||
@callback
|
||||
@websocket_api.ws_require_user(allow_system_user=False)
|
||||
def websocket_setup_mfa(
|
||||
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg
|
||||
):
|
||||
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict[str, Any]
|
||||
) -> None:
|
||||
"""Return a setup flow for mfa auth module."""
|
||||
|
||||
async def async_setup_flow(msg):
|
||||
async def async_setup_flow(msg: dict[str, Any]) -> None:
|
||||
"""Return a setup flow for mfa auth module."""
|
||||
flow_manager = hass.data[DATA_SETUP_FLOW_MGR]
|
||||
flow_manager: MfaFlowManager = hass.data[DATA_SETUP_FLOW_MGR]
|
||||
|
||||
if (flow_id := msg.get("flow_id")) is not None:
|
||||
result = await flow_manager.async_configure(flow_id, msg.get("user_input"))
|
||||
|
@ -77,9 +92,8 @@ def websocket_setup_mfa(
|
|||
)
|
||||
return
|
||||
|
||||
mfa_module_id = msg.get("mfa_module_id")
|
||||
mfa_module = hass.auth.get_auth_mfa_module(mfa_module_id)
|
||||
if mfa_module is None:
|
||||
mfa_module_id = msg["mfa_module_id"]
|
||||
if hass.auth.get_auth_mfa_module(mfa_module_id) is None:
|
||||
connection.send_message(
|
||||
websocket_api.error_message(
|
||||
msg["id"], "no_module", f"MFA module {mfa_module_id} is not found"
|
||||
|
@ -101,11 +115,11 @@ def websocket_setup_mfa(
|
|||
@callback
|
||||
@websocket_api.ws_require_user(allow_system_user=False)
|
||||
def websocket_depose_mfa(
|
||||
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg
|
||||
):
|
||||
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict[str, Any]
|
||||
) -> None:
|
||||
"""Remove user from mfa module."""
|
||||
|
||||
async def async_depose(msg):
|
||||
async def async_depose(msg: dict[str, Any]) -> None:
|
||||
"""Remove user from mfa auth module."""
|
||||
mfa_module_id = msg["mfa_module_id"]
|
||||
try:
|
||||
|
@ -127,7 +141,9 @@ def websocket_depose_mfa(
|
|||
hass.async_create_task(async_depose(msg))
|
||||
|
||||
|
||||
def _prepare_result_json(result):
|
||||
def _prepare_result_json(
|
||||
result: data_entry_flow.FlowResult,
|
||||
) -> data_entry_flow.FlowResult:
|
||||
"""Convert result to JSON."""
|
||||
if result["type"] == data_entry_flow.FlowResultType.CREATE_ENTRY:
|
||||
data = result.copy()
|
||||
|
|
|
@ -175,7 +175,7 @@ class FlowManager(abc.ABC):
|
|||
)
|
||||
|
||||
@callback
|
||||
def async_get(self, flow_id: str) -> FlowResult | None:
|
||||
def async_get(self, flow_id: str) -> FlowResult:
|
||||
"""Return a flow in progress as a partial FlowResult."""
|
||||
if (flow := self._progress.get(flow_id)) is None:
|
||||
raise UnknownFlow
|
||||
|
|
10
mypy.ini
10
mypy.ini
|
@ -349,6 +349,16 @@ disallow_untyped_defs = true
|
|||
warn_return_any = true
|
||||
warn_unreachable = true
|
||||
|
||||
[mypy-homeassistant.components.auth.*]
|
||||
check_untyped_defs = true
|
||||
disallow_incomplete_defs = true
|
||||
disallow_subclassing_any = true
|
||||
disallow_untyped_calls = true
|
||||
disallow_untyped_decorators = true
|
||||
disallow_untyped_defs = true
|
||||
warn_return_any = true
|
||||
warn_unreachable = true
|
||||
|
||||
[mypy-homeassistant.components.automation.*]
|
||||
check_untyped_defs = true
|
||||
disallow_incomplete_defs = true
|
||||
|
|
|
@ -44,7 +44,13 @@ async def test_ws_setup_depose_mfa(hass, hass_ws_client):
|
|||
|
||||
client = await hass_ws_client(hass, access_token)
|
||||
|
||||
await client.send_json({"id": 10, "type": mfa_setup_flow.WS_TYPE_SETUP_MFA})
|
||||
await client.send_json(
|
||||
{
|
||||
"id": 10,
|
||||
"type": mfa_setup_flow.WS_TYPE_SETUP_MFA,
|
||||
"mfa_module_id": "invalid_module",
|
||||
}
|
||||
)
|
||||
|
||||
result = await client.receive_json()
|
||||
assert result["id"] == 10
|
||||
|
|
Loading…
Reference in New Issue