Show site state in Amberelectric config flow (#104702)

pull/109692/head
Myles Eftos 2024-02-05 20:53:42 +11:00 committed by GitHub
parent bfebde0f79
commit 41a256a3ff
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 220 additions and 50 deletions

View File

@ -3,18 +3,46 @@ from __future__ import annotations
import amberelectric
from amberelectric.api import amber_api
from amberelectric.model.site import Site
from amberelectric.model.site import Site, SiteStatus
import voluptuous as vol
from homeassistant import config_entries
from homeassistant.const import CONF_API_TOKEN
from homeassistant.data_entry_flow import FlowResult
from homeassistant.helpers.selector import (
SelectOptionDict,
SelectSelector,
SelectSelectorConfig,
SelectSelectorMode,
)
from .const import CONF_SITE_ID, CONF_SITE_NAME, CONF_SITE_NMI, DOMAIN
from .const import CONF_SITE_ID, CONF_SITE_NAME, DOMAIN
API_URL = "https://app.amber.com.au/developers"
def generate_site_selector_name(site: Site) -> str:
"""Generate the name to show in the site drop down in the configuration flow."""
if site.status == SiteStatus.CLOSED:
return site.nmi + " (Closed: " + site.closed_on.isoformat() + ")" # type: ignore[no-any-return]
if site.status == SiteStatus.PENDING:
return site.nmi + " (Pending)" # type: ignore[no-any-return]
return site.nmi # type: ignore[no-any-return]
def filter_sites(sites: list[Site]) -> list[Site]:
"""Deduplicates the list of sites."""
filtered: list[Site] = []
filtered_nmi: set[str] = set()
for site in sorted(sites, key=lambda site: site.status.value):
if site.status == SiteStatus.ACTIVE or site.nmi not in filtered_nmi:
filtered.append(site)
filtered_nmi.add(site.nmi)
return filtered
class AmberElectricConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
"""Handle a config flow."""
@ -31,7 +59,7 @@ class AmberElectricConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
api: amber_api.AmberApi = amber_api.AmberApi.create(configuration)
try:
sites: list[Site] = api.get_sites()
sites: list[Site] = filter_sites(api.get_sites())
if len(sites) == 0:
self._errors[CONF_API_TOKEN] = "no_site"
return None
@ -86,38 +114,31 @@ class AmberElectricConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
assert self._sites is not None
assert self._api_token is not None
api_token = self._api_token
if user_input is not None:
site_nmi = user_input[CONF_SITE_NMI]
sites = [site for site in self._sites if site.nmi == site_nmi]
site = sites[0]
site_id = site.id
site_id = user_input[CONF_SITE_ID]
name = user_input.get(CONF_SITE_NAME, site_id)
return self.async_create_entry(
title=name,
data={
CONF_SITE_ID: site_id,
CONF_API_TOKEN: api_token,
CONF_SITE_NMI: site.nmi,
},
data={CONF_SITE_ID: site_id, CONF_API_TOKEN: self._api_token},
)
user_input = {
CONF_API_TOKEN: api_token,
CONF_SITE_NMI: "",
CONF_SITE_NAME: "",
}
return self.async_show_form(
step_id="site",
data_schema=vol.Schema(
{
vol.Required(
CONF_SITE_NMI, default=user_input[CONF_SITE_NMI]
): vol.In([site.nmi for site in self._sites]),
vol.Optional(
CONF_SITE_NAME, default=user_input[CONF_SITE_NAME]
): str,
vol.Required(CONF_SITE_ID): SelectSelector(
SelectSelectorConfig(
options=[
SelectOptionDict(
value=site.id,
label=generate_site_selector_name(site),
)
for site in self._sites
],
mode=SelectSelectorMode.DROPDOWN,
)
),
vol.Optional(CONF_SITE_NAME): str,
}
),
errors=self._errors,

View File

@ -6,7 +6,6 @@ from homeassistant.const import Platform
DOMAIN = "amberelectric"
CONF_SITE_NAME = "site_name"
CONF_SITE_ID = "site_id"
CONF_SITE_NMI = "site_nmi"
ATTRIBUTION = "Data provided by Amber Electric"

View File

@ -6,5 +6,5 @@
"documentation": "https://www.home-assistant.io/integrations/amberelectric",
"iot_class": "cloud_polling",
"loggers": ["amberelectric"],
"requirements": ["amberelectric==1.0.4"]
"requirements": ["amberelectric==1.1.0"]
}

View File

@ -425,7 +425,7 @@ airtouch5py==0.2.8
alpha-vantage==2.3.1
# homeassistant.components.amberelectric
amberelectric==1.0.4
amberelectric==1.1.0
# homeassistant.components.amcrest
amcrest==1.9.8

View File

@ -395,7 +395,7 @@ airtouch4pyapi==1.0.5
airtouch5py==0.2.8
# homeassistant.components.amberelectric
amberelectric==1.0.4
amberelectric==1.1.0
# homeassistant.components.androidtv
androidtv[async]==0.0.73

View File

@ -1,17 +1,18 @@
"""Tests for the Amber config flow."""
from collections.abc import Generator
from datetime import date
from unittest.mock import Mock, patch
from amberelectric import ApiException
from amberelectric.model.site import Site
from amberelectric.model.site import Site, SiteStatus
import pytest
from homeassistant import data_entry_flow
from homeassistant.components.amberelectric.config_flow import filter_sites
from homeassistant.components.amberelectric.const import (
CONF_SITE_ID,
CONF_SITE_NAME,
CONF_SITE_NMI,
DOMAIN,
)
from homeassistant.config_entries import SOURCE_USER
@ -26,29 +27,88 @@ pytestmark = pytest.mark.usefixtures("mock_setup_entry")
@pytest.fixture(name="invalid_key_api")
def mock_invalid_key_api() -> Generator:
"""Return an authentication error."""
instance = Mock()
instance.get_sites.side_effect = ApiException(status=403)
with patch("amberelectric.api.AmberApi.create", return_value=instance):
yield instance
with patch("amberelectric.api.AmberApi.create") as mock:
mock.return_value.get_sites.side_effect = ApiException(status=403)
yield mock
@pytest.fixture(name="api_error")
def mock_api_error() -> Generator:
"""Return an authentication error."""
instance = Mock()
instance.get_sites.side_effect = ApiException(status=500)
with patch("amberelectric.api.AmberApi.create", return_value=instance):
yield instance
with patch("amberelectric.api.AmberApi.create") as mock:
mock.return_value.get_sites.side_effect = ApiException(status=500)
yield mock
@pytest.fixture(name="single_site_api")
def mock_single_site_api() -> Generator:
"""Return a single site."""
site = Site(
"01FG0AGP818PXK0DWHXJRRT2DH",
"11111111111",
[],
"Jemena",
SiteStatus.ACTIVE,
date(2002, 1, 1),
None,
)
with patch("amberelectric.api.AmberApi.create") as mock:
mock.return_value.get_sites.return_value = [site]
yield mock
@pytest.fixture(name="single_site_pending_api")
def mock_single_site_pending_api() -> Generator:
"""Return a single site."""
site = Site(
"01FG0AGP818PXK0DWHXJRRT2DH",
"11111111111",
[],
"Jemena",
SiteStatus.PENDING,
None,
None,
)
with patch("amberelectric.api.AmberApi.create") as mock:
mock.return_value.get_sites.return_value = [site]
yield mock
@pytest.fixture(name="single_site_rejoin_api")
def mock_single_site_rejoin_api() -> Generator:
"""Return a single site."""
instance = Mock()
site = Site("01FG0AGP818PXK0DWHXJRRT2DH", "11111111111", [])
instance.get_sites.return_value = [site]
site_1 = Site(
"01HGD9QB72HB3DWQNJ6SSCGXGV",
"11111111111",
[],
"Jemena",
SiteStatus.CLOSED,
date(2002, 1, 1),
date(2002, 6, 1),
)
site_2 = Site(
"01FG0AGP818PXK0DWHXJRRT2DH",
"11111111111",
[],
"Jemena",
SiteStatus.ACTIVE,
date(2003, 1, 1),
None,
)
site_3 = Site(
"01FG0AGP818PXK0DWHXJRRT2DH",
"11111111112",
[],
"Jemena",
SiteStatus.CLOSED,
date(2003, 1, 1),
date(2003, 6, 1),
)
instance.get_sites.return_value = [site_1, site_2, site_3]
with patch("amberelectric.api.AmberApi.create", return_value=instance):
yield instance
@ -64,6 +124,39 @@ def mock_no_site_api() -> Generator:
yield instance
async def test_single_pending_site(
hass: HomeAssistant, single_site_pending_api: Mock
) -> None:
"""Test single site."""
initial_result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": SOURCE_USER}
)
assert initial_result.get("type") == data_entry_flow.FlowResultType.FORM
assert initial_result.get("step_id") == "user"
# Test filling in API key
enter_api_key_result = await hass.config_entries.flow.async_init(
DOMAIN,
context={"source": SOURCE_USER},
data={CONF_API_TOKEN: API_KEY},
)
assert enter_api_key_result.get("type") == data_entry_flow.FlowResultType.FORM
assert enter_api_key_result.get("step_id") == "site"
select_site_result = await hass.config_entries.flow.async_configure(
enter_api_key_result["flow_id"],
{CONF_SITE_ID: "01FG0AGP818PXK0DWHXJRRT2DH", CONF_SITE_NAME: "Home"},
)
# Show available sites
assert select_site_result.get("type") == data_entry_flow.FlowResultType.CREATE_ENTRY
assert select_site_result.get("title") == "Home"
data = select_site_result.get("data")
assert data
assert data[CONF_API_TOKEN] == API_KEY
assert data[CONF_SITE_ID] == "01FG0AGP818PXK0DWHXJRRT2DH"
async def test_single_site(hass: HomeAssistant, single_site_api: Mock) -> None:
"""Test single site."""
initial_result = await hass.config_entries.flow.async_init(
@ -83,7 +176,40 @@ async def test_single_site(hass: HomeAssistant, single_site_api: Mock) -> None:
select_site_result = await hass.config_entries.flow.async_configure(
enter_api_key_result["flow_id"],
{CONF_SITE_NMI: "11111111111", CONF_SITE_NAME: "Home"},
{CONF_SITE_ID: "01FG0AGP818PXK0DWHXJRRT2DH", CONF_SITE_NAME: "Home"},
)
# Show available sites
assert select_site_result.get("type") == data_entry_flow.FlowResultType.CREATE_ENTRY
assert select_site_result.get("title") == "Home"
data = select_site_result.get("data")
assert data
assert data[CONF_API_TOKEN] == API_KEY
assert data[CONF_SITE_ID] == "01FG0AGP818PXK0DWHXJRRT2DH"
async def test_single_site_rejoin(
hass: HomeAssistant, single_site_rejoin_api: Mock
) -> None:
"""Test single site."""
initial_result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": SOURCE_USER}
)
assert initial_result.get("type") == data_entry_flow.FlowResultType.FORM
assert initial_result.get("step_id") == "user"
# Test filling in API key
enter_api_key_result = await hass.config_entries.flow.async_init(
DOMAIN,
context={"source": SOURCE_USER},
data={CONF_API_TOKEN: API_KEY},
)
assert enter_api_key_result.get("type") == data_entry_flow.FlowResultType.FORM
assert enter_api_key_result.get("step_id") == "site"
select_site_result = await hass.config_entries.flow.async_configure(
enter_api_key_result["flow_id"],
{CONF_SITE_ID: "01FG0AGP818PXK0DWHXJRRT2DH", CONF_SITE_NAME: "Home"},
)
# Show available sites
@ -93,7 +219,6 @@ async def test_single_site(hass: HomeAssistant, single_site_api: Mock) -> None:
assert data
assert data[CONF_API_TOKEN] == API_KEY
assert data[CONF_SITE_ID] == "01FG0AGP818PXK0DWHXJRRT2DH"
assert data[CONF_SITE_NMI] == "11111111111"
async def test_no_site(hass: HomeAssistant, no_site_api: Mock) -> None:
@ -148,3 +273,15 @@ async def test_unknown_error(hass: HomeAssistant, api_error: Mock) -> None:
# Goes back to the user step
assert result.get("step_id") == "user"
assert result.get("errors") == {"api_token": "unknown_error"}
async def test_site_deduplication(single_site_rejoin_api: Mock) -> None:
"""Test site deduplication."""
filtered = filter_sites(single_site_rejoin_api.get_sites())
assert len(filtered) == 2
assert (
next(s for s in filtered if s.nmi == "11111111111").status == SiteStatus.ACTIVE
)
assert (
next(s for s in filtered if s.nmi == "11111111112").status == SiteStatus.CLOSED
)

View File

@ -2,13 +2,14 @@
from __future__ import annotations
from collections.abc import Generator
from datetime import date
from unittest.mock import Mock, patch
from amberelectric import ApiException
from amberelectric.model.channel import Channel, ChannelType
from amberelectric.model.current_interval import CurrentInterval
from amberelectric.model.interval import Descriptor, SpikeStatus
from amberelectric.model.site import Site
from amberelectric.model.site import Site, SiteStatus
from dateutil import parser
import pytest
@ -38,23 +39,35 @@ def mock_api_current_price() -> Generator:
general_site = Site(
GENERAL_ONLY_SITE_ID,
"11111111111",
[Channel(identifier="E1", type=ChannelType.GENERAL)],
[Channel(identifier="E1", type=ChannelType.GENERAL, tariff="A100")],
"Jemena",
SiteStatus.ACTIVE,
date(2021, 1, 1),
None,
)
general_and_controlled_load = Site(
GENERAL_AND_CONTROLLED_SITE_ID,
"11111111112",
[
Channel(identifier="E1", type=ChannelType.GENERAL),
Channel(identifier="E2", type=ChannelType.CONTROLLED_LOAD),
Channel(identifier="E1", type=ChannelType.GENERAL, tariff="A100"),
Channel(identifier="E2", type=ChannelType.CONTROLLED_LOAD, tariff="A180"),
],
"Jemena",
SiteStatus.ACTIVE,
date(2021, 1, 1),
None,
)
general_and_feed_in = Site(
GENERAL_AND_FEED_IN_SITE_ID,
"11111111113",
[
Channel(identifier="E1", type=ChannelType.GENERAL),
Channel(identifier="E2", type=ChannelType.FEED_IN),
Channel(identifier="E1", type=ChannelType.GENERAL, tariff="A100"),
Channel(identifier="E2", type=ChannelType.FEED_IN, tariff="A100"),
],
"Jemena",
SiteStatus.ACTIVE,
date(2021, 1, 1),
None,
)
instance.get_sites.return_value = [
general_site,