Add strict typing for auth (#75586)

pull/76875/head
Marc Mueller 2022-08-16 16:10:37 +02:00 committed by GitHub
parent 735dec8dde
commit 563ec67d39
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 174 additions and 85 deletions

View File

@ -59,6 +59,7 @@ homeassistant.components.ampio.*
homeassistant.components.anthemav.*
homeassistant.components.aseko_pool_live.*
homeassistant.components.asuswrt.*
homeassistant.components.auth.*
homeassistant.components.automation.*
homeassistant.components.backup.*
homeassistant.components.baf.*

View File

@ -124,15 +124,22 @@ as part of a config flow.
"""
from __future__ import annotations
from datetime import timedelta
from collections.abc import Callable
from datetime import datetime, timedelta
from http import HTTPStatus
from typing import Any, Optional, cast
import uuid
from aiohttp import web
from multidict import MultiDictProxy
import voluptuous as vol
from homeassistant.auth import InvalidAuthError
from homeassistant.auth.models import TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN, Credentials
from homeassistant.auth.models import (
TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN,
Credentials,
User,
)
from homeassistant.components import websocket_api
from homeassistant.components.http.auth import (
async_sign_path,
@ -151,11 +158,16 @@ from . import indieauth, login_flow, mfa_setup_flow
DOMAIN = "auth"
StoreResultType = Callable[[str, Credentials], str]
RetrieveResultType = Callable[[str, str], Optional[Credentials]]
@bind_hass
def create_auth_code(hass, client_id: str, credential: Credentials) -> str:
def create_auth_code(
hass: HomeAssistant, client_id: str, credential: Credentials
) -> str:
"""Create an authorization code to fetch tokens."""
return hass.data[DOMAIN](client_id, credential)
return cast(StoreResultType, hass.data[DOMAIN])(client_id, credential)
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
@ -188,15 +200,15 @@ class TokenView(HomeAssistantView):
requires_auth = False
cors_allowed = True
def __init__(self, retrieve_auth):
def __init__(self, retrieve_auth: RetrieveResultType) -> None:
"""Initialize the token view."""
self._retrieve_auth = retrieve_auth
@log_invalid_auth
async def post(self, request):
async def post(self, request: web.Request) -> web.Response:
"""Grant a token."""
hass = request.app["hass"]
data = await request.post()
hass: HomeAssistant = request.app["hass"]
data = cast(MultiDictProxy[str], await request.post())
grant_type = data.get("grant_type")
@ -217,7 +229,11 @@ class TokenView(HomeAssistantView):
{"error": "unsupported_grant_type"}, status_code=HTTPStatus.BAD_REQUEST
)
async def _async_handle_revoke_token(self, hass, data):
async def _async_handle_revoke_token(
self,
hass: HomeAssistant,
data: MultiDictProxy[str],
) -> web.Response:
"""Handle revoke token request."""
# OAuth 2.0 Token Revocation [RFC7009]
@ -235,7 +251,12 @@ class TokenView(HomeAssistantView):
await hass.auth.async_remove_refresh_token(refresh_token)
return web.Response(status=HTTPStatus.OK)
async def _async_handle_auth_code(self, hass, data, remote_addr):
async def _async_handle_auth_code(
self,
hass: HomeAssistant,
data: MultiDictProxy[str],
remote_addr: str | None,
) -> web.Response:
"""Handle authorization code request."""
client_id = data.get("client_id")
if client_id is None or not indieauth.verify_client_id(client_id):
@ -298,7 +319,12 @@ class TokenView(HomeAssistantView):
},
)
async def _async_handle_refresh_token(self, hass, data, remote_addr):
async def _async_handle_refresh_token(
self,
hass: HomeAssistant,
data: MultiDictProxy[str],
remote_addr: str | None,
) -> web.Response:
"""Handle authorization code request."""
client_id = data.get("client_id")
if client_id is not None and not indieauth.verify_client_id(client_id):
@ -366,15 +392,15 @@ class LinkUserView(HomeAssistantView):
url = "/auth/link_user"
name = "api:auth:link_user"
def __init__(self, retrieve_credentials):
def __init__(self, retrieve_credentials: RetrieveResultType) -> None:
"""Initialize the link user view."""
self._retrieve_credentials = retrieve_credentials
@RequestDataValidator(vol.Schema({"code": str, "client_id": str}))
async def post(self, request, data):
async def post(self, request: web.Request, data: dict[str, Any]) -> web.Response:
"""Link a user."""
hass = request.app["hass"]
user = request["hass_user"]
hass: HomeAssistant = request.app["hass"]
user: User = request["hass_user"]
credentials = self._retrieve_credentials(data["client_id"], data["code"])
@ -394,12 +420,12 @@ class LinkUserView(HomeAssistantView):
@callback
def _create_auth_code_store():
def _create_auth_code_store() -> tuple[StoreResultType, RetrieveResultType]:
"""Create an in memory store."""
temp_results = {}
temp_results: dict[tuple[str, str], tuple[datetime, Credentials]] = {}
@callback
def store_result(client_id, result):
def store_result(client_id: str, result: Credentials) -> str:
"""Store flow result and return a code to retrieve it."""
if not isinstance(result, Credentials):
raise ValueError("result has to be a Credentials instance")
@ -412,7 +438,7 @@ def _create_auth_code_store():
return code
@callback
def retrieve_result(client_id, code):
def retrieve_result(client_id: str, code: str) -> Credentials | None:
"""Retrieve flow result."""
key = (client_id, code)
@ -437,8 +463,8 @@ def _create_auth_code_store():
@websocket_api.ws_require_user()
@websocket_api.async_response
async def websocket_current_user(
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg
):
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict[str, Any]
) -> None:
"""Return the current user."""
user = connection.user
enabled_modules = await hass.auth.async_get_enabled_mfa(user)
@ -482,8 +508,8 @@ async def websocket_current_user(
@websocket_api.ws_require_user()
@websocket_api.async_response
async def websocket_create_long_lived_access_token(
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg
):
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict[str, Any]
) -> None:
"""Create or a long-lived access token."""
refresh_token = await hass.auth.async_create_refresh_token(
connection.user,
@ -506,12 +532,12 @@ async def websocket_create_long_lived_access_token(
@websocket_api.ws_require_user()
@callback
def websocket_refresh_tokens(
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg
):
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict[str, Any]
) -> None:
"""Return metadata of users refresh tokens."""
current_id = connection.refresh_token_id
tokens = []
tokens: list[dict[str, Any]] = []
for refresh in connection.user.refresh_tokens.values():
if refresh.credential:
auth_provider_type = refresh.credential.auth_provider_type
@ -545,8 +571,8 @@ def websocket_refresh_tokens(
@websocket_api.ws_require_user()
@websocket_api.async_response
async def websocket_delete_refresh_token(
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg
):
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict[str, Any]
) -> None:
"""Handle a delete refresh token request."""
refresh_token = connection.user.refresh_tokens.get(msg["refresh_token_id"])
@ -569,8 +595,8 @@ async def websocket_delete_refresh_token(
@websocket_api.ws_require_user()
@callback
def websocket_sign_path(
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg
):
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict[str, Any]
) -> None:
"""Handle a sign path request."""
connection.send_message(
websocket_api.result_message(

View File

@ -1,18 +1,24 @@
"""Helpers to resolve client ID/secret."""
from __future__ import annotations
import asyncio
from html.parser import HTMLParser
from ipaddress import ip_address
import logging
from urllib.parse import urljoin, urlparse
from urllib.parse import ParseResult, urljoin, urlparse
import aiohttp
import aiohttp.client_exceptions
from homeassistant.core import HomeAssistant
from homeassistant.util.network import is_local
_LOGGER = logging.getLogger(__name__)
async def verify_redirect_uri(hass, client_id, redirect_uri):
async def verify_redirect_uri(
hass: HomeAssistant, client_id: str, redirect_uri: str
) -> bool:
"""Verify that the client and redirect uri match."""
try:
client_id_parts = _parse_client_id(client_id)
@ -47,24 +53,24 @@ async def verify_redirect_uri(hass, client_id, redirect_uri):
class LinkTagParser(HTMLParser):
"""Parser to find link tags."""
def __init__(self, rel):
def __init__(self, rel: str) -> None:
"""Initialize a link tag parser."""
super().__init__()
self.rel = rel
self.found = []
self.found: list[str | None] = []
def handle_starttag(self, tag, attrs):
def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None:
"""Handle finding a start tag."""
if tag != "link":
return
attrs = dict(attrs)
attributes: dict[str, str | None] = dict(attrs)
if attrs.get("rel") == self.rel:
self.found.append(attrs.get("href"))
if attributes.get("rel") == self.rel:
self.found.append(attributes.get("href"))
async def fetch_redirect_uris(hass, url):
async def fetch_redirect_uris(hass: HomeAssistant, url: str) -> list[str]:
"""Find link tag with redirect_uri values.
IndieAuth 4.2.2
@ -108,7 +114,7 @@ async def fetch_redirect_uris(hass, url):
return [urljoin(url, found) for found in parser.found]
def verify_client_id(client_id):
def verify_client_id(client_id: str) -> bool:
"""Verify that the client id is valid."""
try:
_parse_client_id(client_id)
@ -117,7 +123,7 @@ def verify_client_id(client_id):
return False
def _parse_url(url):
def _parse_url(url: str) -> ParseResult:
"""Parse a url in parts and canonicalize according to IndieAuth."""
parts = urlparse(url)
@ -134,7 +140,7 @@ def _parse_url(url):
return parts
def _parse_client_id(client_id):
def _parse_client_id(client_id: str) -> ParseResult:
"""Test if client id is a valid URL according to IndieAuth section 3.2.
https://indieauth.spec.indieweb.org/#client-identifier

View File

@ -66,14 +66,19 @@ associate with an credential if "type" set to "link_user" in
"version": 1
}
"""
from __future__ import annotations
from collections.abc import Callable
from http import HTTPStatus
from ipaddress import ip_address
from typing import TYPE_CHECKING, Any
from aiohttp import web
import voluptuous as vol
import voluptuous_serialize
from homeassistant import data_entry_flow
from homeassistant.auth import AuthManagerFlowManager
from homeassistant.auth.models import Credentials
from homeassistant.components import onboarding
from homeassistant.components.http.auth import async_user_not_allowed_do_auth
@ -88,8 +93,13 @@ from homeassistant.core import HomeAssistant
from . import indieauth
if TYPE_CHECKING:
from . import StoreResultType
async def async_setup(hass, store_result):
async def async_setup(
hass: HomeAssistant, store_result: Callable[[str, Credentials], str]
) -> None:
"""Component to allow users to login."""
hass.http.register_view(AuthProvidersView)
hass.http.register_view(LoginFlowIndexView(hass.auth.login_flow, store_result))
@ -103,9 +113,9 @@ class AuthProvidersView(HomeAssistantView):
name = "api:auth:providers"
requires_auth = False
async def get(self, request):
async def get(self, request: web.Request) -> web.Response:
"""Get available auth providers."""
hass = request.app["hass"]
hass: HomeAssistant = request.app["hass"]
if not onboarding.async_is_user_onboarded(hass):
return self.json_message(
message="Onboarding not finished",
@ -121,7 +131,9 @@ class AuthProvidersView(HomeAssistantView):
)
def _prepare_result_json(result):
def _prepare_result_json(
result: data_entry_flow.FlowResult,
) -> data_entry_flow.FlowResult:
"""Convert result to JSON."""
if result["type"] == data_entry_flow.FlowResultType.CREATE_ENTRY:
data = result.copy()
@ -147,12 +159,21 @@ class LoginFlowBaseView(HomeAssistantView):
requires_auth = False
def __init__(self, flow_mgr, store_result):
def __init__(
self,
flow_mgr: AuthManagerFlowManager,
store_result: StoreResultType,
) -> None:
"""Initialize the flow manager index view."""
self._flow_mgr = flow_mgr
self._store_result = store_result
async def _async_flow_result_to_response(self, request, client_id, result):
async def _async_flow_result_to_response(
self,
request: web.Request,
client_id: str,
result: data_entry_flow.FlowResult,
) -> web.Response:
"""Convert the flow result to a response."""
if result["type"] != data_entry_flow.FlowResultType.CREATE_ENTRY:
# @log_invalid_auth does not work here since it returns HTTP 200.
@ -196,7 +217,7 @@ class LoginFlowIndexView(LoginFlowBaseView):
url = "/auth/login_flow"
name = "api:auth:login_flow"
async def get(self, request):
async def get(self, request: web.Request) -> web.Response:
"""Do not allow index of flows in progress."""
return web.Response(status=HTTPStatus.METHOD_NOT_ALLOWED)
@ -211,15 +232,18 @@ class LoginFlowIndexView(LoginFlowBaseView):
)
)
@log_invalid_auth
async def post(self, request, data):
async def post(self, request: web.Request, data: dict[str, Any]) -> web.Response:
"""Create a new login flow."""
if not await indieauth.verify_redirect_uri(
request.app["hass"], data["client_id"], data["redirect_uri"]
):
hass: HomeAssistant = request.app["hass"]
client_id: str = data["client_id"]
redirect_uri: str = data["redirect_uri"]
if not await indieauth.verify_redirect_uri(hass, client_id, redirect_uri):
return self.json_message(
"invalid client id or redirect uri", HTTPStatus.BAD_REQUEST
)
handler: tuple[str, ...] | str
if isinstance(data["handler"], list):
handler = tuple(data["handler"])
else:
@ -227,9 +251,9 @@ class LoginFlowIndexView(LoginFlowBaseView):
try:
result = await self._flow_mgr.async_init(
handler,
handler, # type: ignore[arg-type]
context={
"ip_address": ip_address(request.remote),
"ip_address": ip_address(request.remote), # type: ignore[arg-type]
"credential_only": data.get("type") == "link_user",
},
)
@ -240,9 +264,7 @@ class LoginFlowIndexView(LoginFlowBaseView):
"Handler does not support init", HTTPStatus.BAD_REQUEST
)
return await self._async_flow_result_to_response(
request, data["client_id"], result
)
return await self._async_flow_result_to_response(request, client_id, result)
class LoginFlowResourceView(LoginFlowBaseView):
@ -251,13 +273,15 @@ class LoginFlowResourceView(LoginFlowBaseView):
url = "/auth/login_flow/{flow_id}"
name = "api:auth:login_flow:resource"
async def get(self, request):
async def get(self, request: web.Request) -> web.Response:
"""Do not allow getting status of a flow in progress."""
return self.json_message("Invalid flow specified", HTTPStatus.NOT_FOUND)
@RequestDataValidator(vol.Schema({"client_id": str}, extra=vol.ALLOW_EXTRA))
@log_invalid_auth
async def post(self, request, data, flow_id):
async def post(
self, request: web.Request, data: dict[str, Any], flow_id: str
) -> web.Response:
"""Handle progressing a login flow request."""
client_id = data.pop("client_id")
@ -267,7 +291,7 @@ class LoginFlowResourceView(LoginFlowBaseView):
try:
# do not allow change ip during login flow
flow = self._flow_mgr.async_get(flow_id)
if flow["context"]["ip_address"] != ip_address(request.remote):
if flow["context"]["ip_address"] != ip_address(request.remote): # type: ignore[arg-type]
return self.json_message("IP address changed", HTTPStatus.BAD_REQUEST)
result = await self._flow_mgr.async_configure(flow_id, data)
except data_entry_flow.UnknownFlow:
@ -277,7 +301,7 @@ class LoginFlowResourceView(LoginFlowBaseView):
return await self._async_flow_result_to_response(request, client_id, result)
async def delete(self, request, flow_id):
async def delete(self, request: web.Request, flow_id: str) -> web.Response:
"""Cancel a flow in progress."""
try:
self._flow_mgr.async_abort(flow_id)

View File

@ -1,5 +1,8 @@
"""Helpers to setup multi-factor auth module."""
from __future__ import annotations
import logging
from typing import Any
import voluptuous as vol
import voluptuous_serialize
@ -7,15 +10,19 @@ import voluptuous_serialize
from homeassistant import data_entry_flow
from homeassistant.components import websocket_api
from homeassistant.core import HomeAssistant, callback
import homeassistant.helpers.config_validation as cv
WS_TYPE_SETUP_MFA = "auth/setup_mfa"
SCHEMA_WS_SETUP_MFA = websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend(
SCHEMA_WS_SETUP_MFA = vol.All(
websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend(
{
vol.Required("type"): WS_TYPE_SETUP_MFA,
vol.Exclusive("mfa_module_id", "module_or_flow_id"): str,
vol.Exclusive("flow_id", "module_or_flow_id"): str,
vol.Optional("user_input"): object,
}
),
cv.has_at_least_one_key("mfa_module_id", "flow_id"),
)
WS_TYPE_DEPOSE_MFA = "auth/depose_mfa"
@ -31,7 +38,13 @@ _LOGGER = logging.getLogger(__name__)
class MfaFlowManager(data_entry_flow.FlowManager):
"""Manage multi factor authentication flows."""
async def async_create_flow(self, handler_key, *, context, data):
async def async_create_flow( # type: ignore[override]
self,
handler_key: Any,
*,
context: dict[str, Any],
data: dict[str, Any],
) -> data_entry_flow.FlowHandler:
"""Create a setup flow. handler is a mfa module."""
mfa_module = self.hass.auth.get_auth_mfa_module(handler_key)
if mfa_module is None:
@ -40,13 +53,15 @@ class MfaFlowManager(data_entry_flow.FlowManager):
user_id = data.pop("user_id")
return await mfa_module.async_setup_flow(user_id)
async def async_finish_flow(self, flow, result):
async def async_finish_flow(
self, flow: data_entry_flow.FlowHandler, result: data_entry_flow.FlowResult
) -> data_entry_flow.FlowResult:
"""Complete an mfs setup flow."""
_LOGGER.debug("flow_result: %s", result)
return result
async def async_setup(hass):
async def async_setup(hass: HomeAssistant) -> None:
"""Init mfa setup flow manager."""
hass.data[DATA_SETUP_FLOW_MGR] = MfaFlowManager(hass)
@ -62,13 +77,13 @@ async def async_setup(hass):
@callback
@websocket_api.ws_require_user(allow_system_user=False)
def websocket_setup_mfa(
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg
):
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict[str, Any]
) -> None:
"""Return a setup flow for mfa auth module."""
async def async_setup_flow(msg):
async def async_setup_flow(msg: dict[str, Any]) -> None:
"""Return a setup flow for mfa auth module."""
flow_manager = hass.data[DATA_SETUP_FLOW_MGR]
flow_manager: MfaFlowManager = hass.data[DATA_SETUP_FLOW_MGR]
if (flow_id := msg.get("flow_id")) is not None:
result = await flow_manager.async_configure(flow_id, msg.get("user_input"))
@ -77,9 +92,8 @@ def websocket_setup_mfa(
)
return
mfa_module_id = msg.get("mfa_module_id")
mfa_module = hass.auth.get_auth_mfa_module(mfa_module_id)
if mfa_module is None:
mfa_module_id = msg["mfa_module_id"]
if hass.auth.get_auth_mfa_module(mfa_module_id) is None:
connection.send_message(
websocket_api.error_message(
msg["id"], "no_module", f"MFA module {mfa_module_id} is not found"
@ -101,11 +115,11 @@ def websocket_setup_mfa(
@callback
@websocket_api.ws_require_user(allow_system_user=False)
def websocket_depose_mfa(
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg
):
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict[str, Any]
) -> None:
"""Remove user from mfa module."""
async def async_depose(msg):
async def async_depose(msg: dict[str, Any]) -> None:
"""Remove user from mfa auth module."""
mfa_module_id = msg["mfa_module_id"]
try:
@ -127,7 +141,9 @@ def websocket_depose_mfa(
hass.async_create_task(async_depose(msg))
def _prepare_result_json(result):
def _prepare_result_json(
result: data_entry_flow.FlowResult,
) -> data_entry_flow.FlowResult:
"""Convert result to JSON."""
if result["type"] == data_entry_flow.FlowResultType.CREATE_ENTRY:
data = result.copy()

View File

@ -175,7 +175,7 @@ class FlowManager(abc.ABC):
)
@callback
def async_get(self, flow_id: str) -> FlowResult | None:
def async_get(self, flow_id: str) -> FlowResult:
"""Return a flow in progress as a partial FlowResult."""
if (flow := self._progress.get(flow_id)) is None:
raise UnknownFlow

View File

@ -349,6 +349,16 @@ disallow_untyped_defs = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.auth.*]
check_untyped_defs = true
disallow_incomplete_defs = true
disallow_subclassing_any = true
disallow_untyped_calls = true
disallow_untyped_decorators = true
disallow_untyped_defs = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.automation.*]
check_untyped_defs = true
disallow_incomplete_defs = true

View File

@ -44,7 +44,13 @@ async def test_ws_setup_depose_mfa(hass, hass_ws_client):
client = await hass_ws_client(hass, access_token)
await client.send_json({"id": 10, "type": mfa_setup_flow.WS_TYPE_SETUP_MFA})
await client.send_json(
{
"id": 10,
"type": mfa_setup_flow.WS_TYPE_SETUP_MFA,
"mfa_module_id": "invalid_module",
}
)
result = await client.receive_json()
assert result["id"] == 10