Azure Event Hub code improvements (#62584)
* code improvements to AEH * moved hub backpull/62651/head
parent
e9c69682c7
commit
259e454c3e
|
@ -3,23 +3,25 @@ from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
|
from datetime import datetime
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import time
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from azure.eventhub import EventData, EventDataBatch
|
from azure.eventhub import EventData, EventDataBatch
|
||||||
|
from azure.eventhub.aio import EventHubProducerClient
|
||||||
from azure.eventhub.exceptions import EventHubError
|
from azure.eventhub.exceptions import EventHubError
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.config_entries import SOURCE_IMPORT, ConfigEntry, ConfigEntryNotReady
|
from homeassistant.config_entries import SOURCE_IMPORT, ConfigEntry, ConfigEntryNotReady
|
||||||
from homeassistant.const import MATCH_ALL, STATE_UNAVAILABLE, STATE_UNKNOWN
|
from homeassistant.const import MATCH_ALL
|
||||||
from homeassistant.core import Event, HomeAssistant
|
from homeassistant.core import Event, HomeAssistant, State
|
||||||
import homeassistant.helpers.config_validation as cv
|
import homeassistant.helpers.config_validation as cv
|
||||||
from homeassistant.helpers.entityfilter import FILTER_SCHEMA
|
from homeassistant.helpers.entityfilter import FILTER_SCHEMA
|
||||||
from homeassistant.helpers.event import async_call_later
|
from homeassistant.helpers.event import async_call_later
|
||||||
from homeassistant.helpers.json import JSONEncoder
|
from homeassistant.helpers.json import JSONEncoder
|
||||||
from homeassistant.helpers.typing import ConfigType
|
from homeassistant.helpers.typing import ConfigType
|
||||||
|
from homeassistant.util.dt import utcnow
|
||||||
|
|
||||||
from .client import AzureEventHubClient
|
from .client import AzureEventHubClient
|
||||||
from .const import (
|
from .const import (
|
||||||
|
@ -35,6 +37,7 @@ from .const import (
|
||||||
DATA_HUB,
|
DATA_HUB,
|
||||||
DEFAULT_MAX_DELAY,
|
DEFAULT_MAX_DELAY,
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
|
FILTER_STATES,
|
||||||
)
|
)
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
@ -91,10 +94,8 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||||
hass.data.setdefault(DOMAIN, {DATA_FILTER: FILTER_SCHEMA({})})
|
hass.data.setdefault(DOMAIN, {DATA_FILTER: FILTER_SCHEMA({})})
|
||||||
hub = AzureEventHub(
|
hub = AzureEventHub(
|
||||||
hass,
|
hass,
|
||||||
AzureEventHubClient.from_input(**entry.data),
|
entry,
|
||||||
hass.data[DOMAIN][DATA_FILTER],
|
hass.data[DOMAIN][DATA_FILTER],
|
||||||
entry.options[CONF_SEND_INTERVAL],
|
|
||||||
entry.options.get(CONF_MAX_DELAY),
|
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
await hub.async_test_connection()
|
await hub.async_test_connection()
|
||||||
|
@ -124,139 +125,128 @@ class AzureEventHub:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
client: AzureEventHubClient,
|
entry: ConfigEntry,
|
||||||
entities_filter: vol.Schema,
|
entities_filter: vol.Schema,
|
||||||
send_interval: int,
|
|
||||||
max_delay: int | None = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize the listener."""
|
"""Initialize the listener."""
|
||||||
self.hass = hass
|
self.hass = hass
|
||||||
self.queue: asyncio.PriorityQueue[ # pylint: disable=unsubscriptable-object
|
self._entry = entry
|
||||||
tuple[int, tuple[float, Event | None]]
|
|
||||||
] = asyncio.PriorityQueue()
|
|
||||||
self._client = client
|
|
||||||
self._entities_filter = entities_filter
|
self._entities_filter = entities_filter
|
||||||
self._send_interval = send_interval
|
|
||||||
self._max_delay = max_delay if max_delay else DEFAULT_MAX_DELAY
|
self._client = AzureEventHubClient.from_input(**self._entry.data)
|
||||||
|
self._send_interval = self._entry.options[CONF_SEND_INTERVAL]
|
||||||
|
self._max_delay = self._entry.options.get(CONF_MAX_DELAY, DEFAULT_MAX_DELAY)
|
||||||
|
|
||||||
|
self._shutdown = False
|
||||||
|
self._queue: asyncio.PriorityQueue[ # pylint: disable=unsubscriptable-object
|
||||||
|
tuple[int, tuple[datetime, State | None]]
|
||||||
|
] = asyncio.PriorityQueue()
|
||||||
self._listener_remover: Callable[[], None] | None = None
|
self._listener_remover: Callable[[], None] | None = None
|
||||||
self._next_send_remover: Callable[[], None] | None = None
|
self._next_send_remover: Callable[[], None] | None = None
|
||||||
self.shutdown = False
|
|
||||||
|
|
||||||
async def async_start(self) -> None:
|
async def async_start(self) -> None:
|
||||||
"""Start the hub.
|
"""Start the hub.
|
||||||
|
|
||||||
This suppresses logging and register the listener and
|
This suppresses logging and register the listener and
|
||||||
schedules the first send.
|
schedules the first send.
|
||||||
|
|
||||||
|
Suppress the INFO and below logging on the underlying packages,
|
||||||
|
they are very verbose, even at INFO.
|
||||||
"""
|
"""
|
||||||
# suppress the INFO and below logging on the underlying packages,
|
|
||||||
# they are very verbose, even at INFO
|
|
||||||
logging.getLogger("uamqp").setLevel(logging.WARNING)
|
logging.getLogger("uamqp").setLevel(logging.WARNING)
|
||||||
logging.getLogger("azure.eventhub").setLevel(logging.WARNING)
|
logging.getLogger("azure.eventhub").setLevel(logging.WARNING)
|
||||||
|
|
||||||
self._listener_remover = self.hass.bus.async_listen(
|
self._listener_remover = self.hass.bus.async_listen(
|
||||||
MATCH_ALL, self.async_listen
|
MATCH_ALL, self.async_listen
|
||||||
)
|
)
|
||||||
# schedule the first send after 10 seconds to capture startup events,
|
self._schedule_next_send()
|
||||||
# after that each send will schedule the next after the interval.
|
|
||||||
self._next_send_remover = async_call_later(
|
|
||||||
self.hass, self._send_interval, self.async_send
|
|
||||||
)
|
|
||||||
|
|
||||||
async def async_stop(self) -> None:
|
async def async_stop(self) -> None:
|
||||||
"""Shut down the AEH by queueing None and calling send."""
|
"""Shut down the AEH by queueing None, calling send, join queue."""
|
||||||
if self._next_send_remover:
|
if self._next_send_remover:
|
||||||
self._next_send_remover()
|
self._next_send_remover()
|
||||||
if self._listener_remover:
|
if self._listener_remover:
|
||||||
self._listener_remover()
|
self._listener_remover()
|
||||||
await self.queue.put((3, (time.monotonic(), None)))
|
await self._queue.put((3, (utcnow(), None)))
|
||||||
await self.async_send(None)
|
await self.async_send(None)
|
||||||
|
await self._queue.join()
|
||||||
|
|
||||||
|
def update_options(self, new_options: dict[str, Any]) -> None:
|
||||||
|
"""Update options."""
|
||||||
|
self._send_interval = new_options[CONF_SEND_INTERVAL]
|
||||||
|
|
||||||
async def async_test_connection(self) -> None:
|
async def async_test_connection(self) -> None:
|
||||||
"""Test the connection to the event hub."""
|
"""Test the connection to the event hub."""
|
||||||
await self._client.test_connection()
|
await self._client.test_connection()
|
||||||
|
|
||||||
async def async_listen(self, event: Event) -> None:
|
def _schedule_next_send(self) -> None:
|
||||||
"""Listen for new messages on the bus and queue them for AEH."""
|
"""Schedule the next send."""
|
||||||
await self.queue.put((2, (time.monotonic(), event)))
|
if not self._shutdown:
|
||||||
|
|
||||||
async def async_send(self, _) -> None:
|
|
||||||
"""Write preprocessed events to eventhub, with retry."""
|
|
||||||
async with self._client.client as client:
|
|
||||||
while not self.queue.empty():
|
|
||||||
data_batch, dequeue_count = await self.fill_batch(client)
|
|
||||||
_LOGGER.debug(
|
|
||||||
"Sending %d event(s), out of %d events in the queue",
|
|
||||||
len(data_batch),
|
|
||||||
dequeue_count,
|
|
||||||
)
|
|
||||||
if data_batch:
|
|
||||||
try:
|
|
||||||
await client.send_batch(data_batch)
|
|
||||||
except EventHubError as exc:
|
|
||||||
_LOGGER.error("Error in sending events to Event Hub: %s", exc)
|
|
||||||
finally:
|
|
||||||
for _ in range(dequeue_count):
|
|
||||||
self.queue.task_done()
|
|
||||||
|
|
||||||
if not self.shutdown:
|
|
||||||
self._next_send_remover = async_call_later(
|
self._next_send_remover = async_call_later(
|
||||||
self.hass, self._send_interval, self.async_send
|
self.hass, self._send_interval, self.async_send
|
||||||
)
|
)
|
||||||
|
|
||||||
async def fill_batch(self, client) -> tuple[EventDataBatch, int]:
|
async def async_listen(self, event: Event) -> None:
|
||||||
"""Return a batch of events formatted for writing.
|
"""Listen for new messages on the bus and queue them for AEH."""
|
||||||
|
if state := event.data.get("new_state"):
|
||||||
|
await self._queue.put((2, (event.time_fired, state)))
|
||||||
|
|
||||||
|
async def async_send(self, _) -> None:
|
||||||
|
"""Write preprocessed events to eventhub, with retry."""
|
||||||
|
async with self._client.client as client:
|
||||||
|
while not self._queue.empty():
|
||||||
|
if event_batch := await self.fill_batch(client):
|
||||||
|
_LOGGER.debug("Sending %d event(s)", len(event_batch))
|
||||||
|
try:
|
||||||
|
await client.send_batch(event_batch)
|
||||||
|
except EventHubError as exc:
|
||||||
|
_LOGGER.error("Error in sending events to Event Hub: %s", exc)
|
||||||
|
self._schedule_next_send()
|
||||||
|
|
||||||
|
async def fill_batch(self, client: EventHubProducerClient) -> EventDataBatch:
|
||||||
|
"""Return a batch of events formatted for sending to Event Hub.
|
||||||
|
|
||||||
Uses get_nowait instead of await get, because the functions batches and
|
Uses get_nowait instead of await get, because the functions batches and
|
||||||
doesn't wait for each single event, the send function is called.
|
doesn't wait for each single event.
|
||||||
|
|
||||||
Throws ValueError on add to batch when the EventDataBatch object reaches
|
Throws ValueError on add to batch when the EventDataBatch object reaches
|
||||||
max_size. Put the item back in the queue and the next batch will include
|
max_size. Put the item back in the queue and the next batch will include
|
||||||
it.
|
it.
|
||||||
"""
|
"""
|
||||||
event_batch = await client.create_batch()
|
event_batch = await client.create_batch()
|
||||||
dequeue_count = 0
|
|
||||||
dropped = 0
|
dropped = 0
|
||||||
while not self.shutdown:
|
while not self._shutdown:
|
||||||
try:
|
try:
|
||||||
_, (timestamp, event) = self.queue.get_nowait()
|
_, event = self._queue.get_nowait()
|
||||||
except asyncio.QueueEmpty:
|
except asyncio.QueueEmpty:
|
||||||
break
|
break
|
||||||
dequeue_count += 1
|
event_data, dropped = self._parse_event(*event, dropped)
|
||||||
if not event:
|
|
||||||
self.shutdown = True
|
|
||||||
break
|
|
||||||
event_data = self._event_to_filtered_event_data(event)
|
|
||||||
if not event_data:
|
if not event_data:
|
||||||
continue
|
continue
|
||||||
if time.monotonic() - timestamp <= self._max_delay + self._send_interval:
|
try:
|
||||||
try:
|
event_batch.add(event_data)
|
||||||
event_batch.add(event_data)
|
except ValueError:
|
||||||
except ValueError:
|
self._queue.put_nowait((1, event))
|
||||||
dequeue_count -= 1
|
break
|
||||||
self.queue.task_done()
|
|
||||||
self.queue.put_nowait((1, (timestamp, event)))
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
dropped += 1
|
|
||||||
|
|
||||||
if dropped:
|
if dropped:
|
||||||
_LOGGER.warning(
|
_LOGGER.warning(
|
||||||
"Dropped %d old events, consider filtering messages", dropped
|
"Dropped %d old events, consider filtering messages", dropped
|
||||||
)
|
)
|
||||||
|
return event_batch
|
||||||
|
|
||||||
return event_batch, dequeue_count
|
def _parse_event(
|
||||||
|
self, time_fired: datetime, state: State | None, dropped: int
|
||||||
def _event_to_filtered_event_data(self, event: Event) -> EventData | None:
|
) -> tuple[EventData | None, int]:
|
||||||
"""Filter event states and create EventData object."""
|
"""Parse event by checking if it needs to be sent, and format it."""
|
||||||
state = event.data.get("new_state")
|
self._queue.task_done()
|
||||||
if (
|
if not state:
|
||||||
state is None
|
self._shutdown = True
|
||||||
or state.state in (STATE_UNKNOWN, "", STATE_UNAVAILABLE)
|
return None, dropped
|
||||||
or not self._entities_filter(state.entity_id)
|
if state.state in FILTER_STATES or not self._entities_filter(state.entity_id):
|
||||||
):
|
return None, dropped
|
||||||
return None
|
if (utcnow() - time_fired).seconds > self._max_delay + self._send_interval:
|
||||||
return EventData(json.dumps(obj=state, cls=JSONEncoder).encode("utf-8"))
|
return None, dropped + 1
|
||||||
|
return (
|
||||||
def update_options(self, new_options: dict[str, Any]) -> None:
|
EventData(json.dumps(obj=state, cls=JSONEncoder).encode("utf-8")),
|
||||||
"""Update options."""
|
dropped,
|
||||||
self._send_interval = new_options[CONF_SEND_INTERVAL]
|
)
|
||||||
|
|
|
@ -3,6 +3,8 @@ from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from homeassistant.const import STATE_UNAVAILABLE, STATE_UNKNOWN
|
||||||
|
|
||||||
DOMAIN = "azure_event_hub"
|
DOMAIN = "azure_event_hub"
|
||||||
|
|
||||||
CONF_USE_CONN_STRING = "use_connection_string"
|
CONF_USE_CONN_STRING = "use_connection_string"
|
||||||
|
@ -27,3 +29,4 @@ DEFAULT_OPTIONS: dict[str, Any] = {
|
||||||
}
|
}
|
||||||
|
|
||||||
ADDITIONAL_ARGS: dict[str, Any] = {"logging_enable": False}
|
ADDITIONAL_ARGS: dict[str, Any] = {"logging_enable": False}
|
||||||
|
FILTER_STATES = (STATE_UNKNOWN, STATE_UNAVAILABLE, "")
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
"""Test the init functions for AEH."""
|
"""Test the init functions for AEH."""
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
import logging
|
import logging
|
||||||
from time import monotonic
|
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
from azure.eventhub.exceptions import EventHubError
|
from azure.eventhub.exceptions import EventHubError
|
||||||
|
@ -96,7 +95,7 @@ async def test_send_batch_error(hass, entry_with_one_event, mock_send_batch):
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
mock_send_batch.assert_called_once()
|
mock_send_batch.assert_called_once()
|
||||||
mock_send_batch.reset_mock()
|
mock_send_batch.reset_mock()
|
||||||
|
hass.states.async_set("sensor.test2", STATE_ON)
|
||||||
async_fire_time_changed(
|
async_fire_time_changed(
|
||||||
hass,
|
hass,
|
||||||
utcnow() + timedelta(seconds=entry_with_one_event.options[CONF_SEND_INTERVAL]),
|
utcnow() + timedelta(seconds=entry_with_one_event.options[CONF_SEND_INTERVAL]),
|
||||||
|
@ -108,8 +107,8 @@ async def test_send_batch_error(hass, entry_with_one_event, mock_send_batch):
|
||||||
async def test_late_event(hass, entry_with_one_event, mock_create_batch):
|
async def test_late_event(hass, entry_with_one_event, mock_create_batch):
|
||||||
"""Test the check on late events."""
|
"""Test the check on late events."""
|
||||||
with patch(
|
with patch(
|
||||||
f"{AZURE_EVENT_HUB_PATH}.time.monotonic",
|
f"{AZURE_EVENT_HUB_PATH}.utcnow",
|
||||||
return_value=monotonic() + timedelta(hours=1).seconds,
|
return_value=utcnow() + timedelta(hours=1),
|
||||||
):
|
):
|
||||||
async_fire_time_changed(
|
async_fire_time_changed(
|
||||||
hass,
|
hass,
|
||||||
|
|
Loading…
Reference in New Issue