Default to recorder db for SQL integration (#85436)
Co-authored-by: J. Nick Koston <nick@koston.org>pull/89169/head
parent
2f4e9c8ef3
commit
afa58b80bd
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components.recorder import CONF_DB_URL
|
||||
from homeassistant.components.recorder import CONF_DB_URL, get_instance
|
||||
from homeassistant.components.sensor import (
|
||||
CONF_STATE_CLASS,
|
||||
DEVICE_CLASSES_SCHEMA,
|
||||
|
@ -53,6 +53,18 @@ CONFIG_SCHEMA = vol.Schema(
|
|||
)
|
||||
|
||||
|
||||
def remove_configured_db_url_if_not_needed(
|
||||
hass: HomeAssistant, entry: ConfigEntry
|
||||
) -> None:
|
||||
"""Remove db url from config if it matches recorder database."""
|
||||
hass.config_entries.async_update_entry(
|
||||
entry,
|
||||
options={
|
||||
key: value for key, value in entry.options.items() if key != CONF_DB_URL
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def async_update_listener(hass: HomeAssistant, entry: ConfigEntry) -> None:
|
||||
"""Update listener for options."""
|
||||
await hass.config_entries.async_reload(entry.entry_id)
|
||||
|
@ -73,6 +85,9 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||
|
||||
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
"""Set up SQL from a config entry."""
|
||||
if entry.options.get(CONF_DB_URL) == get_instance(hass).db_url:
|
||||
remove_configured_db_url_if_not_needed(hass, entry)
|
||||
|
||||
entry.async_on_unload(entry.add_update_listener(async_update_listener))
|
||||
|
||||
await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)
|
||||
|
|
|
@ -11,13 +11,14 @@ from sqlalchemy.orm import Session, scoped_session, sessionmaker
|
|||
import voluptuous as vol
|
||||
|
||||
from homeassistant import config_entries
|
||||
from homeassistant.components.recorder import CONF_DB_URL, DEFAULT_DB_FILE, DEFAULT_URL
|
||||
from homeassistant.components.recorder import CONF_DB_URL
|
||||
from homeassistant.const import CONF_NAME, CONF_UNIT_OF_MEASUREMENT, CONF_VALUE_TEMPLATE
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.data_entry_flow import FlowResult
|
||||
from homeassistant.helpers import selector
|
||||
|
||||
from .const import CONF_COLUMN_NAME, CONF_QUERY, DOMAIN
|
||||
from .util import resolve_db_url
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
@ -85,34 +86,37 @@ class SQLConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
|||
) -> FlowResult:
|
||||
"""Handle the user step."""
|
||||
errors = {}
|
||||
db_url_default = DEFAULT_URL.format(
|
||||
hass_config_path=self.hass.config.path(DEFAULT_DB_FILE)
|
||||
)
|
||||
|
||||
if user_input is not None:
|
||||
db_url = user_input.get(CONF_DB_URL, db_url_default)
|
||||
db_url = user_input.get(CONF_DB_URL)
|
||||
query = user_input[CONF_QUERY]
|
||||
column = user_input[CONF_COLUMN_NAME]
|
||||
uom = user_input.get(CONF_UNIT_OF_MEASUREMENT)
|
||||
value_template = user_input.get(CONF_VALUE_TEMPLATE)
|
||||
name = user_input[CONF_NAME]
|
||||
db_url_for_validation = None
|
||||
|
||||
try:
|
||||
validate_sql_select(query)
|
||||
db_url_for_validation = resolve_db_url(self.hass, db_url)
|
||||
await self.hass.async_add_executor_job(
|
||||
validate_query, db_url, query, column
|
||||
validate_query, db_url_for_validation, query, column
|
||||
)
|
||||
except SQLAlchemyError:
|
||||
errors["db_url"] = "db_url_invalid"
|
||||
except ValueError:
|
||||
errors["query"] = "query_invalid"
|
||||
|
||||
add_db_url = (
|
||||
{CONF_DB_URL: db_url} if db_url == db_url_for_validation else {}
|
||||
)
|
||||
|
||||
if not errors:
|
||||
return self.async_create_entry(
|
||||
title=name,
|
||||
data={},
|
||||
options={
|
||||
CONF_DB_URL: db_url,
|
||||
**add_db_url,
|
||||
CONF_QUERY: query,
|
||||
CONF_COLUMN_NAME: column,
|
||||
CONF_UNIT_OF_MEASUREMENT: uom,
|
||||
|
@ -140,32 +144,32 @@ class SQLOptionsFlowHandler(config_entries.OptionsFlow):
|
|||
) -> FlowResult:
|
||||
"""Manage SQL options."""
|
||||
errors = {}
|
||||
db_url_default = DEFAULT_URL.format(
|
||||
hass_config_path=self.hass.config.path(DEFAULT_DB_FILE)
|
||||
)
|
||||
|
||||
if user_input is not None:
|
||||
db_url = user_input.get(CONF_DB_URL, db_url_default)
|
||||
db_url = user_input.get(CONF_DB_URL)
|
||||
query = user_input[CONF_QUERY]
|
||||
column = user_input[CONF_COLUMN_NAME]
|
||||
name = self.entry.options.get(CONF_NAME, self.entry.title)
|
||||
|
||||
try:
|
||||
validate_sql_select(query)
|
||||
db_url_for_validation = resolve_db_url(self.hass, db_url)
|
||||
await self.hass.async_add_executor_job(
|
||||
validate_query, db_url, query, column
|
||||
validate_query, db_url_for_validation, query, column
|
||||
)
|
||||
except SQLAlchemyError:
|
||||
errors["db_url"] = "db_url_invalid"
|
||||
except ValueError:
|
||||
errors["query"] = "query_invalid"
|
||||
else:
|
||||
new_user_input = user_input
|
||||
if new_user_input.get(CONF_DB_URL) and db_url == db_url_for_validation:
|
||||
new_user_input.pop(CONF_DB_URL)
|
||||
return self.async_create_entry(
|
||||
title="",
|
||||
data={
|
||||
CONF_NAME: name,
|
||||
CONF_DB_URL: db_url,
|
||||
**user_input,
|
||||
**new_user_input,
|
||||
},
|
||||
)
|
||||
|
||||
|
@ -176,7 +180,7 @@ class SQLOptionsFlowHandler(config_entries.OptionsFlow):
|
|||
vol.Optional(
|
||||
CONF_DB_URL,
|
||||
description={
|
||||
"suggested_value": self.entry.options[CONF_DB_URL]
|
||||
"suggested_value": self.entry.options.get(CONF_DB_URL)
|
||||
},
|
||||
): selector.TextSelector(),
|
||||
vol.Required(
|
||||
|
|
|
@ -10,7 +10,7 @@ from sqlalchemy.engine import Result
|
|||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.orm import Session, scoped_session, sessionmaker
|
||||
|
||||
from homeassistant.components.recorder import CONF_DB_URL, DEFAULT_DB_FILE, DEFAULT_URL
|
||||
from homeassistant.components.recorder import CONF_DB_URL
|
||||
from homeassistant.components.sensor import (
|
||||
CONF_STATE_CLASS,
|
||||
SensorDeviceClass,
|
||||
|
@ -34,6 +34,7 @@ from homeassistant.helpers.template import Template
|
|||
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
||||
|
||||
from .const import CONF_COLUMN_NAME, CONF_QUERY, DB_URL_RE, DOMAIN
|
||||
from .util import resolve_db_url
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
@ -59,7 +60,7 @@ async def async_setup_platform(
|
|||
value_template: Template | None = conf.get(CONF_VALUE_TEMPLATE)
|
||||
column_name: str = conf[CONF_COLUMN_NAME]
|
||||
unique_id: str | None = conf.get(CONF_UNIQUE_ID)
|
||||
db_url: str | None = conf.get(CONF_DB_URL)
|
||||
db_url: str = resolve_db_url(hass, conf.get(CONF_DB_URL))
|
||||
device_class: SensorDeviceClass | None = conf.get(CONF_DEVICE_CLASS)
|
||||
state_class: SensorStateClass | None = conf.get(CONF_STATE_CLASS)
|
||||
|
||||
|
@ -87,7 +88,7 @@ async def async_setup_entry(
|
|||
) -> None:
|
||||
"""Set up the SQL sensor from config entry."""
|
||||
|
||||
db_url: str = entry.options[CONF_DB_URL]
|
||||
db_url: str = resolve_db_url(hass, entry.options.get(CONF_DB_URL))
|
||||
name: str = entry.options[CONF_NAME]
|
||||
query_str: str = entry.options[CONF_QUERY]
|
||||
unit: str | None = entry.options.get(CONF_UNIT_OF_MEASUREMENT)
|
||||
|
@ -128,7 +129,7 @@ async def async_setup_sensor(
|
|||
unit: str | None,
|
||||
value_template: Template | None,
|
||||
unique_id: str | None,
|
||||
db_url: str | None,
|
||||
db_url: str,
|
||||
yaml: bool,
|
||||
device_class: SensorDeviceClass | None,
|
||||
state_class: SensorStateClass | None,
|
||||
|
@ -136,16 +137,12 @@ async def async_setup_sensor(
|
|||
) -> None:
|
||||
"""Set up the SQL sensor."""
|
||||
|
||||
if not db_url:
|
||||
db_url = DEFAULT_URL.format(hass_config_path=hass.config.path(DEFAULT_DB_FILE))
|
||||
|
||||
sess: Session | None = None
|
||||
try:
|
||||
engine = sqlalchemy.create_engine(db_url, future=True)
|
||||
sessmaker = scoped_session(sessionmaker(bind=engine, future=True))
|
||||
|
||||
# Run a dummy query just to test the db_url
|
||||
sess = sessmaker()
|
||||
sess: Session = sessmaker()
|
||||
sess.execute(sqlalchemy.text("SELECT 1;"))
|
||||
|
||||
except SQLAlchemyError as err:
|
||||
|
|
|
@ -0,0 +1,12 @@
|
|||
"""Utils for sql."""
|
||||
from __future__ import annotations
|
||||
|
||||
from homeassistant.components.recorder import get_instance
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
|
||||
def resolve_db_url(hass: HomeAssistant, db_url: str | None) -> str:
|
||||
"""Return the db_url provided if not empty, otherwise return the recorder db_url."""
|
||||
if db_url and not db_url.isspace():
|
||||
return db_url
|
||||
return get_instance(hass).db_url
|
|
@ -23,7 +23,6 @@ from homeassistant.core import HomeAssistant
|
|||
from tests.common import MockConfigEntry
|
||||
|
||||
ENTRY_CONFIG = {
|
||||
CONF_DB_URL: "sqlite://",
|
||||
CONF_NAME: "Get Value",
|
||||
CONF_QUERY: "SELECT 5 as value",
|
||||
CONF_COLUMN_NAME: "value",
|
||||
|
@ -31,7 +30,6 @@ ENTRY_CONFIG = {
|
|||
}
|
||||
|
||||
ENTRY_CONFIG_INVALID_QUERY = {
|
||||
CONF_DB_URL: "sqlite://",
|
||||
CONF_NAME: "Get Value",
|
||||
CONF_QUERY: "UPDATE 5 as value",
|
||||
CONF_COLUMN_NAME: "size",
|
||||
|
@ -39,14 +37,12 @@ ENTRY_CONFIG_INVALID_QUERY = {
|
|||
}
|
||||
|
||||
ENTRY_CONFIG_INVALID_QUERY_OPT = {
|
||||
CONF_DB_URL: "sqlite://",
|
||||
CONF_QUERY: "UPDATE 5 as value",
|
||||
CONF_COLUMN_NAME: "size",
|
||||
CONF_UNIT_OF_MEASUREMENT: "MiB",
|
||||
}
|
||||
|
||||
ENTRY_CONFIG_NO_RESULTS = {
|
||||
CONF_DB_URL: "sqlite://",
|
||||
CONF_NAME: "Get Value",
|
||||
CONF_QUERY: "SELECT kalle as value from no_table;",
|
||||
CONF_COLUMN_NAME: "value",
|
||||
|
@ -69,7 +65,6 @@ YAML_CONFIG = {
|
|||
|
||||
YAML_CONFIG_INVALID = {
|
||||
"sql": {
|
||||
CONF_DB_URL: "sqlite://",
|
||||
CONF_QUERY: "SELECT 5 as value",
|
||||
CONF_COLUMN_NAME: "value",
|
||||
CONF_UNIT_OF_MEASUREMENT: "MiB",
|
||||
|
|
|
@ -6,7 +6,7 @@ from unittest.mock import patch
|
|||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from homeassistant import config_entries
|
||||
from homeassistant.components.recorder import DEFAULT_DB_FILE, DEFAULT_URL, Recorder
|
||||
from homeassistant.components.recorder import Recorder
|
||||
from homeassistant.components.sql.const import DOMAIN
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.data_entry_flow import FlowResultType
|
||||
|
@ -43,7 +43,6 @@ async def test_form(recorder_mock: Recorder, hass: HomeAssistant) -> None:
|
|||
assert result2["type"] == FlowResultType.CREATE_ENTRY
|
||||
assert result2["title"] == "Get Value"
|
||||
assert result2["options"] == {
|
||||
"db_url": "sqlite://",
|
||||
"name": "Get Value",
|
||||
"query": "SELECT 5 as value",
|
||||
"column": "value",
|
||||
|
@ -113,7 +112,6 @@ async def test_flow_fails_invalid_query(
|
|||
assert result5["type"] == FlowResultType.CREATE_ENTRY
|
||||
assert result5["title"] == "Get Value"
|
||||
assert result5["options"] == {
|
||||
"db_url": "sqlite://",
|
||||
"name": "Get Value",
|
||||
"query": "SELECT 5 as value",
|
||||
"column": "value",
|
||||
|
@ -163,7 +161,6 @@ async def test_options_flow(recorder_mock: Recorder, hass: HomeAssistant) -> Non
|
|||
assert result["type"] == FlowResultType.CREATE_ENTRY
|
||||
assert result["data"] == {
|
||||
"name": "Get Value",
|
||||
"db_url": "sqlite://",
|
||||
"query": "SELECT 5 as size",
|
||||
"column": "size",
|
||||
"unit_of_measurement": "MiB",
|
||||
|
@ -215,7 +212,6 @@ async def test_options_flow_name_previously_removed(
|
|||
assert result["type"] == FlowResultType.CREATE_ENTRY
|
||||
assert result["data"] == {
|
||||
"name": "Get Value Title",
|
||||
"db_url": "sqlite://",
|
||||
"query": "SELECT 5 as size",
|
||||
"column": "size",
|
||||
"unit_of_measurement": "MiB",
|
||||
|
@ -316,7 +312,6 @@ async def test_options_flow_fails_invalid_query(
|
|||
assert result4["type"] == FlowResultType.CREATE_ENTRY
|
||||
assert result4["data"] == {
|
||||
"name": "Get Value",
|
||||
"db_url": "sqlite://",
|
||||
"query": "SELECT 5 as size",
|
||||
"column": "size",
|
||||
"unit_of_measurement": "MiB",
|
||||
|
@ -369,12 +364,9 @@ async def test_options_flow_db_url_empty(
|
|||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
db_url = DEFAULT_URL.format(hass_config_path=hass.config.path(DEFAULT_DB_FILE))
|
||||
|
||||
assert result["type"] == FlowResultType.CREATE_ENTRY
|
||||
assert result["data"] == {
|
||||
"name": "Get Value",
|
||||
"db_url": db_url,
|
||||
"query": "SELECT 5 as size",
|
||||
"column": "size",
|
||||
"unit_of_measurement": "MiB",
|
||||
|
|
|
@ -8,6 +8,7 @@ import voluptuous as vol
|
|||
|
||||
from homeassistant import config_entries
|
||||
from homeassistant.components.recorder import Recorder
|
||||
from homeassistant.components.recorder.util import get_instance
|
||||
from homeassistant.components.sql import validate_sql_select
|
||||
from homeassistant.components.sql.const import DOMAIN
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
@ -56,3 +57,41 @@ async def test_invalid_query(hass: HomeAssistant) -> None:
|
|||
"""Test invalid query."""
|
||||
with pytest.raises(vol.Invalid):
|
||||
validate_sql_select("DROP TABLE *")
|
||||
|
||||
|
||||
async def test_remove_configured_db_url_if_not_needed_when_not_needed(
|
||||
recorder_mock: Recorder,
|
||||
hass: HomeAssistant,
|
||||
) -> None:
|
||||
"""Test configured db_url is replaced with None if matching the recorder db."""
|
||||
recorder_db_url = get_instance(hass).db_url
|
||||
|
||||
config = {
|
||||
"db_url": recorder_db_url,
|
||||
"query": "SELECT 5 as value",
|
||||
"column": "value",
|
||||
"name": "count_tables",
|
||||
}
|
||||
|
||||
config_entry = await init_integration(hass, config)
|
||||
|
||||
assert config_entry.options.get("db_url") is None
|
||||
|
||||
|
||||
async def test_remove_configured_db_url_if_not_needed_when_needed(
|
||||
recorder_mock: Recorder,
|
||||
hass: HomeAssistant,
|
||||
) -> None:
|
||||
"""Test configured db_url is not replaced if it differs from the recorder db."""
|
||||
db_url = "mssql://"
|
||||
|
||||
config = {
|
||||
"db_url": db_url,
|
||||
"query": "SELECT 5 as value",
|
||||
"column": "value",
|
||||
"name": "count_tables",
|
||||
}
|
||||
|
||||
config_entry = await init_integration(hass, config)
|
||||
|
||||
assert config_entry.options.get("db_url") == db_url
|
||||
|
|
|
@ -182,6 +182,7 @@ async def test_invalid_url_setup(
|
|||
|
||||
|
||||
async def test_invalid_url_on_update(
|
||||
recorder_mock: Recorder,
|
||||
hass: HomeAssistant,
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
) -> None:
|
||||
|
@ -192,22 +193,9 @@ async def test_invalid_url_on_update(
|
|||
"column": "value",
|
||||
"name": "count_tables",
|
||||
}
|
||||
entry = MockConfigEntry(
|
||||
domain=DOMAIN,
|
||||
source=SOURCE_USER,
|
||||
data={},
|
||||
options=config,
|
||||
entry_id="1",
|
||||
)
|
||||
|
||||
entry.add_to_hass(hass)
|
||||
|
||||
await hass.config_entries.async_setup(entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
await init_integration(hass, config)
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.recorder",
|
||||
), patch(
|
||||
"homeassistant.components.sql.sensor.sqlalchemy.engine.cursor.CursorResult",
|
||||
side_effect=SQLAlchemyError(
|
||||
"sqlite://homeassistant:hunter2@homeassistant.local"
|
||||
|
@ -219,7 +207,6 @@ async def test_invalid_url_on_update(
|
|||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert "sqlite://homeassistant:hunter2@homeassistant.local" not in caplog.text
|
||||
assert "sqlite://****:****@homeassistant.local" in caplog.text
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,25 @@
|
|||
"""Test the sql utils."""
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from homeassistant.components.recorder import get_instance
|
||||
from homeassistant.components.sql.util import resolve_db_url
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
|
||||
async def test_resolve_db_url_when_none_configured(
|
||||
recorder_mock: AsyncMock,
|
||||
hass: HomeAssistant,
|
||||
):
|
||||
"""Test return recorder db_url if provided db_url is None."""
|
||||
db_url = None
|
||||
resolved_url = resolve_db_url(hass, db_url)
|
||||
|
||||
assert resolved_url == get_instance(hass).db_url
|
||||
|
||||
|
||||
async def test_resolve_db_url_when_configured(hass: HomeAssistant):
|
||||
"""Test return provided db_url if it's set."""
|
||||
db_url = "mssql://"
|
||||
resolved_url = resolve_db_url(hass, db_url)
|
||||
|
||||
assert resolved_url == db_url
|
Loading…
Reference in New Issue