Refresh Nest WebRTC streams before expiration (#129478)
parent
405a480cae
commit
6c047e2678
|
@ -14,6 +14,7 @@ from google_nest_sdm.camera_traits import (
|
|||
CameraImageTrait,
|
||||
CameraLiveStreamTrait,
|
||||
RtspStream,
|
||||
Stream,
|
||||
StreamingProtocol,
|
||||
WebRtcStream,
|
||||
)
|
||||
|
@ -78,7 +79,8 @@ class NestCamera(Camera):
|
|||
self._attr_device_info = nest_device_info.device_info
|
||||
self._attr_brand = nest_device_info.device_brand
|
||||
self._attr_model = nest_device_info.device_model
|
||||
self._stream: RtspStream | None = None
|
||||
self._rtsp_stream: RtspStream | None = None
|
||||
self._webrtc_sessions: dict[str, WebRtcStream] = {}
|
||||
self._create_stream_url_lock = asyncio.Lock()
|
||||
self._stream_refresh_unsub: Callable[[], None] | None = None
|
||||
self._attr_is_streaming = False
|
||||
|
@ -95,7 +97,6 @@ class NestCamera(Camera):
|
|||
self.stream_options[CONF_EXTRA_PART_WAIT_TIME] = 3
|
||||
# The API "name" field is a unique device identifier.
|
||||
self._attr_unique_id = f"{self._device.name}-camera"
|
||||
self._webrtc_sessions: dict[str, WebRtcStream] = {}
|
||||
|
||||
@property
|
||||
def use_stream_for_stills(self) -> bool:
|
||||
|
@ -127,65 +128,107 @@ class NestCamera(Camera):
|
|||
if not self._rtsp_live_stream_trait:
|
||||
return None
|
||||
async with self._create_stream_url_lock:
|
||||
if not self._stream:
|
||||
if not self._rtsp_stream:
|
||||
_LOGGER.debug("Fetching stream url")
|
||||
try:
|
||||
self._stream = (
|
||||
self._rtsp_stream = (
|
||||
await self._rtsp_live_stream_trait.generate_rtsp_stream()
|
||||
)
|
||||
except ApiException as err:
|
||||
raise HomeAssistantError(f"Nest API error: {err}") from err
|
||||
self._schedule_stream_refresh()
|
||||
assert self._stream
|
||||
if self._stream.expires_at < utcnow():
|
||||
assert self._rtsp_stream
|
||||
if self._rtsp_stream.expires_at < utcnow():
|
||||
_LOGGER.warning("Stream already expired")
|
||||
return self._stream.rtsp_stream_url
|
||||
return self._rtsp_stream.rtsp_stream_url
|
||||
|
||||
def _all_streams(self) -> list[Stream]:
|
||||
"""Return the current list of active streams."""
|
||||
streams: list[Stream] = []
|
||||
if self._rtsp_stream:
|
||||
streams.append(self._rtsp_stream)
|
||||
streams.extend(list(self._webrtc_sessions.values()))
|
||||
return streams
|
||||
|
||||
def _schedule_stream_refresh(self) -> None:
|
||||
"""Schedules an alarm to refresh the stream url before expiration."""
|
||||
assert self._stream
|
||||
_LOGGER.debug("New stream url expires at %s", self._stream.expires_at)
|
||||
refresh_time = self._stream.expires_at - STREAM_EXPIRATION_BUFFER
|
||||
"""Schedules an alarm to refresh any streams before expiration."""
|
||||
# Schedule an alarm to extend the stream
|
||||
if self._stream_refresh_unsub is not None:
|
||||
self._stream_refresh_unsub()
|
||||
|
||||
_LOGGER.debug("Scheduling next stream refresh")
|
||||
expiration_times = [stream.expires_at for stream in self._all_streams()]
|
||||
if not expiration_times:
|
||||
_LOGGER.debug("No streams to refresh")
|
||||
return
|
||||
|
||||
refresh_time = min(expiration_times) - STREAM_EXPIRATION_BUFFER
|
||||
_LOGGER.debug("Scheduled next stream refresh for %s", refresh_time)
|
||||
|
||||
self._stream_refresh_unsub = async_track_point_in_utc_time(
|
||||
self.hass,
|
||||
self._handle_stream_refresh,
|
||||
refresh_time,
|
||||
)
|
||||
|
||||
async def _handle_stream_refresh(self, now: datetime.datetime) -> None:
|
||||
async def _handle_stream_refresh(self, _: datetime.datetime) -> None:
|
||||
"""Alarm that fires to check if the stream should be refreshed."""
|
||||
if not self._stream:
|
||||
_LOGGER.debug("Examining streams to refresh")
|
||||
await self._handle_rtsp_stream_refresh()
|
||||
await self._handle_webrtc_stream_refresh()
|
||||
self._schedule_stream_refresh()
|
||||
|
||||
async def _handle_rtsp_stream_refresh(self) -> None:
|
||||
"""Alarm that fires to check if the stream should be refreshed."""
|
||||
if not self._rtsp_stream:
|
||||
return
|
||||
_LOGGER.debug("Extending stream url")
|
||||
now = utcnow()
|
||||
refresh_time = self._rtsp_stream.expires_at - STREAM_EXPIRATION_BUFFER
|
||||
if now < refresh_time:
|
||||
return
|
||||
_LOGGER.debug("Extending RTSP stream")
|
||||
try:
|
||||
self._stream = await self._stream.extend_rtsp_stream()
|
||||
self._rtsp_stream = await self._rtsp_stream.extend_rtsp_stream()
|
||||
except ApiException as err:
|
||||
_LOGGER.debug("Failed to extend stream: %s", err)
|
||||
# Next attempt to catch a url will get a new one
|
||||
self._stream = None
|
||||
self._rtsp_stream = None
|
||||
if self.stream:
|
||||
await self.stream.stop()
|
||||
self.stream = None
|
||||
return
|
||||
# Update the stream worker with the latest valid url
|
||||
if self.stream:
|
||||
self.stream.update_source(self._stream.rtsp_stream_url)
|
||||
self._schedule_stream_refresh()
|
||||
self.stream.update_source(self._rtsp_stream.rtsp_stream_url)
|
||||
|
||||
async def _handle_webrtc_stream_refresh(self) -> None:
|
||||
"""Alarm that fires to check if the stream should be refreshed."""
|
||||
now = utcnow()
|
||||
for webrtc_stream in list(self._webrtc_sessions.values()):
|
||||
if now < (webrtc_stream.expires_at - STREAM_EXPIRATION_BUFFER):
|
||||
_LOGGER.debug(
|
||||
"Stream does not yet expire: %s", webrtc_stream.expires_at
|
||||
)
|
||||
continue
|
||||
_LOGGER.debug("Extending WebRTC stream %s", webrtc_stream.media_session_id)
|
||||
try:
|
||||
webrtc_stream = await webrtc_stream.extend_stream()
|
||||
except ApiException as err:
|
||||
_LOGGER.debug("Failed to extend stream: %s", err)
|
||||
else:
|
||||
self._webrtc_sessions[webrtc_stream.media_session_id] = webrtc_stream
|
||||
|
||||
async def async_will_remove_from_hass(self) -> None:
|
||||
"""Invalidates the RTSP token when unloaded."""
|
||||
if self._stream:
|
||||
for stream in self._all_streams():
|
||||
_LOGGER.debug("Invalidating stream")
|
||||
try:
|
||||
await self._stream.stop_rtsp_stream()
|
||||
await stream.stop_stream()
|
||||
except ApiException as err:
|
||||
_LOGGER.debug(
|
||||
"Failed to revoke stream token, will rely on ttl: %s", err
|
||||
)
|
||||
_LOGGER.debug("Error stopping stream: %s", err)
|
||||
self._rtsp_stream = None
|
||||
self._webrtc_sessions.clear()
|
||||
|
||||
if self._stream_refresh_unsub:
|
||||
self._stream_refresh_unsub()
|
||||
|
||||
|
@ -223,14 +266,28 @@ class NestCamera(Camera):
|
|||
stream = await trait.generate_web_rtc_stream(offer_sdp)
|
||||
except ApiException as err:
|
||||
raise HomeAssistantError(f"Nest API error: {err}") from err
|
||||
_LOGGER.debug(
|
||||
"Started WebRTC session %s, %s", session_id, stream.media_session_id
|
||||
)
|
||||
self._webrtc_sessions[session_id] = stream
|
||||
send_message(WebRTCAnswer(stream.answer_sdp))
|
||||
self._schedule_stream_refresh()
|
||||
|
||||
@callback
|
||||
def close_webrtc_session(self, session_id: str) -> None:
|
||||
"""Close a WebRTC session."""
|
||||
if (stream := self._webrtc_sessions.pop(session_id, None)) is not None:
|
||||
self.hass.async_create_task(stream.stop_stream())
|
||||
_LOGGER.debug(
|
||||
"Closing WebRTC session %s, %s", session_id, stream.media_session_id
|
||||
)
|
||||
|
||||
async def stop_stream() -> None:
|
||||
try:
|
||||
await stream.stop_stream()
|
||||
except ApiException as err:
|
||||
_LOGGER.debug("Error stopping stream: %s", err)
|
||||
|
||||
self.hass.async_create_task(stop_stream())
|
||||
super().close_webrtc_session(session_id)
|
||||
|
||||
@callback
|
||||
|
|
|
@ -803,3 +803,94 @@ async def test_camera_multiple_streams(
|
|||
"type": "answer",
|
||||
"answer": "v=0\r\ns=-\r\n",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("webrtc_camera_device")
|
||||
async def test_webrtc_refresh_expired_stream(
|
||||
hass: HomeAssistant,
|
||||
setup_platform: PlatformSetup,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
auth: FakeAuth,
|
||||
) -> None:
|
||||
"""Test a camera webrtc expiration and refresh."""
|
||||
now = utcnow()
|
||||
|
||||
stream_1_expiration = now + datetime.timedelta(seconds=90)
|
||||
stream_2_expiration = now + datetime.timedelta(seconds=180)
|
||||
auth.responses = [
|
||||
aiohttp.web.json_response(
|
||||
{
|
||||
"results": {
|
||||
"answerSdp": "v=0\r\ns=-\r\n",
|
||||
"mediaSessionId": "yP2grqz0Y1V_wgiX9KEbMWHoLd...",
|
||||
"expiresAt": stream_1_expiration.isoformat(timespec="seconds"),
|
||||
},
|
||||
}
|
||||
),
|
||||
aiohttp.web.json_response(
|
||||
{
|
||||
"results": {
|
||||
"mediaSessionId": "yP2grqz0Y1V_wgiX9KEbMWHoLd...",
|
||||
"expiresAt": stream_2_expiration.isoformat(timespec="seconds"),
|
||||
},
|
||||
}
|
||||
),
|
||||
]
|
||||
await setup_platform()
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert len(hass.states.async_all()) == 1
|
||||
cam = hass.states.get("camera.my_camera")
|
||||
assert cam is not None
|
||||
assert cam.state == CameraState.STREAMING
|
||||
assert cam.attributes["frontend_stream_type"] == StreamType.WEB_RTC
|
||||
|
||||
client = await hass_ws_client(hass)
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "camera/webrtc/offer",
|
||||
"entity_id": "camera.my_camera",
|
||||
"offer": "a=recvonly",
|
||||
}
|
||||
)
|
||||
|
||||
response = await client.receive_json()
|
||||
assert response["type"] == TYPE_RESULT
|
||||
assert response["success"]
|
||||
subscription_id = response["id"]
|
||||
|
||||
# Session id
|
||||
response = await client.receive_json()
|
||||
assert response["id"] == subscription_id
|
||||
assert response["type"] == "event"
|
||||
assert response["event"]["type"] == "session"
|
||||
|
||||
# Answer
|
||||
response = await client.receive_json()
|
||||
assert response["id"] == subscription_id
|
||||
assert response["type"] == "event"
|
||||
assert response["event"] == {
|
||||
"type": "answer",
|
||||
"answer": "v=0\r\ns=-\r\n",
|
||||
}
|
||||
|
||||
assert len(auth.captured_requests) == 1
|
||||
assert (
|
||||
auth.captured_requests[0][2].get("command")
|
||||
== "sdm.devices.commands.CameraLiveStream.GenerateWebRtcStream"
|
||||
)
|
||||
|
||||
# Fire alarm before stream_1_expiration. The stream url is not refreshed
|
||||
next_update = now + datetime.timedelta(seconds=25)
|
||||
await fire_alarm(hass, next_update)
|
||||
assert len(auth.captured_requests) == 1
|
||||
|
||||
# Alarm is near stream_1_expiration which causes the stream extension
|
||||
next_update = now + datetime.timedelta(seconds=60)
|
||||
await fire_alarm(hass, next_update)
|
||||
|
||||
assert len(auth.captured_requests) >= 2
|
||||
assert (
|
||||
auth.captured_requests[1][2].get("command")
|
||||
== "sdm.devices.commands.CameraLiveStream.ExtendWebRtcStream"
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue