"""Utility functions to combine state attributes from multiple entities.""" from __future__ import annotations from collections.abc import Callable, Iterator from itertools import groupby from typing import Any from homeassistant.core import State def find_state_attributes(states: list[State], key: str) -> Iterator[Any]: """Find attributes with matching key from states.""" for state in states: if (value := state.attributes.get(key)) is not None: yield value def find_state(states: list[State]) -> Iterator[Any]: """Find state from states.""" for state in states: yield state.state def mean_int(*args: Any) -> int: """Return the mean of the supplied values.""" return int(sum(args) / len(args)) def mean_tuple(*args: Any) -> tuple[float | Any, ...]: """Return the mean values along the columns of the supplied values.""" return tuple(sum(x) / len(x) for x in zip(*args)) def attribute_equal(states: list[State], key: str) -> bool: """Return True if all attributes found matching key from states are equal. Note: Returns True if no matching attribute is found. """ return _values_equal(find_state_attributes(states, key)) def most_frequent_attribute(states: list[State], key: str) -> Any | None: """Find attributes with matching key from states.""" if attrs := list(find_state_attributes(states, key)): return max(set(attrs), key=attrs.count) return None def states_equal(states: list[State]) -> bool: """Return True if all states are equal. Note: Returns True if no matching attribute is found. """ return _values_equal(find_state(states)) def _values_equal(values: Iterator[Any]) -> bool: """Return True if all values are equal. Note: Returns True if no matching attribute is found. """ grp = groupby(values) return bool(next(grp, True) and not next(grp, False)) def reduce_attribute( states: list[State], key: str, default: Any | None = None, reduce: Callable[..., Any] = mean_int, ) -> Any: """Find the first attribute matching key from states. If none are found, return default. """ attrs = list(find_state_attributes(states, key)) if not attrs: return default if len(attrs) == 1: return attrs[0] return reduce(*attrs)