Default to recorder db for SQL integration (#85436)

Co-authored-by: J. Nick Koston <nick@koston.org>
pull/89169/head
G Johansson 2023-03-14 04:41:32 +01:00 committed by GitHub
parent 2f4e9c8ef3
commit afa58b80bd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 120 additions and 54 deletions

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import voluptuous as vol
from homeassistant.components.recorder import CONF_DB_URL
from homeassistant.components.recorder import CONF_DB_URL, get_instance
from homeassistant.components.sensor import (
CONF_STATE_CLASS,
DEVICE_CLASSES_SCHEMA,
@ -53,6 +53,18 @@ CONFIG_SCHEMA = vol.Schema(
)
def remove_configured_db_url_if_not_needed(
hass: HomeAssistant, entry: ConfigEntry
) -> None:
"""Remove db url from config if it matches recorder database."""
hass.config_entries.async_update_entry(
entry,
options={
key: value for key, value in entry.options.items() if key != CONF_DB_URL
},
)
async def async_update_listener(hass: HomeAssistant, entry: ConfigEntry) -> None:
"""Update listener for options."""
await hass.config_entries.async_reload(entry.entry_id)
@ -73,6 +85,9 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Set up SQL from a config entry."""
if entry.options.get(CONF_DB_URL) == get_instance(hass).db_url:
remove_configured_db_url_if_not_needed(hass, entry)
entry.async_on_unload(entry.add_update_listener(async_update_listener))
await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)

View File

@ -11,13 +11,14 @@ from sqlalchemy.orm import Session, scoped_session, sessionmaker
import voluptuous as vol
from homeassistant import config_entries
from homeassistant.components.recorder import CONF_DB_URL, DEFAULT_DB_FILE, DEFAULT_URL
from homeassistant.components.recorder import CONF_DB_URL
from homeassistant.const import CONF_NAME, CONF_UNIT_OF_MEASUREMENT, CONF_VALUE_TEMPLATE
from homeassistant.core import callback
from homeassistant.data_entry_flow import FlowResult
from homeassistant.helpers import selector
from .const import CONF_COLUMN_NAME, CONF_QUERY, DOMAIN
from .util import resolve_db_url
_LOGGER = logging.getLogger(__name__)
@ -85,34 +86,37 @@ class SQLConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
) -> FlowResult:
"""Handle the user step."""
errors = {}
db_url_default = DEFAULT_URL.format(
hass_config_path=self.hass.config.path(DEFAULT_DB_FILE)
)
if user_input is not None:
db_url = user_input.get(CONF_DB_URL, db_url_default)
db_url = user_input.get(CONF_DB_URL)
query = user_input[CONF_QUERY]
column = user_input[CONF_COLUMN_NAME]
uom = user_input.get(CONF_UNIT_OF_MEASUREMENT)
value_template = user_input.get(CONF_VALUE_TEMPLATE)
name = user_input[CONF_NAME]
db_url_for_validation = None
try:
validate_sql_select(query)
db_url_for_validation = resolve_db_url(self.hass, db_url)
await self.hass.async_add_executor_job(
validate_query, db_url, query, column
validate_query, db_url_for_validation, query, column
)
except SQLAlchemyError:
errors["db_url"] = "db_url_invalid"
except ValueError:
errors["query"] = "query_invalid"
add_db_url = (
{CONF_DB_URL: db_url} if db_url == db_url_for_validation else {}
)
if not errors:
return self.async_create_entry(
title=name,
data={},
options={
CONF_DB_URL: db_url,
**add_db_url,
CONF_QUERY: query,
CONF_COLUMN_NAME: column,
CONF_UNIT_OF_MEASUREMENT: uom,
@ -140,32 +144,32 @@ class SQLOptionsFlowHandler(config_entries.OptionsFlow):
) -> FlowResult:
"""Manage SQL options."""
errors = {}
db_url_default = DEFAULT_URL.format(
hass_config_path=self.hass.config.path(DEFAULT_DB_FILE)
)
if user_input is not None:
db_url = user_input.get(CONF_DB_URL, db_url_default)
db_url = user_input.get(CONF_DB_URL)
query = user_input[CONF_QUERY]
column = user_input[CONF_COLUMN_NAME]
name = self.entry.options.get(CONF_NAME, self.entry.title)
try:
validate_sql_select(query)
db_url_for_validation = resolve_db_url(self.hass, db_url)
await self.hass.async_add_executor_job(
validate_query, db_url, query, column
validate_query, db_url_for_validation, query, column
)
except SQLAlchemyError:
errors["db_url"] = "db_url_invalid"
except ValueError:
errors["query"] = "query_invalid"
else:
new_user_input = user_input
if new_user_input.get(CONF_DB_URL) and db_url == db_url_for_validation:
new_user_input.pop(CONF_DB_URL)
return self.async_create_entry(
title="",
data={
CONF_NAME: name,
CONF_DB_URL: db_url,
**user_input,
**new_user_input,
},
)
@ -176,7 +180,7 @@ class SQLOptionsFlowHandler(config_entries.OptionsFlow):
vol.Optional(
CONF_DB_URL,
description={
"suggested_value": self.entry.options[CONF_DB_URL]
"suggested_value": self.entry.options.get(CONF_DB_URL)
},
): selector.TextSelector(),
vol.Required(

View File

@ -10,7 +10,7 @@ from sqlalchemy.engine import Result
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import Session, scoped_session, sessionmaker
from homeassistant.components.recorder import CONF_DB_URL, DEFAULT_DB_FILE, DEFAULT_URL
from homeassistant.components.recorder import CONF_DB_URL
from homeassistant.components.sensor import (
CONF_STATE_CLASS,
SensorDeviceClass,
@ -34,6 +34,7 @@ from homeassistant.helpers.template import Template
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from .const import CONF_COLUMN_NAME, CONF_QUERY, DB_URL_RE, DOMAIN
from .util import resolve_db_url
_LOGGER = logging.getLogger(__name__)
@ -59,7 +60,7 @@ async def async_setup_platform(
value_template: Template | None = conf.get(CONF_VALUE_TEMPLATE)
column_name: str = conf[CONF_COLUMN_NAME]
unique_id: str | None = conf.get(CONF_UNIQUE_ID)
db_url: str | None = conf.get(CONF_DB_URL)
db_url: str = resolve_db_url(hass, conf.get(CONF_DB_URL))
device_class: SensorDeviceClass | None = conf.get(CONF_DEVICE_CLASS)
state_class: SensorStateClass | None = conf.get(CONF_STATE_CLASS)
@ -87,7 +88,7 @@ async def async_setup_entry(
) -> None:
"""Set up the SQL sensor from config entry."""
db_url: str = entry.options[CONF_DB_URL]
db_url: str = resolve_db_url(hass, entry.options.get(CONF_DB_URL))
name: str = entry.options[CONF_NAME]
query_str: str = entry.options[CONF_QUERY]
unit: str | None = entry.options.get(CONF_UNIT_OF_MEASUREMENT)
@ -128,7 +129,7 @@ async def async_setup_sensor(
unit: str | None,
value_template: Template | None,
unique_id: str | None,
db_url: str | None,
db_url: str,
yaml: bool,
device_class: SensorDeviceClass | None,
state_class: SensorStateClass | None,
@ -136,16 +137,12 @@ async def async_setup_sensor(
) -> None:
"""Set up the SQL sensor."""
if not db_url:
db_url = DEFAULT_URL.format(hass_config_path=hass.config.path(DEFAULT_DB_FILE))
sess: Session | None = None
try:
engine = sqlalchemy.create_engine(db_url, future=True)
sessmaker = scoped_session(sessionmaker(bind=engine, future=True))
# Run a dummy query just to test the db_url
sess = sessmaker()
sess: Session = sessmaker()
sess.execute(sqlalchemy.text("SELECT 1;"))
except SQLAlchemyError as err:

View File

@ -0,0 +1,12 @@
"""Utils for sql."""
from __future__ import annotations
from homeassistant.components.recorder import get_instance
from homeassistant.core import HomeAssistant
def resolve_db_url(hass: HomeAssistant, db_url: str | None) -> str:
"""Return the db_url provided if not empty, otherwise return the recorder db_url."""
if db_url and not db_url.isspace():
return db_url
return get_instance(hass).db_url

View File

@ -23,7 +23,6 @@ from homeassistant.core import HomeAssistant
from tests.common import MockConfigEntry
ENTRY_CONFIG = {
CONF_DB_URL: "sqlite://",
CONF_NAME: "Get Value",
CONF_QUERY: "SELECT 5 as value",
CONF_COLUMN_NAME: "value",
@ -31,7 +30,6 @@ ENTRY_CONFIG = {
}
ENTRY_CONFIG_INVALID_QUERY = {
CONF_DB_URL: "sqlite://",
CONF_NAME: "Get Value",
CONF_QUERY: "UPDATE 5 as value",
CONF_COLUMN_NAME: "size",
@ -39,14 +37,12 @@ ENTRY_CONFIG_INVALID_QUERY = {
}
ENTRY_CONFIG_INVALID_QUERY_OPT = {
CONF_DB_URL: "sqlite://",
CONF_QUERY: "UPDATE 5 as value",
CONF_COLUMN_NAME: "size",
CONF_UNIT_OF_MEASUREMENT: "MiB",
}
ENTRY_CONFIG_NO_RESULTS = {
CONF_DB_URL: "sqlite://",
CONF_NAME: "Get Value",
CONF_QUERY: "SELECT kalle as value from no_table;",
CONF_COLUMN_NAME: "value",
@ -69,7 +65,6 @@ YAML_CONFIG = {
YAML_CONFIG_INVALID = {
"sql": {
CONF_DB_URL: "sqlite://",
CONF_QUERY: "SELECT 5 as value",
CONF_COLUMN_NAME: "value",
CONF_UNIT_OF_MEASUREMENT: "MiB",

View File

@ -6,7 +6,7 @@ from unittest.mock import patch
from sqlalchemy.exc import SQLAlchemyError
from homeassistant import config_entries
from homeassistant.components.recorder import DEFAULT_DB_FILE, DEFAULT_URL, Recorder
from homeassistant.components.recorder import Recorder
from homeassistant.components.sql.const import DOMAIN
from homeassistant.core import HomeAssistant
from homeassistant.data_entry_flow import FlowResultType
@ -43,7 +43,6 @@ async def test_form(recorder_mock: Recorder, hass: HomeAssistant) -> None:
assert result2["type"] == FlowResultType.CREATE_ENTRY
assert result2["title"] == "Get Value"
assert result2["options"] == {
"db_url": "sqlite://",
"name": "Get Value",
"query": "SELECT 5 as value",
"column": "value",
@ -113,7 +112,6 @@ async def test_flow_fails_invalid_query(
assert result5["type"] == FlowResultType.CREATE_ENTRY
assert result5["title"] == "Get Value"
assert result5["options"] == {
"db_url": "sqlite://",
"name": "Get Value",
"query": "SELECT 5 as value",
"column": "value",
@ -163,7 +161,6 @@ async def test_options_flow(recorder_mock: Recorder, hass: HomeAssistant) -> Non
assert result["type"] == FlowResultType.CREATE_ENTRY
assert result["data"] == {
"name": "Get Value",
"db_url": "sqlite://",
"query": "SELECT 5 as size",
"column": "size",
"unit_of_measurement": "MiB",
@ -215,7 +212,6 @@ async def test_options_flow_name_previously_removed(
assert result["type"] == FlowResultType.CREATE_ENTRY
assert result["data"] == {
"name": "Get Value Title",
"db_url": "sqlite://",
"query": "SELECT 5 as size",
"column": "size",
"unit_of_measurement": "MiB",
@ -316,7 +312,6 @@ async def test_options_flow_fails_invalid_query(
assert result4["type"] == FlowResultType.CREATE_ENTRY
assert result4["data"] == {
"name": "Get Value",
"db_url": "sqlite://",
"query": "SELECT 5 as size",
"column": "size",
"unit_of_measurement": "MiB",
@ -369,12 +364,9 @@ async def test_options_flow_db_url_empty(
)
await hass.async_block_till_done()
db_url = DEFAULT_URL.format(hass_config_path=hass.config.path(DEFAULT_DB_FILE))
assert result["type"] == FlowResultType.CREATE_ENTRY
assert result["data"] == {
"name": "Get Value",
"db_url": db_url,
"query": "SELECT 5 as size",
"column": "size",
"unit_of_measurement": "MiB",

View File

@ -8,6 +8,7 @@ import voluptuous as vol
from homeassistant import config_entries
from homeassistant.components.recorder import Recorder
from homeassistant.components.recorder.util import get_instance
from homeassistant.components.sql import validate_sql_select
from homeassistant.components.sql.const import DOMAIN
from homeassistant.core import HomeAssistant
@ -56,3 +57,41 @@ async def test_invalid_query(hass: HomeAssistant) -> None:
"""Test invalid query."""
with pytest.raises(vol.Invalid):
validate_sql_select("DROP TABLE *")
async def test_remove_configured_db_url_if_not_needed_when_not_needed(
recorder_mock: Recorder,
hass: HomeAssistant,
) -> None:
"""Test configured db_url is replaced with None if matching the recorder db."""
recorder_db_url = get_instance(hass).db_url
config = {
"db_url": recorder_db_url,
"query": "SELECT 5 as value",
"column": "value",
"name": "count_tables",
}
config_entry = await init_integration(hass, config)
assert config_entry.options.get("db_url") is None
async def test_remove_configured_db_url_if_not_needed_when_needed(
recorder_mock: Recorder,
hass: HomeAssistant,
) -> None:
"""Test configured db_url is not replaced if it differs from the recorder db."""
db_url = "mssql://"
config = {
"db_url": db_url,
"query": "SELECT 5 as value",
"column": "value",
"name": "count_tables",
}
config_entry = await init_integration(hass, config)
assert config_entry.options.get("db_url") == db_url

View File

@ -182,6 +182,7 @@ async def test_invalid_url_setup(
async def test_invalid_url_on_update(
recorder_mock: Recorder,
hass: HomeAssistant,
caplog: pytest.LogCaptureFixture,
) -> None:
@ -192,22 +193,9 @@ async def test_invalid_url_on_update(
"column": "value",
"name": "count_tables",
}
entry = MockConfigEntry(
domain=DOMAIN,
source=SOURCE_USER,
data={},
options=config,
entry_id="1",
)
entry.add_to_hass(hass)
await hass.config_entries.async_setup(entry.entry_id)
await hass.async_block_till_done()
await init_integration(hass, config)
with patch(
"homeassistant.components.recorder",
), patch(
"homeassistant.components.sql.sensor.sqlalchemy.engine.cursor.CursorResult",
side_effect=SQLAlchemyError(
"sqlite://homeassistant:hunter2@homeassistant.local"
@ -219,7 +207,6 @@ async def test_invalid_url_on_update(
)
await hass.async_block_till_done()
assert "sqlite://homeassistant:hunter2@homeassistant.local" not in caplog.text
assert "sqlite://****:****@homeassistant.local" in caplog.text

View File

@ -0,0 +1,25 @@
"""Test the sql utils."""
from unittest.mock import AsyncMock
from homeassistant.components.recorder import get_instance
from homeassistant.components.sql.util import resolve_db_url
from homeassistant.core import HomeAssistant
async def test_resolve_db_url_when_none_configured(
recorder_mock: AsyncMock,
hass: HomeAssistant,
):
"""Test return recorder db_url if provided db_url is None."""
db_url = None
resolved_url = resolve_db_url(hass, db_url)
assert resolved_url == get_instance(hass).db_url
async def test_resolve_db_url_when_configured(hass: HomeAssistant):
"""Test return provided db_url if it's set."""
db_url = "mssql://"
resolved_url = resolve_db_url(hass, db_url)
assert resolved_url == db_url