Allow storing arbitrary data in repairs issues ()

pull/76299/head
Erik Montnemery 2022-08-05 13:16:29 +02:00 committed by GitHub
parent b366090175
commit 9aa8838479
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 76 additions and 23 deletions

View File

@ -6,6 +6,7 @@ import voluptuous as vol
from homeassistant import data_entry_flow
from homeassistant.components.repairs import ConfirmRepairFlow, RepairsFlow
from homeassistant.core import HomeAssistant
class DemoFixFlow(RepairsFlow):
@ -28,7 +29,11 @@ class DemoFixFlow(RepairsFlow):
return self.async_show_form(step_id="confirm", data_schema=vol.Schema({}))
async def async_create_fix_flow(hass, issue_id):
async def async_create_fix_flow(
hass: HomeAssistant,
issue_id: str,
data: dict[str, str | int | float | None] | None,
) -> RepairsFlow:
"""Create flow."""
if issue_id == "bad_psu":
# The bad_psu issue doesn't have its own flow

View File

@ -36,7 +36,9 @@ class FluNearYouFixFlow(RepairsFlow):
async def async_create_fix_flow(
hass: HomeAssistant, issue_id: str
) -> FluNearYouFixFlow:
hass: HomeAssistant,
issue_id: str,
data: dict[str, str | int | float | None] | None,
) -> RepairsFlow:
"""Create flow."""
return FluNearYouFixFlow()

View File

@ -64,10 +64,14 @@ class RepairsFlowManager(data_entry_flow.FlowManager):
platforms: dict[str, RepairsProtocol] = self.hass.data[DOMAIN]["platforms"]
if handler_key not in platforms:
return ConfirmRepairFlow()
platform = platforms[handler_key]
flow: RepairsFlow = ConfirmRepairFlow()
else:
platform = platforms[handler_key]
flow = await platform.async_create_fix_flow(self.hass, issue_id, issue.data)
return await platform.async_create_fix_flow(self.hass, issue_id)
flow.issue_id = issue_id
flow.data = issue.data
return flow
async def async_finish_flow(
self, flow: data_entry_flow.FlowHandler, result: data_entry_flow.FlowResult
@ -109,6 +113,7 @@ def async_create_issue(
*,
issue_domain: str | None = None,
breaks_in_ha_version: str | None = None,
data: dict[str, str | int | float | None] | None = None,
is_fixable: bool,
is_persistent: bool = False,
learn_more_url: str | None = None,
@ -131,6 +136,7 @@ def async_create_issue(
issue_id,
issue_domain=issue_domain,
breaks_in_ha_version=breaks_in_ha_version,
data=data,
is_fixable=is_fixable,
is_persistent=is_persistent,
learn_more_url=learn_more_url,
@ -146,6 +152,7 @@ def create_issue(
issue_id: str,
*,
breaks_in_ha_version: str | None = None,
data: dict[str, str | int | float | None] | None = None,
is_fixable: bool,
is_persistent: bool = False,
learn_more_url: str | None = None,
@ -162,6 +169,7 @@ def create_issue(
domain,
issue_id,
breaks_in_ha_version=breaks_in_ha_version,
data=data,
is_fixable=is_fixable,
is_persistent=is_persistent,
learn_more_url=learn_more_url,

View File

@ -27,6 +27,7 @@ class IssueEntry:
active: bool
breaks_in_ha_version: str | None
created: datetime
data: dict[str, str | int | float | None] | None
dismissed_version: str | None
domain: str
is_fixable: bool | None
@ -53,6 +54,7 @@ class IssueEntry:
return {
**result,
"breaks_in_ha_version": self.breaks_in_ha_version,
"data": self.data,
"is_fixable": self.is_fixable,
"is_persistent": True,
"issue_domain": self.issue_domain,
@ -106,6 +108,7 @@ class IssueRegistry:
*,
issue_domain: str | None = None,
breaks_in_ha_version: str | None = None,
data: dict[str, str | int | float | None] | None = None,
is_fixable: bool,
is_persistent: bool,
learn_more_url: str | None = None,
@ -120,6 +123,7 @@ class IssueRegistry:
active=True,
breaks_in_ha_version=breaks_in_ha_version,
created=dt_util.utcnow(),
data=data,
dismissed_version=None,
domain=domain,
is_fixable=is_fixable,
@ -142,6 +146,7 @@ class IssueRegistry:
issue,
active=True,
breaks_in_ha_version=breaks_in_ha_version,
data=data,
is_fixable=is_fixable,
is_persistent=is_persistent,
issue_domain=issue_domain,
@ -204,6 +209,7 @@ class IssueRegistry:
active=True,
breaks_in_ha_version=issue["breaks_in_ha_version"],
created=created,
data=issue["data"],
dismissed_version=issue["dismissed_version"],
domain=issue["domain"],
is_fixable=issue["is_fixable"],
@ -220,6 +226,7 @@ class IssueRegistry:
active=False,
breaks_in_ha_version=None,
created=created,
data=None,
dismissed_version=issue["dismissed_version"],
domain=issue["domain"],
is_fixable=None,

View File

@ -19,11 +19,17 @@ class IssueSeverity(StrEnum):
class RepairsFlow(data_entry_flow.FlowHandler):
"""Handle a flow for fixing an issue."""
issue_id: str
data: dict[str, str | int | float | None] | None
class RepairsProtocol(Protocol):
"""Define the format of repairs platforms."""
async def async_create_fix_flow(
self, hass: HomeAssistant, issue_id: str
self,
hass: HomeAssistant,
issue_id: str,
data: dict[str, str | int | float | None] | None,
) -> RepairsFlow:
"""Create a flow to fix a fixable issue."""

View File

@ -64,7 +64,8 @@ def ws_list_issues(
"""Return a list of issues."""
def ws_dict(kv_pairs: list[tuple[Any, Any]]) -> dict[Any, Any]:
result = {k: v for k, v in kv_pairs if k not in ("active", "is_persistent")}
excluded_keys = ("active", "data", "is_persistent")
result = {k: v for k, v in kv_pairs if k not in excluded_keys}
result["ignored"] = result["dismissed_version"] is not None
result["created"] = result["created"].isoformat()
return result

View File

@ -51,6 +51,7 @@ async def test_load_issues(hass: HomeAssistant) -> None:
},
{
"breaks_in_ha_version": "2022.6",
"data": {"entry_id": "123"},
"domain": "test",
"issue_id": "issue_4",
"is_fixable": True,
@ -141,6 +142,7 @@ async def test_load_issues(hass: HomeAssistant) -> None:
active=False,
breaks_in_ha_version=None,
created=issue1.created,
data=None,
dismissed_version=issue1.dismissed_version,
domain=issue1.domain,
is_fixable=None,
@ -157,6 +159,7 @@ async def test_load_issues(hass: HomeAssistant) -> None:
active=False,
breaks_in_ha_version=None,
created=issue2.created,
data=None,
dismissed_version=issue2.dismissed_version,
domain=issue2.domain,
is_fixable=None,
@ -196,6 +199,7 @@ async def test_loading_issues_from_storage(hass: HomeAssistant, hass_storage) ->
{
"breaks_in_ha_version": "2022.6",
"created": "2022-07-19T19:41:13.746514+00:00",
"data": {"entry_id": "123"},
"dismissed_version": None,
"domain": "test",
"issue_domain": "blubb",

View File

@ -37,6 +37,17 @@ DEFAULT_ISSUES = [
async def create_issues(hass, ws_client, issues=None):
"""Create issues."""
def api_issue(issue):
excluded_keys = ("data",)
return dict(
{key: issue[key] for key in issue if key not in excluded_keys},
created=ANY,
dismissed_version=None,
ignored=False,
issue_domain=None,
)
if issues is None:
issues = DEFAULT_ISSUES
@ -46,6 +57,7 @@ async def create_issues(hass, ws_client, issues=None):
issue["domain"],
issue["issue_id"],
breaks_in_ha_version=issue["breaks_in_ha_version"],
data=issue.get("data"),
is_fixable=issue["is_fixable"],
is_persistent=False,
learn_more_url=issue["learn_more_url"],
@ -58,22 +70,17 @@ async def create_issues(hass, ws_client, issues=None):
msg = await ws_client.receive_json()
assert msg["success"]
assert msg["result"] == {
"issues": [
dict(
issue,
created=ANY,
dismissed_version=None,
ignored=False,
issue_domain=None,
)
for issue in issues
]
}
assert msg["result"] == {"issues": [api_issue(issue) for issue in issues]}
return issues
EXPECTED_DATA = {
"issue_1": None,
"issue_2": {"blah": "bleh"},
}
class MockFixFlow(RepairsFlow):
"""Handler for an issue fixing flow."""
@ -82,6 +89,9 @@ class MockFixFlow(RepairsFlow):
) -> data_entry_flow.FlowResult:
"""Handle the first step of a fix flow."""
assert self.issue_id in EXPECTED_DATA
assert self.data == EXPECTED_DATA[self.issue_id]
return await (self.async_step_custom_step())
async def async_step_custom_step(
@ -99,7 +109,10 @@ async def mock_repairs_integration(hass):
"""Mock a repairs integration."""
hass.config.components.add("fake_integration")
def async_create_fix_flow(hass, issue_id):
def async_create_fix_flow(hass, issue_id, data):
assert issue_id in EXPECTED_DATA
assert data == EXPECTED_DATA[issue_id]
return MockFixFlow()
mock_platform(
@ -256,11 +269,18 @@ async def test_fix_issue(
ws_client = await hass_ws_client(hass)
client = await hass_client()
issues = [{**DEFAULT_ISSUES[0], "domain": domain}]
issues = [
{
**DEFAULT_ISSUES[0],
"data": {"blah": "bleh"},
"domain": domain,
"issue_id": "issue_2",
}
]
await create_issues(hass, ws_client, issues=issues)
url = "/api/repairs/issues/fix"
resp = await client.post(url, json={"handler": domain, "issue_id": "issue_1"})
resp = await client.post(url, json={"handler": domain, "issue_id": "issue_2"})
assert resp.status == HTTPStatus.OK
data = await resp.json()