Add config flow for cloudflare (#41167)
* add config flow for cloudflare * Create const.py * work on flow. * remove const. * lint. * Apply suggestions from code review Co-authored-by: J. Nick Koston <nick@koston.org> * Update config_flows.py * Update homeassistant/components/cloudflare/strings.json * Apply suggestions from code review * Apply suggestions from code review * Apply suggestions from code review * Update strings.json * Apply suggestions from code review * Update __init__.py * Update __init__.py Co-authored-by: J. Nick Koston <nick@koston.org>pull/42167/head
parent
e203896638
commit
d8577a1550
|
@ -131,7 +131,6 @@ omit =
|
|||
homeassistant/components/clickatell/notify.py
|
||||
homeassistant/components/clicksend/notify.py
|
||||
homeassistant/components/clicksend_tts/notify.py
|
||||
homeassistant/components/cloudflare/*
|
||||
homeassistant/components/cmus/media_player.py
|
||||
homeassistant/components/co2signal/*
|
||||
homeassistant/components/coinbase/*
|
||||
|
|
|
@ -75,7 +75,7 @@ homeassistant/components/cisco_ios/* @fbradyirl
|
|||
homeassistant/components/cisco_mobility_express/* @fbradyirl
|
||||
homeassistant/components/cisco_webex_teams/* @fbradyirl
|
||||
homeassistant/components/cloud/* @home-assistant/cloud
|
||||
homeassistant/components/cloudflare/* @ludeeus
|
||||
homeassistant/components/cloudflare/* @ludeeus @ctalkington
|
||||
homeassistant/components/comfoconnect/* @michaelarnauts
|
||||
homeassistant/components/config/* @home-assistant/core
|
||||
homeassistant/components/configurator/* @home-assistant/core
|
||||
|
|
|
@ -1,74 +1,130 @@
|
|||
"""Update the IP addresses of your Cloudflare DNS records."""
|
||||
from datetime import timedelta
|
||||
import logging
|
||||
from typing import Dict
|
||||
|
||||
from pycfdns import CloudflareUpdater
|
||||
from pycfdns.exceptions import (
|
||||
CloudflareAuthenticationException,
|
||||
CloudflareConnectionException,
|
||||
CloudflareException,
|
||||
)
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.const import CONF_API_KEY, CONF_EMAIL, CONF_ZONE
|
||||
from homeassistant.components import persistent_notification
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import CONF_API_KEY, CONF_API_TOKEN, CONF_EMAIL, CONF_ZONE
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import ConfigEntryNotReady
|
||||
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
from homeassistant.helpers.event import track_time_interval
|
||||
from homeassistant.helpers.event import async_track_time_interval
|
||||
|
||||
from .const import (
|
||||
CONF_RECORDS,
|
||||
DATA_UNDO_UPDATE_INTERVAL,
|
||||
DEFAULT_UPDATE_INTERVAL,
|
||||
DOMAIN,
|
||||
SERVICE_UPDATE_RECORDS,
|
||||
)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
CONF_RECORDS = "records"
|
||||
|
||||
DOMAIN = "cloudflare"
|
||||
|
||||
INTERVAL = timedelta(minutes=60)
|
||||
|
||||
CONFIG_SCHEMA = vol.Schema(
|
||||
{
|
||||
DOMAIN: vol.Schema(
|
||||
{
|
||||
vol.Required(CONF_EMAIL): cv.string,
|
||||
vol.Required(CONF_API_KEY): cv.string,
|
||||
vol.Required(CONF_ZONE): cv.string,
|
||||
vol.Required(CONF_RECORDS): vol.All(cv.ensure_list, [cv.string]),
|
||||
}
|
||||
DOMAIN: vol.All(
|
||||
cv.deprecated(CONF_EMAIL, invalidation_version="0.119"),
|
||||
cv.deprecated(CONF_API_KEY, invalidation_version="0.119"),
|
||||
cv.deprecated(CONF_ZONE, invalidation_version="0.119"),
|
||||
cv.deprecated(CONF_RECORDS, invalidation_version="0.119"),
|
||||
vol.Schema(
|
||||
{
|
||||
vol.Optional(CONF_EMAIL): cv.string,
|
||||
vol.Optional(CONF_API_KEY): cv.string,
|
||||
vol.Optional(CONF_ZONE): cv.string,
|
||||
vol.Optional(CONF_RECORDS): vol.All(cv.ensure_list, [cv.string]),
|
||||
}
|
||||
),
|
||||
)
|
||||
},
|
||||
extra=vol.ALLOW_EXTRA,
|
||||
)
|
||||
|
||||
|
||||
def setup(hass, config):
|
||||
"""Set up the Cloudflare component."""
|
||||
async def async_setup(hass: HomeAssistant, config: Dict) -> bool:
|
||||
"""Set up the component."""
|
||||
hass.data.setdefault(DOMAIN, {})
|
||||
|
||||
cfupdate = CloudflareUpdater()
|
||||
email = config[DOMAIN][CONF_EMAIL]
|
||||
key = config[DOMAIN][CONF_API_KEY]
|
||||
zone = config[DOMAIN][CONF_ZONE]
|
||||
records = config[DOMAIN][CONF_RECORDS]
|
||||
if len(hass.config_entries.async_entries(DOMAIN)) > 0:
|
||||
return True
|
||||
|
||||
def update_records_interval(now):
|
||||
"""Set up recurring update."""
|
||||
_update_cloudflare(cfupdate, email, key, zone, records)
|
||||
if DOMAIN in config and CONF_API_KEY in config[DOMAIN]:
|
||||
persistent_notification.async_create(
|
||||
hass,
|
||||
"Cloudflare integration now requires an API Token. Please go to the integrations page to setup.",
|
||||
"Cloudflare Setup",
|
||||
"cloudflare_setup",
|
||||
)
|
||||
|
||||
def update_records_service(now):
|
||||
"""Set up service for manual trigger."""
|
||||
_update_cloudflare(cfupdate, email, key, zone, records)
|
||||
|
||||
track_time_interval(hass, update_records_interval, INTERVAL)
|
||||
hass.services.register(DOMAIN, "update_records", update_records_service)
|
||||
return True
|
||||
|
||||
|
||||
def _update_cloudflare(cfupdate, email, key, zone, records):
|
||||
"""Update DNS records for a given zone."""
|
||||
_LOGGER.debug("Starting update for zone %s", zone)
|
||||
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
"""Set up Cloudflare from a config entry."""
|
||||
cfupdate = CloudflareUpdater(
|
||||
async_get_clientsession(hass),
|
||||
entry.data[CONF_API_TOKEN],
|
||||
entry.data[CONF_ZONE],
|
||||
entry.data[CONF_RECORDS],
|
||||
)
|
||||
|
||||
headers = cfupdate.set_header(email, key)
|
||||
_LOGGER.debug("Header data defined as: %s", headers)
|
||||
try:
|
||||
zone_id = await cfupdate.get_zone_id()
|
||||
except CloudflareAuthenticationException:
|
||||
_LOGGER.error("API access forbidden. Please reauthenticate")
|
||||
return False
|
||||
except CloudflareConnectionException as error:
|
||||
raise ConfigEntryNotReady from error
|
||||
|
||||
zoneid = cfupdate.get_zoneID(headers, zone)
|
||||
_LOGGER.debug("Zone ID is set to: %s", zoneid)
|
||||
async def update_records(now):
|
||||
"""Set up recurring update."""
|
||||
try:
|
||||
await _async_update_cloudflare(cfupdate, zone_id)
|
||||
except CloudflareException as error:
|
||||
_LOGGER.error("Error updating zone %s: %s", entry.data[CONF_ZONE], error)
|
||||
|
||||
update_records = cfupdate.get_recordInfo(headers, zoneid, zone, records)
|
||||
_LOGGER.debug("Records: %s", update_records)
|
||||
async def update_records_service(call):
|
||||
"""Set up service for manual trigger."""
|
||||
try:
|
||||
await _async_update_cloudflare(cfupdate, zone_id)
|
||||
except CloudflareException as error:
|
||||
_LOGGER.error("Error updating zone %s: %s", entry.data[CONF_ZONE], error)
|
||||
|
||||
result = cfupdate.update_records(headers, zoneid, update_records)
|
||||
_LOGGER.debug("Update for zone %s is complete", zone)
|
||||
update_interval = timedelta(minutes=DEFAULT_UPDATE_INTERVAL)
|
||||
undo_interval = async_track_time_interval(hass, update_records, update_interval)
|
||||
|
||||
if result is not True:
|
||||
_LOGGER.warning(result)
|
||||
hass.data[DOMAIN][entry.entry_id] = {
|
||||
DATA_UNDO_UPDATE_INTERVAL: undo_interval,
|
||||
}
|
||||
|
||||
hass.services.async_register(DOMAIN, SERVICE_UPDATE_RECORDS, update_records_service)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
"""Unload Cloudflare config entry."""
|
||||
hass.data[DOMAIN][entry.entry_id][DATA_UNDO_UPDATE_INTERVAL]()
|
||||
hass.data[DOMAIN].pop(entry.entry_id)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def _async_update_cloudflare(cfupdate: CloudflareUpdater, zone_id: str):
|
||||
_LOGGER.debug("Starting update for zone %s", cfupdate.zone)
|
||||
|
||||
records = await cfupdate.get_record_info(zone_id)
|
||||
_LOGGER.debug("Records: %s", records)
|
||||
|
||||
await cfupdate.update_records(zone_id, records)
|
||||
_LOGGER.debug("Update for zone %s is complete", cfupdate.zone)
|
||||
|
|
|
@ -0,0 +1,180 @@
|
|||
"""Config flow for Cloudflare integration."""
|
||||
import logging
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from pycfdns import CloudflareUpdater
|
||||
from pycfdns.exceptions import (
|
||||
CloudflareAuthenticationException,
|
||||
CloudflareConnectionException,
|
||||
CloudflareZoneException,
|
||||
)
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components import persistent_notification
|
||||
from homeassistant.config_entries import CONN_CLASS_CLOUD_PUSH, ConfigFlow
|
||||
from homeassistant.const import CONF_API_TOKEN, CONF_ZONE
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import config_validation as cv
|
||||
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
||||
|
||||
from .const import CONF_RECORDS
|
||||
from .const import DOMAIN # pylint:disable=unused-import
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
DATA_SCHEMA = vol.Schema(
|
||||
{
|
||||
vol.Required(CONF_API_TOKEN): str,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _zone_schema(zones: Optional[List] = None):
|
||||
"""Zone selection schema."""
|
||||
zones_list = []
|
||||
|
||||
if zones is not None:
|
||||
zones_list = zones
|
||||
|
||||
return vol.Schema({vol.Required(CONF_ZONE): vol.In(zones_list)})
|
||||
|
||||
|
||||
def _records_schema(records: Optional[List] = None):
|
||||
"""Zone records selection schema."""
|
||||
records_dict = {}
|
||||
|
||||
if records:
|
||||
records_dict = {name: name for name in records}
|
||||
|
||||
return vol.Schema({vol.Required(CONF_RECORDS): cv.multi_select(records_dict)})
|
||||
|
||||
|
||||
async def validate_input(hass: HomeAssistant, data: Dict):
|
||||
"""Validate the user input allows us to connect.
|
||||
|
||||
Data has the keys from DATA_SCHEMA with values provided by the user.
|
||||
"""
|
||||
zone = data.get(CONF_ZONE)
|
||||
records = None
|
||||
|
||||
cfupdate = CloudflareUpdater(
|
||||
async_get_clientsession(hass),
|
||||
data[CONF_API_TOKEN],
|
||||
zone,
|
||||
[],
|
||||
)
|
||||
|
||||
try:
|
||||
zones = await cfupdate.get_zones()
|
||||
if zone:
|
||||
zone_id = await cfupdate.get_zone_id()
|
||||
records = await cfupdate.get_zone_records(zone_id, "A")
|
||||
except CloudflareConnectionException as error:
|
||||
raise CannotConnect from error
|
||||
except CloudflareAuthenticationException as error:
|
||||
raise InvalidAuth from error
|
||||
except CloudflareZoneException as error:
|
||||
raise InvalidZone from error
|
||||
|
||||
return {"zones": zones, "records": records}
|
||||
|
||||
|
||||
class CloudflareConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||
"""Handle a config flow for Cloudflare."""
|
||||
|
||||
VERSION = 1
|
||||
CONNECTION_CLASS = CONN_CLASS_CLOUD_PUSH
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the Cloudflare config flow."""
|
||||
self.cloudflare_config = {}
|
||||
self.zones = None
|
||||
self.records = None
|
||||
|
||||
async def async_step_user(self, user_input: Optional[Dict] = None):
|
||||
"""Handle a flow initiated by the user."""
|
||||
if self._async_current_entries():
|
||||
return self.async_abort(reason="single_instance_allowed")
|
||||
|
||||
assert self.hass
|
||||
persistent_notification.async_dismiss(self.hass, "cloudflare_setup")
|
||||
|
||||
errors = {}
|
||||
|
||||
if user_input is not None:
|
||||
info, errors = await self._async_validate_or_error(user_input)
|
||||
|
||||
if not errors:
|
||||
self.cloudflare_config.update(user_input)
|
||||
self.zones = info["zones"]
|
||||
return await self.async_step_zone()
|
||||
|
||||
return self.async_show_form(
|
||||
step_id="user", data_schema=DATA_SCHEMA, errors=errors
|
||||
)
|
||||
|
||||
async def async_step_zone(self, user_input: Optional[Dict] = None):
|
||||
"""Handle the picking the zone."""
|
||||
errors = {}
|
||||
|
||||
if user_input is not None:
|
||||
self.cloudflare_config.update(user_input)
|
||||
info, errors = await self._async_validate_or_error(self.cloudflare_config)
|
||||
|
||||
if not errors:
|
||||
await self.async_set_unique_id(user_input[CONF_ZONE])
|
||||
self.records = info["records"]
|
||||
|
||||
return await self.async_step_records()
|
||||
|
||||
return self.async_show_form(
|
||||
step_id="zone",
|
||||
data_schema=_zone_schema(self.zones),
|
||||
errors=errors,
|
||||
)
|
||||
|
||||
async def async_step_records(self, user_input: Optional[Dict] = None):
|
||||
"""Handle the picking the zone records."""
|
||||
errors = {}
|
||||
|
||||
if user_input is not None:
|
||||
self.cloudflare_config.update(user_input)
|
||||
title = self.cloudflare_config[CONF_ZONE]
|
||||
return self.async_create_entry(title=title, data=self.cloudflare_config)
|
||||
|
||||
return self.async_show_form(
|
||||
step_id="records",
|
||||
data_schema=_records_schema(self.records),
|
||||
errors=errors,
|
||||
)
|
||||
|
||||
async def _async_validate_or_error(self, config):
|
||||
errors = {}
|
||||
info = {}
|
||||
|
||||
try:
|
||||
info = await validate_input(self.hass, config)
|
||||
except CannotConnect:
|
||||
errors["base"] = "cannot_connect"
|
||||
except InvalidAuth:
|
||||
errors["base"] = "invalid_auth"
|
||||
except InvalidZone:
|
||||
errors["base"] = "invalid_zone"
|
||||
except Exception: # pylint: disable=broad-except
|
||||
_LOGGER.exception("Unexpected exception")
|
||||
errors["base"] = "unknown"
|
||||
|
||||
return info, errors
|
||||
|
||||
|
||||
class CannotConnect(HomeAssistantError):
|
||||
"""Error to indicate we cannot connect."""
|
||||
|
||||
|
||||
class InvalidAuth(HomeAssistantError):
|
||||
"""Error to indicate there is invalid auth."""
|
||||
|
||||
|
||||
class InvalidZone(HomeAssistantError):
|
||||
"""Error to indicate we cannot validate zone exists in account."""
|
|
@ -0,0 +1,15 @@
|
|||
"""Constants for Cloudflare."""
|
||||
|
||||
DOMAIN = "cloudflare"
|
||||
|
||||
# Config
|
||||
CONF_RECORDS = "records"
|
||||
|
||||
# Data
|
||||
DATA_UNDO_UPDATE_INTERVAL = "undo_update_interval"
|
||||
|
||||
# Defaults
|
||||
DEFAULT_UPDATE_INTERVAL = 60 # in minutes
|
||||
|
||||
# Services
|
||||
SERVICE_UPDATE_RECORDS = "update_records"
|
|
@ -2,6 +2,7 @@
|
|||
"domain": "cloudflare",
|
||||
"name": "Cloudflare",
|
||||
"documentation": "https://www.home-assistant.io/integrations/cloudflare",
|
||||
"requirements": ["pycfdns==0.0.1"],
|
||||
"codeowners": ["@ludeeus"]
|
||||
"requirements": ["pycfdns==1.1.1"],
|
||||
"codeowners": ["@ludeeus", "@ctalkington"],
|
||||
"config_flow": true
|
||||
}
|
||||
|
|
|
@ -0,0 +1,35 @@
|
|||
{
|
||||
"config": {
|
||||
"flow_title": "Cloudflare: {name}",
|
||||
"step": {
|
||||
"user": {
|
||||
"title": "Connect to Cloudflare",
|
||||
"description": "This integration requires an API Token created with Zone:Zone:Read and Zone:DNS:Edit permissions for all zones in your account.",
|
||||
"data": {
|
||||
"api_token": "[%key:common::config_flow::data::api_token%]"
|
||||
}
|
||||
},
|
||||
"zone": {
|
||||
"title": "Choose the Zone to Update",
|
||||
"data": {
|
||||
"zone": "Zone"
|
||||
}
|
||||
},
|
||||
"records": {
|
||||
"title": "Choose the Records to Update",
|
||||
"data": {
|
||||
"records": "Records"
|
||||
}
|
||||
}
|
||||
},
|
||||
"error": {
|
||||
"cannot_connect": "[%key:common::config_flow::error::cannot_connect%]",
|
||||
"invalid_auth": "[%key:common::config_flow::error::invalid_auth%]",
|
||||
"invalid_zone": "Invalid zone"
|
||||
},
|
||||
"abort": {
|
||||
"single_instance_allowed": "[%key:common::config_flow::abort::single_instance_allowed%]",
|
||||
"unknown": "[%key:common::config_flow::error::unknown%]"
|
||||
}
|
||||
}
|
||||
}
|
|
@ -35,6 +35,7 @@ FLOWS = [
|
|||
"canary",
|
||||
"cast",
|
||||
"cert_expiry",
|
||||
"cloudflare",
|
||||
"control4",
|
||||
"coolmaster",
|
||||
"coronavirus",
|
||||
|
|
|
@ -1288,7 +1288,7 @@ pybotvac==0.0.17
|
|||
pycarwings2==2.9
|
||||
|
||||
# homeassistant.components.cloudflare
|
||||
pycfdns==0.0.1
|
||||
pycfdns==1.1.1
|
||||
|
||||
# homeassistant.components.channels
|
||||
pychannels==1.0.0
|
||||
|
|
|
@ -635,6 +635,9 @@ pyblackbird==0.5
|
|||
# homeassistant.components.neato
|
||||
pybotvac==0.0.17
|
||||
|
||||
# homeassistant.components.cloudflare
|
||||
pycfdns==1.1.1
|
||||
|
||||
# homeassistant.components.cast
|
||||
pychromecast==7.5.1
|
||||
|
||||
|
|
|
@ -0,0 +1,101 @@
|
|||
"""Tests for the Cloudflare integration."""
|
||||
from typing import List
|
||||
|
||||
from pycfdns import CFRecord
|
||||
|
||||
from homeassistant.components.cloudflare.const import CONF_RECORDS, DOMAIN
|
||||
from homeassistant.const import CONF_API_TOKEN, CONF_ZONE
|
||||
|
||||
from tests.async_mock import AsyncMock, patch
|
||||
from tests.common import MockConfigEntry
|
||||
|
||||
ENTRY_CONFIG = {
|
||||
CONF_API_TOKEN: "mock-api-token",
|
||||
CONF_ZONE: "mock.com",
|
||||
CONF_RECORDS: ["ha.mock.com", "homeassistant.mock.com"],
|
||||
}
|
||||
|
||||
ENTRY_OPTIONS = {}
|
||||
|
||||
USER_INPUT = {
|
||||
CONF_API_TOKEN: "mock-api-token",
|
||||
}
|
||||
|
||||
USER_INPUT_ZONE = {CONF_ZONE: "mock.com"}
|
||||
|
||||
USER_INPUT_RECORDS = {CONF_RECORDS: ["ha.mock.com", "homeassistant.mock.com"]}
|
||||
|
||||
MOCK_ZONE = "mock.com"
|
||||
MOCK_ZONE_ID = "mock-zone-id"
|
||||
MOCK_ZONE_RECORDS = [
|
||||
{
|
||||
"id": "zone-record-id",
|
||||
"type": "A",
|
||||
"name": "ha.mock.com",
|
||||
"proxied": True,
|
||||
"content": "127.0.0.1",
|
||||
},
|
||||
{
|
||||
"id": "zone-record-id-2",
|
||||
"type": "A",
|
||||
"name": "homeassistant.mock.com",
|
||||
"proxied": True,
|
||||
"content": "127.0.0.1",
|
||||
},
|
||||
{
|
||||
"id": "zone-record-id-3",
|
||||
"type": "A",
|
||||
"name": "mock.com",
|
||||
"proxied": True,
|
||||
"content": "127.0.0.1",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
async def init_integration(
|
||||
hass,
|
||||
*,
|
||||
data: dict = ENTRY_CONFIG,
|
||||
options: dict = ENTRY_OPTIONS,
|
||||
) -> MockConfigEntry:
|
||||
"""Set up the Cloudflare integration in Home Assistant."""
|
||||
entry = MockConfigEntry(domain=DOMAIN, data=data, options=options)
|
||||
entry.add_to_hass(hass)
|
||||
|
||||
await hass.config_entries.async_setup(entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
return entry
|
||||
|
||||
|
||||
def _get_mock_cfupdate(
|
||||
zone: str = MOCK_ZONE,
|
||||
zone_id: str = MOCK_ZONE_ID,
|
||||
records: List = MOCK_ZONE_RECORDS,
|
||||
):
|
||||
client = AsyncMock()
|
||||
|
||||
zone_records = [record["name"] for record in records]
|
||||
cf_records = [CFRecord(record) for record in records]
|
||||
|
||||
client.get_zones = AsyncMock(return_value=[zone])
|
||||
client.get_zone_records = AsyncMock(return_value=zone_records)
|
||||
client.get_record_info = AsyncMock(return_value=cf_records)
|
||||
client.get_zone_id = AsyncMock(return_value=zone_id)
|
||||
client.update_records = AsyncMock(return_value=None)
|
||||
|
||||
return client
|
||||
|
||||
|
||||
def _patch_async_setup(return_value=True):
|
||||
return patch(
|
||||
"homeassistant.components.cloudflare.async_setup",
|
||||
return_value=return_value,
|
||||
)
|
||||
|
||||
|
||||
def _patch_async_setup_entry(return_value=True):
|
||||
return patch(
|
||||
"homeassistant.components.cloudflare.async_setup_entry",
|
||||
return_value=return_value,
|
||||
)
|
|
@ -0,0 +1,28 @@
|
|||
"""Define fixtures available for all tests."""
|
||||
from pytest import fixture
|
||||
|
||||
from . import _get_mock_cfupdate
|
||||
|
||||
from tests.async_mock import patch
|
||||
|
||||
|
||||
@fixture
|
||||
def cfupdate(hass):
|
||||
"""Mock the CloudflareUpdater for easier testing."""
|
||||
mock_cfupdate = _get_mock_cfupdate()
|
||||
with patch(
|
||||
"homeassistant.components.cloudflare.CloudflareUpdater",
|
||||
return_value=mock_cfupdate,
|
||||
) as mock_api:
|
||||
yield mock_api
|
||||
|
||||
|
||||
@fixture
|
||||
def cfupdate_flow(hass):
|
||||
"""Mock the CloudflareUpdater for easier config flow testing."""
|
||||
mock_cfupdate = _get_mock_cfupdate()
|
||||
with patch(
|
||||
"homeassistant.components.cloudflare.config_flow.CloudflareUpdater",
|
||||
return_value=mock_cfupdate,
|
||||
) as mock_api:
|
||||
yield mock_api
|
|
@ -0,0 +1,166 @@
|
|||
"""Test the Cloudflare config flow."""
|
||||
from pycfdns.exceptions import (
|
||||
CloudflareAuthenticationException,
|
||||
CloudflareConnectionException,
|
||||
CloudflareZoneException,
|
||||
)
|
||||
|
||||
from homeassistant.components.cloudflare.const import CONF_RECORDS, DOMAIN
|
||||
from homeassistant.config_entries import SOURCE_USER
|
||||
from homeassistant.const import CONF_API_TOKEN, CONF_SOURCE, CONF_ZONE
|
||||
from homeassistant.data_entry_flow import (
|
||||
RESULT_TYPE_ABORT,
|
||||
RESULT_TYPE_CREATE_ENTRY,
|
||||
RESULT_TYPE_FORM,
|
||||
)
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from . import (
|
||||
ENTRY_CONFIG,
|
||||
USER_INPUT,
|
||||
USER_INPUT_RECORDS,
|
||||
USER_INPUT_ZONE,
|
||||
_patch_async_setup,
|
||||
_patch_async_setup_entry,
|
||||
)
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
|
||||
|
||||
async def test_user_form(hass, cfupdate_flow):
|
||||
"""Test we get the user initiated form."""
|
||||
await async_setup_component(hass, "persistent_notification", {})
|
||||
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={CONF_SOURCE: SOURCE_USER}
|
||||
)
|
||||
assert result["type"] == RESULT_TYPE_FORM
|
||||
assert result["step_id"] == "user"
|
||||
assert result["errors"] == {}
|
||||
|
||||
result = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
USER_INPUT,
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert result["type"] == RESULT_TYPE_FORM
|
||||
assert result["step_id"] == "zone"
|
||||
assert result["errors"] == {}
|
||||
|
||||
result = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
USER_INPUT_ZONE,
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert result["type"] == RESULT_TYPE_FORM
|
||||
assert result["step_id"] == "records"
|
||||
assert result["errors"] == {}
|
||||
|
||||
with _patch_async_setup() as mock_setup, _patch_async_setup_entry() as mock_setup_entry:
|
||||
result = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
USER_INPUT_RECORDS,
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert result["type"] == RESULT_TYPE_CREATE_ENTRY
|
||||
assert result["title"] == USER_INPUT_ZONE[CONF_ZONE]
|
||||
|
||||
assert result["data"]
|
||||
assert result["data"][CONF_API_TOKEN] == USER_INPUT[CONF_API_TOKEN]
|
||||
assert result["data"][CONF_ZONE] == USER_INPUT_ZONE[CONF_ZONE]
|
||||
assert result["data"][CONF_RECORDS] == USER_INPUT_RECORDS[CONF_RECORDS]
|
||||
|
||||
assert result["result"]
|
||||
assert result["result"].unique_id == USER_INPUT_ZONE[CONF_ZONE]
|
||||
|
||||
assert len(mock_setup.mock_calls) == 1
|
||||
assert len(mock_setup_entry.mock_calls) == 1
|
||||
|
||||
|
||||
async def test_user_form_cannot_connect(hass, cfupdate_flow):
|
||||
"""Test we handle cannot connect error."""
|
||||
instance = cfupdate_flow.return_value
|
||||
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={CONF_SOURCE: SOURCE_USER}
|
||||
)
|
||||
|
||||
instance.get_zones.side_effect = CloudflareConnectionException()
|
||||
result = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
USER_INPUT,
|
||||
)
|
||||
|
||||
assert result["type"] == RESULT_TYPE_FORM
|
||||
assert result["errors"] == {"base": "cannot_connect"}
|
||||
|
||||
|
||||
async def test_user_form_invalid_auth(hass, cfupdate_flow):
|
||||
"""Test we handle invalid auth error."""
|
||||
instance = cfupdate_flow.return_value
|
||||
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={CONF_SOURCE: SOURCE_USER}
|
||||
)
|
||||
|
||||
instance.get_zones.side_effect = CloudflareAuthenticationException()
|
||||
result = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
USER_INPUT,
|
||||
)
|
||||
|
||||
assert result["type"] == RESULT_TYPE_FORM
|
||||
assert result["errors"] == {"base": "invalid_auth"}
|
||||
|
||||
|
||||
async def test_user_form_invalid_zone(hass, cfupdate_flow):
|
||||
"""Test we handle invalid zone error."""
|
||||
instance = cfupdate_flow.return_value
|
||||
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={CONF_SOURCE: SOURCE_USER}
|
||||
)
|
||||
|
||||
instance.get_zones.side_effect = CloudflareZoneException()
|
||||
result = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
USER_INPUT,
|
||||
)
|
||||
|
||||
assert result["type"] == RESULT_TYPE_FORM
|
||||
assert result["errors"] == {"base": "invalid_zone"}
|
||||
|
||||
|
||||
async def test_user_form_unexpected_exception(hass, cfupdate_flow):
|
||||
"""Test we handle unexpected exception."""
|
||||
instance = cfupdate_flow.return_value
|
||||
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={CONF_SOURCE: SOURCE_USER}
|
||||
)
|
||||
|
||||
instance.get_zones.side_effect = Exception()
|
||||
result = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
USER_INPUT,
|
||||
)
|
||||
|
||||
assert result["type"] == RESULT_TYPE_FORM
|
||||
assert result["errors"] == {"base": "unknown"}
|
||||
|
||||
|
||||
async def test_user_form_single_instance_allowed(hass):
|
||||
"""Test that configuring more than one instance is rejected."""
|
||||
entry = MockConfigEntry(domain=DOMAIN, data=ENTRY_CONFIG)
|
||||
entry.add_to_hass(hass)
|
||||
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN,
|
||||
context={CONF_SOURCE: SOURCE_USER},
|
||||
data=USER_INPUT,
|
||||
)
|
||||
assert result["type"] == RESULT_TYPE_ABORT
|
||||
assert result["reason"] == "single_instance_allowed"
|
|
@ -0,0 +1,58 @@
|
|||
"""Test the Cloudflare integration."""
|
||||
from pycfdns.exceptions import CloudflareConnectionException
|
||||
|
||||
from homeassistant.components.cloudflare.const import DOMAIN, SERVICE_UPDATE_RECORDS
|
||||
from homeassistant.config_entries import (
|
||||
ENTRY_STATE_LOADED,
|
||||
ENTRY_STATE_NOT_LOADED,
|
||||
ENTRY_STATE_SETUP_RETRY,
|
||||
)
|
||||
|
||||
from . import ENTRY_CONFIG, init_integration
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
|
||||
|
||||
async def test_unload_entry(hass, cfupdate):
|
||||
"""Test successful unload of entry."""
|
||||
entry = await init_integration(hass)
|
||||
|
||||
assert len(hass.config_entries.async_entries(DOMAIN)) == 1
|
||||
assert entry.state == ENTRY_STATE_LOADED
|
||||
|
||||
assert await hass.config_entries.async_unload(entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert entry.state == ENTRY_STATE_NOT_LOADED
|
||||
assert not hass.data.get(DOMAIN)
|
||||
|
||||
|
||||
async def test_async_setup_raises_entry_not_ready(hass, cfupdate):
|
||||
"""Test that it throws ConfigEntryNotReady when exception occurs during setup."""
|
||||
instance = cfupdate.return_value
|
||||
|
||||
entry = MockConfigEntry(domain=DOMAIN, data=ENTRY_CONFIG)
|
||||
entry.add_to_hass(hass)
|
||||
|
||||
instance.get_zone_id.side_effect = CloudflareConnectionException()
|
||||
await hass.config_entries.async_setup(entry.entry_id)
|
||||
|
||||
assert entry.state == ENTRY_STATE_SETUP_RETRY
|
||||
|
||||
|
||||
async def test_integration_services(hass, cfupdate):
|
||||
"""Test integration services."""
|
||||
instance = cfupdate.return_value
|
||||
|
||||
entry = await init_integration(hass)
|
||||
assert entry.state == ENTRY_STATE_LOADED
|
||||
|
||||
await hass.services.async_call(
|
||||
DOMAIN,
|
||||
SERVICE_UPDATE_RECORDS,
|
||||
{},
|
||||
blocking=True,
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
instance.update_records.assert_called_once()
|
Loading…
Reference in New Issue