238 lines
7.0 KiB
Python
238 lines
7.0 KiB
Python
"""Config flow to configure the Nextbus integration."""
|
|
|
|
from collections import Counter
|
|
import logging
|
|
|
|
from py_nextbus import NextBusClient
|
|
import voluptuous as vol
|
|
|
|
from homeassistant.config_entries import ConfigFlow, ConfigFlowResult
|
|
from homeassistant.const import CONF_NAME, CONF_STOP
|
|
from homeassistant.helpers.selector import (
|
|
SelectOptionDict,
|
|
SelectSelector,
|
|
SelectSelectorConfig,
|
|
SelectSelectorMode,
|
|
)
|
|
|
|
from .const import CONF_AGENCY, CONF_ROUTE, DOMAIN
|
|
from .util import listify
|
|
|
|
_LOGGER = logging.getLogger(__name__)
|
|
|
|
|
|
def _dict_to_select_selector(options: dict[str, str]) -> SelectSelector:
|
|
return SelectSelector(
|
|
SelectSelectorConfig(
|
|
options=sorted(
|
|
(
|
|
SelectOptionDict(value=key, label=value)
|
|
for key, value in options.items()
|
|
),
|
|
key=lambda o: o["label"],
|
|
),
|
|
mode=SelectSelectorMode.DROPDOWN,
|
|
)
|
|
)
|
|
|
|
|
|
def _get_agency_tags(client: NextBusClient) -> dict[str, str]:
|
|
return {a["tag"]: a["title"] for a in client.get_agency_list()["agency"]}
|
|
|
|
|
|
def _get_route_tags(client: NextBusClient, agency_tag: str) -> dict[str, str]:
|
|
return {a["tag"]: a["title"] for a in client.get_route_list(agency_tag)["route"]}
|
|
|
|
|
|
def _get_stop_tags(
|
|
client: NextBusClient, agency_tag: str, route_tag: str
|
|
) -> dict[str, str]:
|
|
route_config = client.get_route_config(route_tag, agency_tag)
|
|
tags = {a["tag"]: a["title"] for a in route_config["route"]["stop"]}
|
|
title_counts = Counter(tags.values())
|
|
|
|
stop_directions: dict[str, str] = {}
|
|
for direction in listify(route_config["route"]["direction"]):
|
|
for stop in direction["stop"]:
|
|
stop_directions[stop["tag"]] = direction["name"]
|
|
|
|
# Append directions for stops with shared titles
|
|
for tag, title in tags.items():
|
|
if title_counts[title] > 1:
|
|
tags[tag] = f"{title} ({stop_directions.get(tag, tag)})"
|
|
|
|
return tags
|
|
|
|
|
|
def _validate_import(
|
|
client: NextBusClient, agency_tag: str, route_tag: str, stop_tag: str
|
|
) -> str | tuple[str, str, str]:
|
|
agency_tags = _get_agency_tags(client)
|
|
agency = agency_tags.get(agency_tag)
|
|
if not agency:
|
|
return "invalid_agency"
|
|
|
|
route_tags = _get_route_tags(client, agency_tag)
|
|
route = route_tags.get(route_tag)
|
|
if not route:
|
|
return "invalid_route"
|
|
|
|
stop_tags = _get_stop_tags(client, agency_tag, route_tag)
|
|
stop = stop_tags.get(stop_tag)
|
|
if not stop:
|
|
return "invalid_stop"
|
|
|
|
return agency, route, stop
|
|
|
|
|
|
def _unique_id_from_data(data: dict[str, str]) -> str:
|
|
return f"{data[CONF_AGENCY]}_{data[CONF_ROUTE]}_{data[CONF_STOP]}"
|
|
|
|
|
|
class NextBusFlowHandler(ConfigFlow, domain=DOMAIN):
|
|
"""Handle Nextbus configuration."""
|
|
|
|
VERSION = 1
|
|
|
|
_agency_tags: dict[str, str]
|
|
_route_tags: dict[str, str]
|
|
_stop_tags: dict[str, str]
|
|
|
|
def __init__(self):
|
|
"""Initialize NextBus config flow."""
|
|
self.data: dict[str, str] = {}
|
|
self._client = NextBusClient(output_format="json")
|
|
_LOGGER.info("Init new config flow")
|
|
|
|
async def async_step_import(self, config_input: dict[str, str]) -> ConfigFlowResult:
|
|
"""Handle import of config."""
|
|
agency_tag = config_input[CONF_AGENCY]
|
|
route_tag = config_input[CONF_ROUTE]
|
|
stop_tag = config_input[CONF_STOP]
|
|
|
|
validation_result = await self.hass.async_add_executor_job(
|
|
_validate_import,
|
|
self._client,
|
|
agency_tag,
|
|
route_tag,
|
|
stop_tag,
|
|
)
|
|
if isinstance(validation_result, str):
|
|
return self.async_abort(reason=validation_result)
|
|
|
|
data = {
|
|
CONF_AGENCY: agency_tag,
|
|
CONF_ROUTE: route_tag,
|
|
CONF_STOP: stop_tag,
|
|
CONF_NAME: config_input.get(
|
|
CONF_NAME,
|
|
f"{config_input[CONF_AGENCY]} {config_input[CONF_ROUTE]}",
|
|
),
|
|
}
|
|
|
|
await self.async_set_unique_id(_unique_id_from_data(data))
|
|
self._abort_if_unique_id_configured()
|
|
|
|
return self.async_create_entry(
|
|
title=" ".join(validation_result),
|
|
data=data,
|
|
)
|
|
|
|
async def async_step_user(
|
|
self,
|
|
user_input: dict[str, str] | None = None,
|
|
) -> ConfigFlowResult:
|
|
"""Handle a flow initiated by the user."""
|
|
return await self.async_step_agency(user_input)
|
|
|
|
async def async_step_agency(
|
|
self,
|
|
user_input: dict[str, str] | None = None,
|
|
) -> ConfigFlowResult:
|
|
"""Select agency."""
|
|
if user_input is not None:
|
|
self.data[CONF_AGENCY] = user_input[CONF_AGENCY]
|
|
|
|
return await self.async_step_route()
|
|
|
|
self._agency_tags = await self.hass.async_add_executor_job(
|
|
_get_agency_tags, self._client
|
|
)
|
|
|
|
return self.async_show_form(
|
|
step_id="agency",
|
|
data_schema=vol.Schema(
|
|
{
|
|
vol.Required(CONF_AGENCY): _dict_to_select_selector(
|
|
self._agency_tags
|
|
),
|
|
}
|
|
),
|
|
)
|
|
|
|
async def async_step_route(
|
|
self,
|
|
user_input: dict[str, str] | None = None,
|
|
) -> ConfigFlowResult:
|
|
"""Select route."""
|
|
if user_input is not None:
|
|
self.data[CONF_ROUTE] = user_input[CONF_ROUTE]
|
|
|
|
return await self.async_step_stop()
|
|
|
|
self._route_tags = await self.hass.async_add_executor_job(
|
|
_get_route_tags, self._client, self.data[CONF_AGENCY]
|
|
)
|
|
|
|
return self.async_show_form(
|
|
step_id="route",
|
|
data_schema=vol.Schema(
|
|
{
|
|
vol.Required(CONF_ROUTE): _dict_to_select_selector(
|
|
self._route_tags
|
|
),
|
|
}
|
|
),
|
|
)
|
|
|
|
async def async_step_stop(
|
|
self,
|
|
user_input: dict[str, str] | None = None,
|
|
) -> ConfigFlowResult:
|
|
"""Select stop."""
|
|
|
|
if user_input is not None:
|
|
self.data[CONF_STOP] = user_input[CONF_STOP]
|
|
|
|
await self.async_set_unique_id(_unique_id_from_data(self.data))
|
|
self._abort_if_unique_id_configured()
|
|
|
|
agency_tag = self.data[CONF_AGENCY]
|
|
route_tag = self.data[CONF_ROUTE]
|
|
stop_tag = self.data[CONF_STOP]
|
|
|
|
agency_name = self._agency_tags[agency_tag]
|
|
route_name = self._route_tags[route_tag]
|
|
stop_name = self._stop_tags[stop_tag]
|
|
|
|
return self.async_create_entry(
|
|
title=f"{agency_name} {route_name} {stop_name}",
|
|
data=self.data,
|
|
)
|
|
|
|
self._stop_tags = await self.hass.async_add_executor_job(
|
|
_get_stop_tags,
|
|
self._client,
|
|
self.data[CONF_AGENCY],
|
|
self.data[CONF_ROUTE],
|
|
)
|
|
|
|
return self.async_show_form(
|
|
step_id="stop",
|
|
data_schema=vol.Schema(
|
|
{
|
|
vol.Required(CONF_STOP): _dict_to_select_selector(self._stop_tags),
|
|
}
|
|
),
|
|
)
|