Fix local calendar issue with events created with fixed UTC offsets (#88650)

Fix issue with events created with UTC offsets
pull/88679/head
Allen Porter 2023-02-23 10:37:15 -08:00 committed by GitHub
parent 5739782877
commit e1e0400b16
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 141 additions and 36 deletions

View File

@ -66,6 +66,55 @@ SCAN_INTERVAL = datetime.timedelta(seconds=60)
# Don't support rrules more often than daily
VALID_FREQS = {"DAILY", "WEEKLY", "MONTHLY", "YEARLY"}
def _has_consistent_timezone(*keys: Any) -> Callable[[dict[str, Any]], dict[str, Any]]:
"""Verify that all datetime values have a consistent timezone."""
def validate(obj: dict[str, Any]) -> dict[str, Any]:
"""Test that all keys that are datetime values have the same timezone."""
tzinfos = []
for key in keys:
if not (value := obj.get(key)) or not isinstance(value, datetime.datetime):
return obj
tzinfos.append(value.tzinfo)
uniq_values = groupby(tzinfos)
if len(list(uniq_values)) > 1:
raise vol.Invalid("Expected all values to have the same timezone")
return obj
return validate
def _as_local_timezone(*keys: Any) -> Callable[[dict[str, Any]], dict[str, Any]]:
"""Convert all datetime values to the local timezone."""
def validate(obj: dict[str, Any]) -> dict[str, Any]:
"""Test that all keys that are datetime values have the same timezone."""
for k in keys:
if (value := obj.get(k)) and isinstance(value, datetime.datetime):
obj[k] = dt.as_local(value)
return obj
return validate
def _is_sorted(*keys: Any) -> Callable[[dict[str, Any]], dict[str, Any]]:
"""Verify that the specified values are sequential."""
def validate(obj: dict[str, Any]) -> dict[str, Any]:
"""Test that all keys in the dict are in order."""
values = []
for k in keys:
if not (value := obj.get(k)):
return obj
values.append(value)
if all(values) and values != sorted(values):
raise vol.Invalid(f"Values were not in order: {values}")
return obj
return validate
CREATE_EVENT_SERVICE = "create_event"
CREATE_EVENT_SCHEMA = vol.All(
cv.has_at_least_one_key(EVENT_START_DATE, EVENT_START_DATETIME, EVENT_IN),
@ -98,6 +147,10 @@ CREATE_EVENT_SCHEMA = vol.All(
),
},
),
_has_consistent_timezone(EVENT_START_DATETIME, EVENT_END_DATETIME),
_as_local_timezone(EVENT_START_DATETIME, EVENT_END_DATETIME),
_is_sorted(EVENT_START_DATE, EVENT_END_DATE),
_is_sorted(EVENT_START_DATETIME, EVENT_END_DATETIME),
)
@ -441,36 +494,6 @@ def _has_same_type(*keys: Any) -> Callable[[dict[str, Any]], dict[str, Any]]:
return validate
def _has_consistent_timezone(*keys: Any) -> Callable[[dict[str, Any]], dict[str, Any]]:
"""Verify that all datetime values have a consistent timezone."""
def validate(obj: dict[str, Any]) -> dict[str, Any]:
"""Test that all keys that are datetime values have the same timezone."""
values = [obj[k] for k in keys]
if all(isinstance(value, datetime.datetime) for value in values):
uniq_values = groupby(value.tzinfo for value in values)
if len(list(uniq_values)) > 1:
raise vol.Invalid(
f"Expected all values to have the same timezone: {values}"
)
return obj
return validate
def _is_sorted(*keys: Any) -> Callable[[dict[str, Any]], dict[str, Any]]:
"""Verify that the specified values are sequential."""
def validate(obj: dict[str, Any]) -> dict[str, Any]:
"""Test that all keys in the dict are in order."""
values = [obj[k] for k in keys]
if values != sorted(values):
raise vol.Invalid(f"Values were not in order: {values}")
return obj
return validate
@websocket_api.websocket_command(
{
vol.Required("type"): "calendar/event/create",
@ -486,6 +509,7 @@ def _is_sorted(*keys: Any) -> Callable[[dict[str, Any]], dict[str, Any]]:
},
_has_same_type(EVENT_START, EVENT_END),
_has_consistent_timezone(EVENT_START, EVENT_END),
_as_local_timezone(EVENT_START, EVENT_END),
_is_sorted(EVENT_START, EVENT_END),
)
),
@ -582,6 +606,7 @@ async def handle_calendar_event_delete(
},
_has_same_type(EVENT_START, EVENT_END),
_has_consistent_timezone(EVENT_START, EVENT_END),
_as_local_timezone(EVENT_START, EVENT_END),
_is_sorted(EVENT_START, EVENT_END),
)
),

View File

@ -15,7 +15,9 @@ from pydantic import ValidationError
import voluptuous as vol
from homeassistant.components.calendar import (
EVENT_END,
EVENT_RRULE,
EVENT_START,
CalendarEntity,
CalendarEntityFeature,
CalendarEvent,
@ -151,6 +153,21 @@ def _parse_event(event: dict[str, Any]) -> Event:
"""Parse an ical event from a home assistant event dictionary."""
if rrule := event.get(EVENT_RRULE):
event[EVENT_RRULE] = Recur.from_rrule(rrule)
# This function is called with new events created in the local timezone,
# however ical library does not properly return recurrence_ids for
# start dates with a timezone. For now, ensure any datetime is stored as a
# floating local time to ensure we still apply proper local timezone rules.
# This can be removed when ical is updated with a new recurrence_id format
# https://github.com/home-assistant/core/issues/87759
for key in (EVENT_START, EVENT_END):
if (
(value := event[key])
and isinstance(value, datetime)
and value.tzinfo is not None
):
event[key] = dt_util.as_local(value).replace(tzinfo=None)
try:
return Event.parse_obj(event)
except ValidationError as err:
@ -162,8 +179,12 @@ def _get_calendar_event(event: Event) -> CalendarEvent:
"""Return a CalendarEvent from an API event."""
return CalendarEvent(
summary=event.summary,
start=event.start,
end=event.end,
start=dt_util.as_local(event.start)
if isinstance(event.start, datetime)
else event.start,
end=dt_util.as_local(event.end)
if isinstance(event.end, datetime)
else event.end,
description=event.description,
uid=event.uid,
rrule=event.rrule.as_rrule_str() if event.rrule else None,

View File

@ -310,6 +310,30 @@ async def test_unsupported_create_event_service(hass: HomeAssistant) -> None:
vol.error.MultipleInvalid,
"must contain at most one of start_date, start_date_time, in.",
),
(
{
"start_date_time": "2022-04-01T06:00:00+00:00",
"end_date_time": "2022-04-01T07:00:00+01:00",
},
vol.error.MultipleInvalid,
"Expected all values to have the same timezone",
),
(
{
"start_date_time": "2022-04-01T07:00:00",
"end_date_time": "2022-04-01T06:00:00",
},
vol.error.MultipleInvalid,
"Values were not in order",
),
(
{
"start_date": "2022-04-02",
"end_date": "2022-04-01",
},
vol.error.MultipleInvalid,
"Values were not in order",
),
],
ids=[
"missing_all",
@ -324,6 +348,9 @@ async def test_unsupported_create_event_service(hass: HomeAssistant) -> None:
"multiple_in",
"unexpected_in_with_date",
"unexpected_in_with_datetime",
"inconsistent_timezone",
"incorrect_date_order",
"incorrect_datetime_order",
],
)
async def test_create_event_service_invalid_params(

View File

@ -48,8 +48,12 @@ class FakeStore(LocalCalendarStore):
def mock_store() -> None:
"""Test cleanup, remove any media storage persisted during the test."""
stores: dict[Path, FakeStore] = {}
def new_store(hass: HomeAssistant, path: Path) -> FakeStore:
return FakeStore(hass, path)
if path not in stores:
stores[path] = FakeStore(hass, path)
return stores[path]
with patch(
"homeassistant.components.local_calendar.LocalCalendarStore", new=new_store
@ -961,8 +965,20 @@ async def test_update_invalid_event_id(
assert resp.get("error").get("code") == "failed"
@pytest.mark.parametrize(
("start_date_time", "end_date_time"),
[
("1997-07-14T17:00:00+00:00", "1997-07-15T04:00:00+00:00"),
("1997-07-14T11:00:00-06:00", "1997-07-14T22:00:00-06:00"),
],
)
async def test_create_event_service(
hass: HomeAssistant, setup_integration: None, get_events: GetEventsFn
hass: HomeAssistant,
setup_integration: None,
get_events: GetEventsFn,
start_date_time: str,
end_date_time: str,
config_entry: MockConfigEntry,
) -> None:
"""Test creating an event using the create_event service."""
@ -970,13 +986,15 @@ async def test_create_event_service(
"calendar",
"create_event",
{
"start_date_time": "1997-07-14T17:00:00+00:00",
"end_date_time": "1997-07-15T04:00:00+00:00",
"start_date_time": start_date_time,
"end_date_time": end_date_time,
"summary": "Bastille Day Party",
},
target={"entity_id": TEST_ENTITY},
blocking=True,
)
# Ensure data is written to disk
await hass.async_block_till_done()
events = await get_events("1997-07-14T00:00:00Z", "1997-07-16T00:00:00Z")
assert list(map(event_fields, events)) == [
@ -995,3 +1013,17 @@ async def test_create_event_service(
"end": {"dateTime": "1997-07-14T22:00:00-06:00"},
}
]
# Reload the config entry, which reloads the content from the store and
# verifies that the persisted data can be parsed correctly.
await hass.config_entries.async_reload(config_entry.entry_id)
await hass.async_block_till_done()
events = await get_events("1997-07-13T00:00:00Z", "1997-07-14T18:00:00Z")
assert list(map(event_fields, events)) == [
{
"summary": "Bastille Day Party",
"start": {"dateTime": "1997-07-14T11:00:00-06:00"},
"end": {"dateTime": "1997-07-14T22:00:00-06:00"},
}
]