diff --git a/homeassistant/components/calendar/__init__.py b/homeassistant/components/calendar/__init__.py index 390e14d1689..c77d6c9c67a 100644 --- a/homeassistant/components/calendar/__init__.py +++ b/homeassistant/components/calendar/__init__.py @@ -66,6 +66,55 @@ SCAN_INTERVAL = datetime.timedelta(seconds=60) # Don't support rrules more often than daily VALID_FREQS = {"DAILY", "WEEKLY", "MONTHLY", "YEARLY"} + +def _has_consistent_timezone(*keys: Any) -> Callable[[dict[str, Any]], dict[str, Any]]: + """Verify that all datetime values have a consistent timezone.""" + + def validate(obj: dict[str, Any]) -> dict[str, Any]: + """Test that all keys that are datetime values have the same timezone.""" + tzinfos = [] + for key in keys: + if not (value := obj.get(key)) or not isinstance(value, datetime.datetime): + return obj + tzinfos.append(value.tzinfo) + uniq_values = groupby(tzinfos) + if len(list(uniq_values)) > 1: + raise vol.Invalid("Expected all values to have the same timezone") + return obj + + return validate + + +def _as_local_timezone(*keys: Any) -> Callable[[dict[str, Any]], dict[str, Any]]: + """Convert all datetime values to the local timezone.""" + + def validate(obj: dict[str, Any]) -> dict[str, Any]: + """Test that all keys that are datetime values have the same timezone.""" + for k in keys: + if (value := obj.get(k)) and isinstance(value, datetime.datetime): + obj[k] = dt.as_local(value) + return obj + + return validate + + +def _is_sorted(*keys: Any) -> Callable[[dict[str, Any]], dict[str, Any]]: + """Verify that the specified values are sequential.""" + + def validate(obj: dict[str, Any]) -> dict[str, Any]: + """Test that all keys in the dict are in order.""" + values = [] + for k in keys: + if not (value := obj.get(k)): + return obj + values.append(value) + if all(values) and values != sorted(values): + raise vol.Invalid(f"Values were not in order: {values}") + return obj + + return validate + + CREATE_EVENT_SERVICE = "create_event" CREATE_EVENT_SCHEMA = vol.All( cv.has_at_least_one_key(EVENT_START_DATE, EVENT_START_DATETIME, EVENT_IN), @@ -98,6 +147,10 @@ CREATE_EVENT_SCHEMA = vol.All( ), }, ), + _has_consistent_timezone(EVENT_START_DATETIME, EVENT_END_DATETIME), + _as_local_timezone(EVENT_START_DATETIME, EVENT_END_DATETIME), + _is_sorted(EVENT_START_DATE, EVENT_END_DATE), + _is_sorted(EVENT_START_DATETIME, EVENT_END_DATETIME), ) @@ -441,36 +494,6 @@ def _has_same_type(*keys: Any) -> Callable[[dict[str, Any]], dict[str, Any]]: return validate -def _has_consistent_timezone(*keys: Any) -> Callable[[dict[str, Any]], dict[str, Any]]: - """Verify that all datetime values have a consistent timezone.""" - - def validate(obj: dict[str, Any]) -> dict[str, Any]: - """Test that all keys that are datetime values have the same timezone.""" - values = [obj[k] for k in keys] - if all(isinstance(value, datetime.datetime) for value in values): - uniq_values = groupby(value.tzinfo for value in values) - if len(list(uniq_values)) > 1: - raise vol.Invalid( - f"Expected all values to have the same timezone: {values}" - ) - return obj - - return validate - - -def _is_sorted(*keys: Any) -> Callable[[dict[str, Any]], dict[str, Any]]: - """Verify that the specified values are sequential.""" - - def validate(obj: dict[str, Any]) -> dict[str, Any]: - """Test that all keys in the dict are in order.""" - values = [obj[k] for k in keys] - if values != sorted(values): - raise vol.Invalid(f"Values were not in order: {values}") - return obj - - return validate - - @websocket_api.websocket_command( { vol.Required("type"): "calendar/event/create", @@ -486,6 +509,7 @@ def _is_sorted(*keys: Any) -> Callable[[dict[str, Any]], dict[str, Any]]: }, _has_same_type(EVENT_START, EVENT_END), _has_consistent_timezone(EVENT_START, EVENT_END), + _as_local_timezone(EVENT_START, EVENT_END), _is_sorted(EVENT_START, EVENT_END), ) ), @@ -582,6 +606,7 @@ async def handle_calendar_event_delete( }, _has_same_type(EVENT_START, EVENT_END), _has_consistent_timezone(EVENT_START, EVENT_END), + _as_local_timezone(EVENT_START, EVENT_END), _is_sorted(EVENT_START, EVENT_END), ) ), diff --git a/homeassistant/components/local_calendar/calendar.py b/homeassistant/components/local_calendar/calendar.py index be6fb4a17b5..88737150c02 100644 --- a/homeassistant/components/local_calendar/calendar.py +++ b/homeassistant/components/local_calendar/calendar.py @@ -15,7 +15,9 @@ from pydantic import ValidationError import voluptuous as vol from homeassistant.components.calendar import ( + EVENT_END, EVENT_RRULE, + EVENT_START, CalendarEntity, CalendarEntityFeature, CalendarEvent, @@ -151,6 +153,21 @@ def _parse_event(event: dict[str, Any]) -> Event: """Parse an ical event from a home assistant event dictionary.""" if rrule := event.get(EVENT_RRULE): event[EVENT_RRULE] = Recur.from_rrule(rrule) + + # This function is called with new events created in the local timezone, + # however ical library does not properly return recurrence_ids for + # start dates with a timezone. For now, ensure any datetime is stored as a + # floating local time to ensure we still apply proper local timezone rules. + # This can be removed when ical is updated with a new recurrence_id format + # https://github.com/home-assistant/core/issues/87759 + for key in (EVENT_START, EVENT_END): + if ( + (value := event[key]) + and isinstance(value, datetime) + and value.tzinfo is not None + ): + event[key] = dt_util.as_local(value).replace(tzinfo=None) + try: return Event.parse_obj(event) except ValidationError as err: @@ -162,8 +179,12 @@ def _get_calendar_event(event: Event) -> CalendarEvent: """Return a CalendarEvent from an API event.""" return CalendarEvent( summary=event.summary, - start=event.start, - end=event.end, + start=dt_util.as_local(event.start) + if isinstance(event.start, datetime) + else event.start, + end=dt_util.as_local(event.end) + if isinstance(event.end, datetime) + else event.end, description=event.description, uid=event.uid, rrule=event.rrule.as_rrule_str() if event.rrule else None, diff --git a/tests/components/calendar/test_init.py b/tests/components/calendar/test_init.py index 806410c9834..5c90a1cfc2c 100644 --- a/tests/components/calendar/test_init.py +++ b/tests/components/calendar/test_init.py @@ -310,6 +310,30 @@ async def test_unsupported_create_event_service(hass: HomeAssistant) -> None: vol.error.MultipleInvalid, "must contain at most one of start_date, start_date_time, in.", ), + ( + { + "start_date_time": "2022-04-01T06:00:00+00:00", + "end_date_time": "2022-04-01T07:00:00+01:00", + }, + vol.error.MultipleInvalid, + "Expected all values to have the same timezone", + ), + ( + { + "start_date_time": "2022-04-01T07:00:00", + "end_date_time": "2022-04-01T06:00:00", + }, + vol.error.MultipleInvalid, + "Values were not in order", + ), + ( + { + "start_date": "2022-04-02", + "end_date": "2022-04-01", + }, + vol.error.MultipleInvalid, + "Values were not in order", + ), ], ids=[ "missing_all", @@ -324,6 +348,9 @@ async def test_unsupported_create_event_service(hass: HomeAssistant) -> None: "multiple_in", "unexpected_in_with_date", "unexpected_in_with_datetime", + "inconsistent_timezone", + "incorrect_date_order", + "incorrect_datetime_order", ], ) async def test_create_event_service_invalid_params( diff --git a/tests/components/local_calendar/test_calendar.py b/tests/components/local_calendar/test_calendar.py index c7eea20920f..f432fe3f977 100644 --- a/tests/components/local_calendar/test_calendar.py +++ b/tests/components/local_calendar/test_calendar.py @@ -48,8 +48,12 @@ class FakeStore(LocalCalendarStore): def mock_store() -> None: """Test cleanup, remove any media storage persisted during the test.""" + stores: dict[Path, FakeStore] = {} + def new_store(hass: HomeAssistant, path: Path) -> FakeStore: - return FakeStore(hass, path) + if path not in stores: + stores[path] = FakeStore(hass, path) + return stores[path] with patch( "homeassistant.components.local_calendar.LocalCalendarStore", new=new_store @@ -961,8 +965,20 @@ async def test_update_invalid_event_id( assert resp.get("error").get("code") == "failed" +@pytest.mark.parametrize( + ("start_date_time", "end_date_time"), + [ + ("1997-07-14T17:00:00+00:00", "1997-07-15T04:00:00+00:00"), + ("1997-07-14T11:00:00-06:00", "1997-07-14T22:00:00-06:00"), + ], +) async def test_create_event_service( - hass: HomeAssistant, setup_integration: None, get_events: GetEventsFn + hass: HomeAssistant, + setup_integration: None, + get_events: GetEventsFn, + start_date_time: str, + end_date_time: str, + config_entry: MockConfigEntry, ) -> None: """Test creating an event using the create_event service.""" @@ -970,13 +986,15 @@ async def test_create_event_service( "calendar", "create_event", { - "start_date_time": "1997-07-14T17:00:00+00:00", - "end_date_time": "1997-07-15T04:00:00+00:00", + "start_date_time": start_date_time, + "end_date_time": end_date_time, "summary": "Bastille Day Party", }, target={"entity_id": TEST_ENTITY}, blocking=True, ) + # Ensure data is written to disk + await hass.async_block_till_done() events = await get_events("1997-07-14T00:00:00Z", "1997-07-16T00:00:00Z") assert list(map(event_fields, events)) == [ @@ -995,3 +1013,17 @@ async def test_create_event_service( "end": {"dateTime": "1997-07-14T22:00:00-06:00"}, } ] + + # Reload the config entry, which reloads the content from the store and + # verifies that the persisted data can be parsed correctly. + await hass.config_entries.async_reload(config_entry.entry_id) + await hass.async_block_till_done() + + events = await get_events("1997-07-13T00:00:00Z", "1997-07-14T18:00:00Z") + assert list(map(event_fields, events)) == [ + { + "summary": "Bastille Day Party", + "start": {"dateTime": "1997-07-14T11:00:00-06:00"}, + "end": {"dateTime": "1997-07-14T22:00:00-06:00"}, + } + ]