Add ability to re-auth WattTime (#56582)
* Tests cleanup * Still store the abbreviation * Code review * Remove unused attribute * Add ability to re-auth WattTime * Consolidate logic for entry unique ID * Fix tests * Fix docstringpull/57498/head
parent
6a39119ccc
commit
0c04ca20c6
|
@ -5,7 +5,7 @@ from datetime import timedelta
|
|||
|
||||
from aiowatttime import Client
|
||||
from aiowatttime.emissions import RealTimeEmissionsResponseType
|
||||
from aiowatttime.errors import WattTimeError
|
||||
from aiowatttime.errors import InvalidCredentialsError, WattTimeError
|
||||
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import (
|
||||
|
@ -15,6 +15,7 @@ from homeassistant.const import (
|
|||
CONF_USERNAME,
|
||||
)
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import ConfigEntryAuthFailed
|
||||
from homeassistant.helpers import aiohttp_client
|
||||
from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, UpdateFailed
|
||||
|
||||
|
@ -36,6 +37,8 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
|||
client = await Client.async_login(
|
||||
entry.data[CONF_USERNAME], entry.data[CONF_PASSWORD], session=session
|
||||
)
|
||||
except InvalidCredentialsError as err:
|
||||
raise ConfigEntryAuthFailed("Invalid username/password") from err
|
||||
except WattTimeError as err:
|
||||
LOGGER.error("Error while authenticating with WattTime: %s", err)
|
||||
return False
|
||||
|
@ -46,6 +49,8 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
|||
return await client.emissions.async_get_realtime_emissions(
|
||||
entry.data[CONF_LATITUDE], entry.data[CONF_LONGITUDE]
|
||||
)
|
||||
except InvalidCredentialsError as err:
|
||||
raise ConfigEntryAuthFailed("Invalid username/password") from err
|
||||
except WattTimeError as err:
|
||||
raise UpdateFailed(
|
||||
f"Error while requesting data from WattTime: {err}"
|
||||
|
|
|
@ -14,8 +14,10 @@ from homeassistant.const import (
|
|||
CONF_PASSWORD,
|
||||
CONF_USERNAME,
|
||||
)
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.data_entry_flow import FlowResult
|
||||
from homeassistant.helpers import aiohttp_client, config_validation as cv
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
|
||||
from .const import (
|
||||
CONF_BALANCING_AUTHORITY,
|
||||
|
@ -44,6 +46,12 @@ STEP_LOCATION_DATA_SCHEMA = vol.Schema(
|
|||
}
|
||||
)
|
||||
|
||||
STEP_REAUTH_CONFIRM_DATA_SCHEMA = vol.Schema(
|
||||
{
|
||||
vol.Required(CONF_PASSWORD): str,
|
||||
}
|
||||
)
|
||||
|
||||
STEP_USER_DATA_SCHEMA = vol.Schema(
|
||||
{
|
||||
vol.Required(CONF_USERNAME): str,
|
||||
|
@ -52,6 +60,12 @@ STEP_USER_DATA_SCHEMA = vol.Schema(
|
|||
)
|
||||
|
||||
|
||||
@callback
|
||||
def get_unique_id(data: dict[str, Any]) -> str:
|
||||
"""Get a unique ID from a data payload."""
|
||||
return f"{data[CONF_LATITUDE]}, {data[CONF_LONGITUDE]}"
|
||||
|
||||
|
||||
class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
||||
"""Handle a config flow for WattTime."""
|
||||
|
||||
|
@ -60,8 +74,49 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
|||
def __init__(self) -> None:
|
||||
"""Initialize."""
|
||||
self._client: Client | None = None
|
||||
self._password: str | None = None
|
||||
self._username: str | None = None
|
||||
self._data: dict[str, Any] = {}
|
||||
|
||||
async def _async_validate_credentials(
|
||||
self, username: str, password: str, error_step_id: str, error_schema: vol.Schema
|
||||
):
|
||||
"""Validate input credentials and proceed accordingly."""
|
||||
session = aiohttp_client.async_get_clientsession(self.hass)
|
||||
|
||||
try:
|
||||
self._client = await Client.async_login(username, password, session=session)
|
||||
except InvalidCredentialsError:
|
||||
return self.async_show_form(
|
||||
step_id=error_step_id,
|
||||
data_schema=error_schema,
|
||||
errors={"base": "invalid_auth"},
|
||||
description_placeholders={CONF_USERNAME: username},
|
||||
)
|
||||
except Exception as err: # pylint: disable=broad-except
|
||||
LOGGER.exception("Unexpected exception while logging in: %s", err)
|
||||
return self.async_show_form(
|
||||
step_id=error_step_id,
|
||||
data_schema=error_schema,
|
||||
errors={"base": "unknown"},
|
||||
description_placeholders={CONF_USERNAME: username},
|
||||
)
|
||||
|
||||
if CONF_LATITUDE in self._data:
|
||||
# If coordinates already exist at this stage, we're in an existing flow and
|
||||
# should reauth:
|
||||
entry_unique_id = get_unique_id(self._data)
|
||||
if existing_entry := await self.async_set_unique_id(entry_unique_id):
|
||||
self.hass.config_entries.async_update_entry(
|
||||
existing_entry, data=self._data
|
||||
)
|
||||
self.hass.async_create_task(
|
||||
self.hass.config_entries.async_reload(existing_entry.entry_id)
|
||||
)
|
||||
return self.async_abort(reason="reauth_successful")
|
||||
|
||||
# ...otherwise, we're in a new flow:
|
||||
self._data[CONF_USERNAME] = username
|
||||
self._data[CONF_PASSWORD] = password
|
||||
return await self.async_step_location()
|
||||
|
||||
async def async_step_coordinates(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
|
@ -75,7 +130,7 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
|||
if TYPE_CHECKING:
|
||||
assert self._client
|
||||
|
||||
unique_id = f"{user_input[CONF_LATITUDE]}, {user_input[CONF_LONGITUDE]}"
|
||||
unique_id = get_unique_id(user_input)
|
||||
await self.async_set_unique_id(unique_id)
|
||||
self._abort_if_unique_id_configured()
|
||||
|
||||
|
@ -100,8 +155,8 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
|||
return self.async_create_entry(
|
||||
title=unique_id,
|
||||
data={
|
||||
CONF_USERNAME: self._username,
|
||||
CONF_PASSWORD: self._password,
|
||||
CONF_USERNAME: self._data[CONF_USERNAME],
|
||||
CONF_PASSWORD: self._data[CONF_PASSWORD],
|
||||
CONF_LATITUDE: user_input[CONF_LATITUDE],
|
||||
CONF_LONGITUDE: user_input[CONF_LONGITUDE],
|
||||
CONF_BALANCING_AUTHORITY: grid_region["name"],
|
||||
|
@ -127,6 +182,31 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
|||
)
|
||||
return await self.async_step_coordinates()
|
||||
|
||||
async def async_step_reauth(self, config: ConfigType) -> FlowResult:
|
||||
"""Handle configuration by re-auth."""
|
||||
self._data = {**config}
|
||||
return await self.async_step_reauth_confirm()
|
||||
|
||||
async def async_step_reauth_confirm(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> FlowResult:
|
||||
"""Handle re-auth completion."""
|
||||
if not user_input:
|
||||
return self.async_show_form(
|
||||
step_id="reauth_confirm",
|
||||
data_schema=STEP_REAUTH_CONFIRM_DATA_SCHEMA,
|
||||
description_placeholders={CONF_USERNAME: self._data[CONF_USERNAME]},
|
||||
)
|
||||
|
||||
self._data[CONF_PASSWORD] = user_input[CONF_PASSWORD]
|
||||
|
||||
return await self._async_validate_credentials(
|
||||
self._data[CONF_USERNAME],
|
||||
self._data[CONF_PASSWORD],
|
||||
"reauth_confirm",
|
||||
STEP_REAUTH_CONFIRM_DATA_SCHEMA,
|
||||
)
|
||||
|
||||
async def async_step_user(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> FlowResult:
|
||||
|
@ -136,28 +216,9 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
|||
step_id="user", data_schema=STEP_USER_DATA_SCHEMA
|
||||
)
|
||||
|
||||
session = aiohttp_client.async_get_clientsession(self.hass)
|
||||
|
||||
try:
|
||||
self._client = await Client.async_login(
|
||||
user_input[CONF_USERNAME],
|
||||
user_input[CONF_PASSWORD],
|
||||
session=session,
|
||||
)
|
||||
except InvalidCredentialsError:
|
||||
return self.async_show_form(
|
||||
step_id="user",
|
||||
data_schema=STEP_USER_DATA_SCHEMA,
|
||||
errors={CONF_USERNAME: "invalid_auth"},
|
||||
)
|
||||
except Exception as err: # pylint: disable=broad-except
|
||||
LOGGER.exception("Unexpected exception while logging in: %s", err)
|
||||
return self.async_show_form(
|
||||
step_id="user",
|
||||
data_schema=STEP_USER_DATA_SCHEMA,
|
||||
errors={"base": "unknown"},
|
||||
)
|
||||
|
||||
self._username = user_input[CONF_USERNAME]
|
||||
self._password = user_input[CONF_PASSWORD]
|
||||
return await self.async_step_location()
|
||||
return await self._async_validate_credentials(
|
||||
user_input[CONF_USERNAME],
|
||||
user_input[CONF_PASSWORD],
|
||||
"user",
|
||||
STEP_USER_DATA_SCHEMA,
|
||||
)
|
||||
|
|
|
@ -14,6 +14,13 @@
|
|||
"location_type": "[%key:common::config_flow::data::location%]"
|
||||
}
|
||||
},
|
||||
"reauth_confirm": {
|
||||
"title": "[%key:common::config_flow::title::reauth%]",
|
||||
"description": "Please re-enter the password for {username}:",
|
||||
"data": {
|
||||
"password": "[%key:common::config_flow::data::password%]"
|
||||
}
|
||||
},
|
||||
"user": {
|
||||
"description": "Input your username and password:",
|
||||
"data": {
|
||||
|
@ -28,7 +35,8 @@
|
|||
"unknown_coordinates": "No data for latitude/longitude"
|
||||
},
|
||||
"abort": {
|
||||
"already_configured": "[%key:common::config_flow::abort::already_configured_device%]"
|
||||
"already_configured": "[%key:common::config_flow::abort::already_configured_device%]",
|
||||
"reauth_successful": "[%key:common::config_flow::abort::reauth_successful%]"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
{
|
||||
"config": {
|
||||
"abort": {
|
||||
"already_configured": "Device is already configured"
|
||||
"already_configured": "Device is already configured",
|
||||
"reauth_successful": "Re-authentication was successful"
|
||||
},
|
||||
"error": {
|
||||
"invalid_auth": "Invalid authentication",
|
||||
|
@ -22,6 +23,13 @@
|
|||
},
|
||||
"description": "Pick a location to monitor:"
|
||||
},
|
||||
"reauth_confirm": {
|
||||
"data": {
|
||||
"password": "Password"
|
||||
},
|
||||
"description": "Please re-enter the password for {username}.",
|
||||
"title": "Reauthenticate Integration"
|
||||
},
|
||||
"user": {
|
||||
"data": {
|
||||
"password": "Password",
|
||||
|
|
|
@ -42,9 +42,11 @@ def client_fixture(get_grid_region):
|
|||
@pytest.fixture(name="client_login")
|
||||
def client_login_fixture(client):
|
||||
"""Define a fixture for patching the aiowatttime coroutine to get a client."""
|
||||
with patch("homeassistant.components.watttime.config_flow.Client.async_login") as m:
|
||||
m.return_value = client
|
||||
yield m
|
||||
with patch(
|
||||
"homeassistant.components.watttime.config_flow.Client.async_login"
|
||||
) as mock_client:
|
||||
mock_client.return_value = client
|
||||
yield mock_client
|
||||
|
||||
|
||||
@pytest.fixture(name="get_grid_region")
|
||||
|
@ -162,7 +164,92 @@ async def test_step_coordinates_unknown_error(
|
|||
assert result["errors"] == {"base": "unknown"}
|
||||
|
||||
|
||||
async def test_step_login_coordinates(hass: HomeAssistant, client_login) -> None:
|
||||
async def test_step_reauth(hass: HomeAssistant, client_login) -> None:
|
||||
"""Test a full reauth flow."""
|
||||
MockConfigEntry(
|
||||
domain=DOMAIN,
|
||||
unique_id="51.528308, -0.3817765",
|
||||
data={
|
||||
CONF_USERNAME: "user",
|
||||
CONF_PASSWORD: "password",
|
||||
CONF_LATITUDE: 51.528308,
|
||||
CONF_LONGITUDE: -0.3817765,
|
||||
CONF_BALANCING_AUTHORITY: "Authority 1",
|
||||
CONF_BALANCING_AUTHORITY_ABBREV: "AUTH_1",
|
||||
},
|
||||
).add_to_hass(hass)
|
||||
|
||||
await setup.async_setup_component(hass, "persistent_notification", {})
|
||||
with patch(
|
||||
"homeassistant.components.watttime.async_setup_entry",
|
||||
return_value=True,
|
||||
):
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN,
|
||||
context={"source": config_entries.SOURCE_REAUTH},
|
||||
data={
|
||||
CONF_USERNAME: "user",
|
||||
CONF_PASSWORD: "password",
|
||||
CONF_LATITUDE: 51.528308,
|
||||
CONF_LONGITUDE: -0.3817765,
|
||||
CONF_BALANCING_AUTHORITY: "Authority 1",
|
||||
CONF_BALANCING_AUTHORITY_ABBREV: "AUTH_1",
|
||||
},
|
||||
)
|
||||
result = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
user_input={CONF_PASSWORD: "password"},
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert result["type"] == RESULT_TYPE_ABORT
|
||||
assert result["reason"] == "reauth_successful"
|
||||
assert len(hass.config_entries.async_entries()) == 1
|
||||
|
||||
|
||||
async def test_step_reauth_invalid_credentials(hass: HomeAssistant) -> None:
|
||||
"""Test that invalid credentials during reauth are handled."""
|
||||
MockConfigEntry(
|
||||
domain=DOMAIN,
|
||||
unique_id="51.528308, -0.3817765",
|
||||
data={
|
||||
CONF_USERNAME: "user",
|
||||
CONF_PASSWORD: "password",
|
||||
CONF_LATITUDE: 51.528308,
|
||||
CONF_LONGITUDE: -0.3817765,
|
||||
CONF_BALANCING_AUTHORITY: "Authority 1",
|
||||
CONF_BALANCING_AUTHORITY_ABBREV: "AUTH_1",
|
||||
},
|
||||
).add_to_hass(hass)
|
||||
|
||||
await setup.async_setup_component(hass, "persistent_notification", {})
|
||||
with patch(
|
||||
"homeassistant.components.watttime.config_flow.Client.async_login",
|
||||
AsyncMock(side_effect=InvalidCredentialsError),
|
||||
):
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN,
|
||||
context={"source": config_entries.SOURCE_REAUTH},
|
||||
data={
|
||||
CONF_USERNAME: "user",
|
||||
CONF_PASSWORD: "password",
|
||||
CONF_LATITUDE: 51.528308,
|
||||
CONF_LONGITUDE: -0.3817765,
|
||||
CONF_BALANCING_AUTHORITY: "Authority 1",
|
||||
CONF_BALANCING_AUTHORITY_ABBREV: "AUTH_1",
|
||||
},
|
||||
)
|
||||
result = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
user_input={CONF_PASSWORD: "password"},
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert result["type"] == RESULT_TYPE_FORM
|
||||
assert result["errors"] == {"base": "invalid_auth"}
|
||||
|
||||
|
||||
async def test_step_user_coordinates(hass: HomeAssistant, client_login) -> None:
|
||||
"""Test a full login flow (inputting custom coordinates)."""
|
||||
|
||||
with patch(
|
||||
|
@ -241,7 +328,7 @@ async def test_step_user_invalid_credentials(hass: HomeAssistant) -> None:
|
|||
await hass.async_block_till_done()
|
||||
|
||||
assert result["type"] == RESULT_TYPE_FORM
|
||||
assert result["errors"] == {"username": "invalid_auth"}
|
||||
assert result["errors"] == {"base": "invalid_auth"}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("get_grid_region", [AsyncMock(side_effect=Exception)])
|
||||
|
|
Loading…
Reference in New Issue