Add is_host_valid util (#76589)

pull/78232/head
Artem Draft 2022-09-11 19:12:04 +03:00 committed by GitHub
parent b0777e6280
commit 29be6d17b0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 47 additions and 56 deletions

View File

@ -1,9 +1,6 @@
"""Config flow to configure the Bravia TV integration."""
from __future__ import annotations
from contextlib import suppress
import ipaddress
import re
from typing import Any
from aiohttp import CookieJar
@ -17,6 +14,7 @@ from homeassistant.core import callback
from homeassistant.data_entry_flow import FlowResult
from homeassistant.helpers.aiohttp_client import async_create_clientsession
import homeassistant.helpers.config_validation as cv
from homeassistant.util.network import is_host_valid
from . import BraviaTVCoordinator
from .const import (
@ -30,15 +28,6 @@ from .const import (
)
def host_valid(host: str) -> bool:
"""Return True if hostname or IP address is valid."""
with suppress(ValueError):
if ipaddress.ip_address(host).version in [4, 6]:
return True
disallowed = re.compile(r"[^a-zA-Z\d\-]")
return all(x and not disallowed.search(x) for x in host.split("."))
class BraviaTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
"""Handle a config flow for Bravia TV integration."""
@ -82,7 +71,7 @@ class BraviaTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
if user_input is not None:
host = user_input[CONF_HOST]
if host_valid(host):
if is_host_valid(host):
session = async_create_clientsession(
self.hass,
cookie_jar=CookieJar(unsafe=True, quote_cookie=False),

View File

@ -1,8 +1,6 @@
"""Adds config flow for Brother Printer."""
from __future__ import annotations
import ipaddress
import re
from typing import Any
from brother import Brother, SnmpError, UnsupportedModel
@ -12,6 +10,7 @@ from homeassistant import config_entries, exceptions
from homeassistant.components import zeroconf
from homeassistant.const import CONF_HOST, CONF_TYPE
from homeassistant.data_entry_flow import FlowResult
from homeassistant.util.network import is_host_valid
from .const import DOMAIN, PRINTER_TYPES
from .utils import get_snmp_engine
@ -24,17 +23,6 @@ DATA_SCHEMA = vol.Schema(
)
def host_valid(host: str) -> bool:
"""Return True if hostname or IP address is valid."""
try:
if ipaddress.ip_address(host).version in [4, 6]:
return True
except ValueError:
pass
disallowed = re.compile(r"[^a-zA-Z\d\-]")
return all(x and not disallowed.search(x) for x in host.split("."))
class BrotherConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
"""Handle a config flow for Brother Printer."""
@ -53,7 +41,7 @@ class BrotherConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
if user_input is not None:
try:
if not host_valid(user_input[CONF_HOST]):
if not is_host_valid(user_input[CONF_HOST]):
raise InvalidHost()
snmp_engine = get_snmp_engine(self.hass)

View File

@ -1,8 +1,6 @@
"""Adds config flow for Dune HD integration."""
from __future__ import annotations
import ipaddress
import re
from typing import Any
from pdunehd import DuneHDPlayer
@ -11,23 +9,11 @@ import voluptuous as vol
from homeassistant import config_entries, exceptions
from homeassistant.const import CONF_HOST
from homeassistant.data_entry_flow import FlowResult
from homeassistant.util.network import is_host_valid
from .const import DOMAIN
def host_valid(host: str) -> bool:
"""Return True if hostname or IP address is valid."""
try:
if ipaddress.ip_address(host).version in (4, 6):
return True
except ValueError:
pass
if len(host) > 253:
return False
allowed = re.compile(r"(?!-)[A-Z\d\-\_]{1,63}(?<!-)$", re.IGNORECASE)
return all(allowed.match(x) for x in host.split("."))
class DuneHDConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
"""Handle a config flow for Dune HD integration."""
@ -47,7 +33,7 @@ class DuneHDConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
errors = {}
if user_input is not None:
if host_valid(user_input[CONF_HOST]):
if is_host_valid(user_input[CONF_HOST]):
host: str = user_input[CONF_HOST]
try:

View File

@ -1,7 +1,5 @@
"""Config flow for Vilfo Router integration."""
import ipaddress
import logging
import re
from vilfo import Client as VilfoClient
from vilfo.exceptions import (
@ -12,6 +10,7 @@ import voluptuous as vol
from homeassistant import config_entries, core, exceptions
from homeassistant.const import CONF_ACCESS_TOKEN, CONF_HOST, CONF_ID, CONF_MAC
from homeassistant.util.network import is_host_valid
from .const import DOMAIN, ROUTER_DEFAULT_HOST
@ -29,16 +28,6 @@ RESULT_CANNOT_CONNECT = "cannot_connect"
RESULT_INVALID_AUTH = "invalid_auth"
def host_valid(host):
"""Return True if hostname or IP address is valid."""
try:
if ipaddress.ip_address(host).version in (4, 6):
return True
except ValueError:
disallowed = re.compile(r"[^a-zA-Z\d\-]")
return all(x and not disallowed.search(x) for x in host.split("."))
def _try_connect_and_fetch_basic_info(host, token):
"""Attempt to connect and call the ping endpoint and, if successful, fetch basic information."""
@ -80,7 +69,7 @@ async def validate_input(hass: core.HomeAssistant, data):
"""
# Validate the host before doing anything else.
if not host_valid(data[CONF_HOST]):
if not is_host_valid(data[CONF_HOST]):
raise InvalidHost
config = {}

View File

@ -2,6 +2,7 @@
from __future__ import annotations
from ipaddress import IPv4Address, IPv6Address, ip_address, ip_network
import re
import yarl
@ -86,6 +87,20 @@ def is_ipv6_address(address: str) -> bool:
return True
def is_host_valid(host: str) -> bool:
"""Check if a given string is an IP address or valid hostname."""
if is_ip_address(host):
return True
if len(host) > 255:
return False
if re.match(r"^[0-9\.]+$", host): # reject invalid IPv4
return False
if host.endswith("."): # dot at the end is correct
host = host[:-1]
allowed = re.compile(r"(?!-)[A-Z\d\-]{1,63}(?<!-)$", re.IGNORECASE)
return all(allowed.match(x) for x in host.split("."))
def normalize_url(address: str) -> str:
"""Normalize a given URL."""
url = yarl.URL(address.rstrip("/"))

View File

@ -80,6 +80,30 @@ def test_is_ipv6_address():
assert network_util.is_ipv6_address("8.8.8.8") is False
def test_is_valid_host():
"""Test if strings are IPv6 addresses."""
assert network_util.is_host_valid("::1")
assert network_util.is_host_valid("::ffff:127.0.0.0")
assert network_util.is_host_valid("2001:0db8:85a3:0000:0000:8a2e:0370:7334")
assert network_util.is_host_valid("8.8.8.8")
assert network_util.is_host_valid("local")
assert network_util.is_host_valid("host-host")
assert network_util.is_host_valid("example.com")
assert network_util.is_host_valid("example.com.")
assert network_util.is_host_valid("Example123.com")
assert not network_util.is_host_valid("")
assert not network_util.is_host_valid("192.168.0.1:8080")
assert not network_util.is_host_valid("192.168.0.999")
assert not network_util.is_host_valid("2001:hb8::1:0:0:1")
assert not network_util.is_host_valid("-host-host")
assert not network_util.is_host_valid("host-host-")
assert not network_util.is_host_valid("host_host")
assert not network_util.is_host_valid("example.com/path")
assert not network_util.is_host_valid("example.com:8080")
assert not network_util.is_host_valid("verylonghostname" * 4)
assert not network_util.is_host_valid("verydeepdomain." * 18)
def test_normalize_url():
"""Test the normalizing of URLs."""
assert network_util.normalize_url("http://example.com") == "http://example.com"