Refresh Nest WebRTC streams before expiration (#129478)

pull/129502/head
Allen Porter 2024-10-30 06:25:43 -07:00 committed by GitHub
parent 405a480cae
commit 6c047e2678
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 172 additions and 24 deletions

View File

@ -14,6 +14,7 @@ from google_nest_sdm.camera_traits import (
CameraImageTrait,
CameraLiveStreamTrait,
RtspStream,
Stream,
StreamingProtocol,
WebRtcStream,
)
@ -78,7 +79,8 @@ class NestCamera(Camera):
self._attr_device_info = nest_device_info.device_info
self._attr_brand = nest_device_info.device_brand
self._attr_model = nest_device_info.device_model
self._stream: RtspStream | None = None
self._rtsp_stream: RtspStream | None = None
self._webrtc_sessions: dict[str, WebRtcStream] = {}
self._create_stream_url_lock = asyncio.Lock()
self._stream_refresh_unsub: Callable[[], None] | None = None
self._attr_is_streaming = False
@ -95,7 +97,6 @@ class NestCamera(Camera):
self.stream_options[CONF_EXTRA_PART_WAIT_TIME] = 3
# The API "name" field is a unique device identifier.
self._attr_unique_id = f"{self._device.name}-camera"
self._webrtc_sessions: dict[str, WebRtcStream] = {}
@property
def use_stream_for_stills(self) -> bool:
@ -127,65 +128,107 @@ class NestCamera(Camera):
if not self._rtsp_live_stream_trait:
return None
async with self._create_stream_url_lock:
if not self._stream:
if not self._rtsp_stream:
_LOGGER.debug("Fetching stream url")
try:
self._stream = (
self._rtsp_stream = (
await self._rtsp_live_stream_trait.generate_rtsp_stream()
)
except ApiException as err:
raise HomeAssistantError(f"Nest API error: {err}") from err
self._schedule_stream_refresh()
assert self._stream
if self._stream.expires_at < utcnow():
assert self._rtsp_stream
if self._rtsp_stream.expires_at < utcnow():
_LOGGER.warning("Stream already expired")
return self._stream.rtsp_stream_url
return self._rtsp_stream.rtsp_stream_url
def _all_streams(self) -> list[Stream]:
"""Return the current list of active streams."""
streams: list[Stream] = []
if self._rtsp_stream:
streams.append(self._rtsp_stream)
streams.extend(list(self._webrtc_sessions.values()))
return streams
def _schedule_stream_refresh(self) -> None:
"""Schedules an alarm to refresh the stream url before expiration."""
assert self._stream
_LOGGER.debug("New stream url expires at %s", self._stream.expires_at)
refresh_time = self._stream.expires_at - STREAM_EXPIRATION_BUFFER
"""Schedules an alarm to refresh any streams before expiration."""
# Schedule an alarm to extend the stream
if self._stream_refresh_unsub is not None:
self._stream_refresh_unsub()
_LOGGER.debug("Scheduling next stream refresh")
expiration_times = [stream.expires_at for stream in self._all_streams()]
if not expiration_times:
_LOGGER.debug("No streams to refresh")
return
refresh_time = min(expiration_times) - STREAM_EXPIRATION_BUFFER
_LOGGER.debug("Scheduled next stream refresh for %s", refresh_time)
self._stream_refresh_unsub = async_track_point_in_utc_time(
self.hass,
self._handle_stream_refresh,
refresh_time,
)
async def _handle_stream_refresh(self, now: datetime.datetime) -> None:
async def _handle_stream_refresh(self, _: datetime.datetime) -> None:
"""Alarm that fires to check if the stream should be refreshed."""
if not self._stream:
_LOGGER.debug("Examining streams to refresh")
await self._handle_rtsp_stream_refresh()
await self._handle_webrtc_stream_refresh()
self._schedule_stream_refresh()
async def _handle_rtsp_stream_refresh(self) -> None:
"""Alarm that fires to check if the stream should be refreshed."""
if not self._rtsp_stream:
return
_LOGGER.debug("Extending stream url")
now = utcnow()
refresh_time = self._rtsp_stream.expires_at - STREAM_EXPIRATION_BUFFER
if now < refresh_time:
return
_LOGGER.debug("Extending RTSP stream")
try:
self._stream = await self._stream.extend_rtsp_stream()
self._rtsp_stream = await self._rtsp_stream.extend_rtsp_stream()
except ApiException as err:
_LOGGER.debug("Failed to extend stream: %s", err)
# Next attempt to catch a url will get a new one
self._stream = None
self._rtsp_stream = None
if self.stream:
await self.stream.stop()
self.stream = None
return
# Update the stream worker with the latest valid url
if self.stream:
self.stream.update_source(self._stream.rtsp_stream_url)
self._schedule_stream_refresh()
self.stream.update_source(self._rtsp_stream.rtsp_stream_url)
async def _handle_webrtc_stream_refresh(self) -> None:
"""Alarm that fires to check if the stream should be refreshed."""
now = utcnow()
for webrtc_stream in list(self._webrtc_sessions.values()):
if now < (webrtc_stream.expires_at - STREAM_EXPIRATION_BUFFER):
_LOGGER.debug(
"Stream does not yet expire: %s", webrtc_stream.expires_at
)
continue
_LOGGER.debug("Extending WebRTC stream %s", webrtc_stream.media_session_id)
try:
webrtc_stream = await webrtc_stream.extend_stream()
except ApiException as err:
_LOGGER.debug("Failed to extend stream: %s", err)
else:
self._webrtc_sessions[webrtc_stream.media_session_id] = webrtc_stream
async def async_will_remove_from_hass(self) -> None:
"""Invalidates the RTSP token when unloaded."""
if self._stream:
for stream in self._all_streams():
_LOGGER.debug("Invalidating stream")
try:
await self._stream.stop_rtsp_stream()
await stream.stop_stream()
except ApiException as err:
_LOGGER.debug(
"Failed to revoke stream token, will rely on ttl: %s", err
)
_LOGGER.debug("Error stopping stream: %s", err)
self._rtsp_stream = None
self._webrtc_sessions.clear()
if self._stream_refresh_unsub:
self._stream_refresh_unsub()
@ -223,14 +266,28 @@ class NestCamera(Camera):
stream = await trait.generate_web_rtc_stream(offer_sdp)
except ApiException as err:
raise HomeAssistantError(f"Nest API error: {err}") from err
_LOGGER.debug(
"Started WebRTC session %s, %s", session_id, stream.media_session_id
)
self._webrtc_sessions[session_id] = stream
send_message(WebRTCAnswer(stream.answer_sdp))
self._schedule_stream_refresh()
@callback
def close_webrtc_session(self, session_id: str) -> None:
"""Close a WebRTC session."""
if (stream := self._webrtc_sessions.pop(session_id, None)) is not None:
self.hass.async_create_task(stream.stop_stream())
_LOGGER.debug(
"Closing WebRTC session %s, %s", session_id, stream.media_session_id
)
async def stop_stream() -> None:
try:
await stream.stop_stream()
except ApiException as err:
_LOGGER.debug("Error stopping stream: %s", err)
self.hass.async_create_task(stop_stream())
super().close_webrtc_session(session_id)
@callback

View File

@ -803,3 +803,94 @@ async def test_camera_multiple_streams(
"type": "answer",
"answer": "v=0\r\ns=-\r\n",
}
@pytest.mark.usefixtures("webrtc_camera_device")
async def test_webrtc_refresh_expired_stream(
hass: HomeAssistant,
setup_platform: PlatformSetup,
hass_ws_client: WebSocketGenerator,
auth: FakeAuth,
) -> None:
"""Test a camera webrtc expiration and refresh."""
now = utcnow()
stream_1_expiration = now + datetime.timedelta(seconds=90)
stream_2_expiration = now + datetime.timedelta(seconds=180)
auth.responses = [
aiohttp.web.json_response(
{
"results": {
"answerSdp": "v=0\r\ns=-\r\n",
"mediaSessionId": "yP2grqz0Y1V_wgiX9KEbMWHoLd...",
"expiresAt": stream_1_expiration.isoformat(timespec="seconds"),
},
}
),
aiohttp.web.json_response(
{
"results": {
"mediaSessionId": "yP2grqz0Y1V_wgiX9KEbMWHoLd...",
"expiresAt": stream_2_expiration.isoformat(timespec="seconds"),
},
}
),
]
await setup_platform()
await hass.async_block_till_done()
assert len(hass.states.async_all()) == 1
cam = hass.states.get("camera.my_camera")
assert cam is not None
assert cam.state == CameraState.STREAMING
assert cam.attributes["frontend_stream_type"] == StreamType.WEB_RTC
client = await hass_ws_client(hass)
await client.send_json_auto_id(
{
"type": "camera/webrtc/offer",
"entity_id": "camera.my_camera",
"offer": "a=recvonly",
}
)
response = await client.receive_json()
assert response["type"] == TYPE_RESULT
assert response["success"]
subscription_id = response["id"]
# Session id
response = await client.receive_json()
assert response["id"] == subscription_id
assert response["type"] == "event"
assert response["event"]["type"] == "session"
# Answer
response = await client.receive_json()
assert response["id"] == subscription_id
assert response["type"] == "event"
assert response["event"] == {
"type": "answer",
"answer": "v=0\r\ns=-\r\n",
}
assert len(auth.captured_requests) == 1
assert (
auth.captured_requests[0][2].get("command")
== "sdm.devices.commands.CameraLiveStream.GenerateWebRtcStream"
)
# Fire alarm before stream_1_expiration. The stream url is not refreshed
next_update = now + datetime.timedelta(seconds=25)
await fire_alarm(hass, next_update)
assert len(auth.captured_requests) == 1
# Alarm is near stream_1_expiration which causes the stream extension
next_update = now + datetime.timedelta(seconds=60)
await fire_alarm(hass, next_update)
assert len(auth.captured_requests) >= 2
assert (
auth.captured_requests[1][2].get("command")
== "sdm.devices.commands.CameraLiveStream.ExtendWebRtcStream"
)