Make Stream.stop() async (#73107)
* Make Stream.start() async * Stop streams concurrently on shutdown Co-authored-by: Martin Hjelmare <marhje52@gmail.com>pull/73237/head
parent
c6b835dd91
commit
73f2bca377
|
@ -386,7 +386,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||
continue
|
||||
stream.keepalive = True
|
||||
stream.add_provider("hls")
|
||||
stream.start()
|
||||
await stream.start()
|
||||
|
||||
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_START, preload_stream)
|
||||
|
||||
|
@ -996,7 +996,7 @@ async def _async_stream_endpoint_url(
|
|||
stream.keepalive = camera_prefs.preload_stream
|
||||
|
||||
stream.add_provider(fmt)
|
||||
stream.start()
|
||||
await stream.start()
|
||||
return stream.endpoint_url(fmt)
|
||||
|
||||
|
||||
|
|
|
@ -175,7 +175,7 @@ class NestCamera(Camera):
|
|||
# Next attempt to catch a url will get a new one
|
||||
self._stream = None
|
||||
if self.stream:
|
||||
self.stream.stop()
|
||||
await self.stream.stop()
|
||||
self.stream = None
|
||||
return
|
||||
# Update the stream worker with the latest valid url
|
||||
|
|
|
@ -16,6 +16,7 @@ to always keep workers active.
|
|||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable, Mapping
|
||||
import logging
|
||||
import re
|
||||
|
@ -206,13 +207,16 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||
# Setup Recorder
|
||||
async_setup_recorder(hass)
|
||||
|
||||
@callback
|
||||
def shutdown(event: Event) -> None:
|
||||
async def shutdown(event: Event) -> None:
|
||||
"""Stop all stream workers."""
|
||||
for stream in hass.data[DOMAIN][ATTR_STREAMS]:
|
||||
stream.keepalive = False
|
||||
stream.stop()
|
||||
_LOGGER.info("Stopped stream workers")
|
||||
if awaitables := [
|
||||
asyncio.create_task(stream.stop())
|
||||
for stream in hass.data[DOMAIN][ATTR_STREAMS]
|
||||
]:
|
||||
await asyncio.wait(awaitables)
|
||||
_LOGGER.debug("Stopped stream workers")
|
||||
|
||||
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, shutdown)
|
||||
|
||||
|
@ -236,6 +240,7 @@ class Stream:
|
|||
self._stream_label = stream_label
|
||||
self.keepalive = False
|
||||
self.access_token: str | None = None
|
||||
self._start_stop_lock = asyncio.Lock()
|
||||
self._thread: threading.Thread | None = None
|
||||
self._thread_quit = threading.Event()
|
||||
self._outputs: dict[str, StreamOutput] = {}
|
||||
|
@ -271,12 +276,11 @@ class Stream:
|
|||
"""Add provider output stream."""
|
||||
if not (provider := self._outputs.get(fmt)):
|
||||
|
||||
@callback
|
||||
def idle_callback() -> None:
|
||||
async def idle_callback() -> None:
|
||||
if (
|
||||
not self.keepalive or fmt == RECORDER_PROVIDER
|
||||
) and fmt in self._outputs:
|
||||
self.remove_provider(self._outputs[fmt])
|
||||
await self.remove_provider(self._outputs[fmt])
|
||||
self.check_idle()
|
||||
|
||||
provider = PROVIDERS[fmt](
|
||||
|
@ -286,14 +290,14 @@ class Stream:
|
|||
|
||||
return provider
|
||||
|
||||
def remove_provider(self, provider: StreamOutput) -> None:
|
||||
async def remove_provider(self, provider: StreamOutput) -> None:
|
||||
"""Remove provider output stream."""
|
||||
if provider.name in self._outputs:
|
||||
self._outputs[provider.name].cleanup()
|
||||
del self._outputs[provider.name]
|
||||
|
||||
if not self._outputs:
|
||||
self.stop()
|
||||
await self.stop()
|
||||
|
||||
def check_idle(self) -> None:
|
||||
"""Reset access token if all providers are idle."""
|
||||
|
@ -316,9 +320,14 @@ class Stream:
|
|||
if self._update_callback:
|
||||
self._update_callback()
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start a stream."""
|
||||
if self._thread is None or not self._thread.is_alive():
|
||||
async def start(self) -> None:
|
||||
"""Start a stream.
|
||||
|
||||
Uses an asyncio.Lock to avoid conflicts with _stop().
|
||||
"""
|
||||
async with self._start_stop_lock:
|
||||
if self._thread and self._thread.is_alive():
|
||||
return
|
||||
if self._thread is not None:
|
||||
# The thread must have crashed/exited. Join to clean up the
|
||||
# previous thread.
|
||||
|
@ -329,7 +338,7 @@ class Stream:
|
|||
target=self._run_worker,
|
||||
)
|
||||
self._thread.start()
|
||||
self._logger.info(
|
||||
self._logger.debug(
|
||||
"Started stream: %s", redact_credentials(str(self.source))
|
||||
)
|
||||
|
||||
|
@ -394,33 +403,39 @@ class Stream:
|
|||
redact_credentials(str(self.source)),
|
||||
)
|
||||
|
||||
@callback
|
||||
def worker_finished() -> None:
|
||||
async def worker_finished() -> None:
|
||||
# The worker is no checking availability of the stream and can no longer track
|
||||
# availability so mark it as available, otherwise the frontend may not be able to
|
||||
# interact with the stream.
|
||||
if not self.available:
|
||||
self._async_update_state(True)
|
||||
# We can call remove_provider() sequentially as the wrapped _stop() function
|
||||
# which blocks internally is only called when the last provider is removed.
|
||||
for provider in self.outputs().values():
|
||||
self.remove_provider(provider)
|
||||
await self.remove_provider(provider)
|
||||
|
||||
self.hass.loop.call_soon_threadsafe(worker_finished)
|
||||
self.hass.create_task(worker_finished())
|
||||
|
||||
def stop(self) -> None:
|
||||
async def stop(self) -> None:
|
||||
"""Remove outputs and access token."""
|
||||
self._outputs = {}
|
||||
self.access_token = None
|
||||
|
||||
if not self.keepalive:
|
||||
self._stop()
|
||||
await self._stop()
|
||||
|
||||
def _stop(self) -> None:
|
||||
"""Stop worker thread."""
|
||||
if self._thread is not None:
|
||||
async def _stop(self) -> None:
|
||||
"""Stop worker thread.
|
||||
|
||||
Uses an asyncio.Lock to avoid conflicts with start().
|
||||
"""
|
||||
async with self._start_stop_lock:
|
||||
if self._thread is None:
|
||||
return
|
||||
self._thread_quit.set()
|
||||
self._thread.join()
|
||||
await self.hass.async_add_executor_job(self._thread.join)
|
||||
self._thread = None
|
||||
self._logger.info(
|
||||
self._logger.debug(
|
||||
"Stopped stream: %s", redact_credentials(str(self.source))
|
||||
)
|
||||
|
||||
|
@ -448,7 +463,7 @@ class Stream:
|
|||
)
|
||||
recorder.video_path = video_path
|
||||
|
||||
self.start()
|
||||
await self.start()
|
||||
self._logger.debug("Started a stream recording of %s seconds", duration)
|
||||
|
||||
# Take advantage of lookback
|
||||
|
@ -473,7 +488,7 @@ class Stream:
|
|||
"""
|
||||
|
||||
self.add_provider(HLS_PROVIDER)
|
||||
self.start()
|
||||
await self.start()
|
||||
return await self._keyframe_converter.async_get_image(
|
||||
width=width, height=height
|
||||
)
|
||||
|
|
|
@ -3,9 +3,9 @@ from __future__ import annotations
|
|||
|
||||
import asyncio
|
||||
from collections import deque
|
||||
from collections.abc import Iterable
|
||||
from collections.abc import Callable, Coroutine, Iterable
|
||||
import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from aiohttp import web
|
||||
import async_timeout
|
||||
|
@ -192,7 +192,10 @@ class IdleTimer:
|
|||
"""
|
||||
|
||||
def __init__(
|
||||
self, hass: HomeAssistant, timeout: int, idle_callback: CALLBACK_TYPE
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
timeout: int,
|
||||
idle_callback: Callable[[], Coroutine[Any, Any, None]],
|
||||
) -> None:
|
||||
"""Initialize IdleTimer."""
|
||||
self._hass = hass
|
||||
|
@ -219,11 +222,12 @@ class IdleTimer:
|
|||
if self._unsub is not None:
|
||||
self._unsub()
|
||||
|
||||
@callback
|
||||
def fire(self, _now: datetime.datetime) -> None:
|
||||
"""Invoke the idle timeout callback, called when the alarm fires."""
|
||||
self.idle = True
|
||||
self._unsub = None
|
||||
self._callback()
|
||||
self._hass.async_create_task(self._callback())
|
||||
|
||||
|
||||
class StreamOutput:
|
||||
|
@ -349,7 +353,7 @@ class StreamView(HomeAssistantView):
|
|||
raise web.HTTPNotFound()
|
||||
|
||||
# Start worker if not already started
|
||||
stream.start()
|
||||
await stream.start()
|
||||
|
||||
return await self.handle(request, stream, sequence, part_num)
|
||||
|
||||
|
|
|
@ -117,7 +117,7 @@ class HlsMasterPlaylistView(StreamView):
|
|||
) -> web.Response:
|
||||
"""Return m3u8 playlist."""
|
||||
track = stream.add_provider(HLS_PROVIDER)
|
||||
stream.start()
|
||||
await stream.start()
|
||||
# Make sure at least two segments are ready (last one may not be complete)
|
||||
if not track.sequences and not await track.recv():
|
||||
return web.HTTPNotFound()
|
||||
|
@ -232,7 +232,7 @@ class HlsPlaylistView(StreamView):
|
|||
track: HlsStreamOutput = cast(
|
||||
HlsStreamOutput, stream.add_provider(HLS_PROVIDER)
|
||||
)
|
||||
stream.start()
|
||||
await stream.start()
|
||||
|
||||
hls_msn: str | int | None = request.query.get("_HLS_msn")
|
||||
hls_part: str | int | None = request.query.get("_HLS_part")
|
||||
|
|
|
@ -3,7 +3,7 @@ import asyncio
|
|||
import base64
|
||||
from http import HTTPStatus
|
||||
import io
|
||||
from unittest.mock import Mock, PropertyMock, mock_open, patch
|
||||
from unittest.mock import AsyncMock, Mock, PropertyMock, mock_open, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
@ -410,6 +410,7 @@ async def test_preload_stream(hass, mock_stream):
|
|||
"homeassistant.components.demo.camera.DemoCamera.stream_source",
|
||||
return_value="http://example.com",
|
||||
):
|
||||
mock_create_stream.return_value.start = AsyncMock()
|
||||
assert await async_setup_component(
|
||||
hass, "camera", {DOMAIN: {"platform": "demo"}}
|
||||
)
|
||||
|
|
|
@ -158,6 +158,7 @@ async def mock_create_stream(hass) -> Mock:
|
|||
)
|
||||
mock_stream.return_value.async_get_image = AsyncMock()
|
||||
mock_stream.return_value.async_get_image.return_value = IMAGE_BYTES_FROM_STREAM
|
||||
mock_stream.return_value.start = AsyncMock()
|
||||
yield mock_stream
|
||||
|
||||
|
||||
|
@ -370,6 +371,7 @@ async def test_refresh_expired_stream_token(
|
|||
# Request a stream for the camera entity to exercise nest cam + camera interaction
|
||||
# and shutdown on url expiration
|
||||
with patch("homeassistant.components.camera.create_stream") as create_stream:
|
||||
create_stream.return_value.start = AsyncMock()
|
||||
hls_url = await camera.async_request_stream(hass, "camera.my_camera", fmt="hls")
|
||||
assert hls_url.startswith("/api/hls/") # Includes access token
|
||||
assert create_stream.called
|
||||
|
@ -536,7 +538,8 @@ async def test_refresh_expired_stream_failure(
|
|||
|
||||
# Request an HLS stream
|
||||
with patch("homeassistant.components.camera.create_stream") as create_stream:
|
||||
|
||||
create_stream.return_value.start = AsyncMock()
|
||||
create_stream.return_value.stop = AsyncMock()
|
||||
hls_url = await camera.async_request_stream(hass, "camera.my_camera", fmt="hls")
|
||||
assert hls_url.startswith("/api/hls/") # Includes access token
|
||||
assert create_stream.called
|
||||
|
@ -555,6 +558,7 @@ async def test_refresh_expired_stream_failure(
|
|||
|
||||
# Requesting an HLS stream will create an entirely new stream
|
||||
with patch("homeassistant.components.camera.create_stream") as create_stream:
|
||||
create_stream.return_value.start = AsyncMock()
|
||||
# The HLS stream endpoint was invalidated, with a new auth token
|
||||
hls_url2 = await camera.async_request_stream(
|
||||
hass, "camera.my_camera", fmt="hls"
|
||||
|
|
|
@ -144,7 +144,7 @@ async def test_hls_stream(
|
|||
|
||||
# Request stream
|
||||
stream.add_provider(HLS_PROVIDER)
|
||||
stream.start()
|
||||
await stream.start()
|
||||
|
||||
hls_client = await hls_stream(stream)
|
||||
|
||||
|
@ -171,7 +171,7 @@ async def test_hls_stream(
|
|||
stream_worker_sync.resume()
|
||||
|
||||
# Stop stream, if it hasn't quit already
|
||||
stream.stop()
|
||||
await stream.stop()
|
||||
|
||||
# Ensure playlist not accessible after stream ends
|
||||
fail_response = await hls_client.get()
|
||||
|
@ -205,7 +205,7 @@ async def test_stream_timeout(
|
|||
|
||||
# Request stream
|
||||
stream.add_provider(HLS_PROVIDER)
|
||||
stream.start()
|
||||
await stream.start()
|
||||
url = stream.endpoint_url(HLS_PROVIDER)
|
||||
|
||||
http_client = await hass_client()
|
||||
|
@ -218,6 +218,7 @@ async def test_stream_timeout(
|
|||
# Wait a minute
|
||||
future = dt_util.utcnow() + timedelta(minutes=1)
|
||||
async_fire_time_changed(hass, future)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
# Fetch again to reset timer
|
||||
playlist_response = await http_client.get(parsed_url.path)
|
||||
|
@ -249,10 +250,10 @@ async def test_stream_timeout_after_stop(
|
|||
|
||||
# Request stream
|
||||
stream.add_provider(HLS_PROVIDER)
|
||||
stream.start()
|
||||
await stream.start()
|
||||
|
||||
stream_worker_sync.resume()
|
||||
stream.stop()
|
||||
await stream.stop()
|
||||
|
||||
# Wait 5 minutes and fire callback. Stream should already have been
|
||||
# stopped so this is a no-op.
|
||||
|
@ -297,14 +298,14 @@ async def test_stream_retries(hass, setup_component, should_retry):
|
|||
mock_time.time.side_effect = time_side_effect
|
||||
# Request stream. Enable retries which are disabled by default in tests.
|
||||
should_retry.return_value = True
|
||||
stream.start()
|
||||
await stream.start()
|
||||
stream._thread.join()
|
||||
stream._thread = None
|
||||
assert av_open.call_count == 2
|
||||
await hass.async_block_till_done()
|
||||
|
||||
# Stop stream, if it hasn't quit already
|
||||
stream.stop()
|
||||
await stream.stop()
|
||||
|
||||
# Stream marked initially available, then marked as failed, then marked available
|
||||
# before the final failure that exits the stream.
|
||||
|
@ -351,7 +352,7 @@ async def test_hls_playlist_view(hass, setup_component, hls_stream, stream_worke
|
|||
)
|
||||
|
||||
stream_worker_sync.resume()
|
||||
stream.stop()
|
||||
await stream.stop()
|
||||
|
||||
|
||||
async def test_hls_max_segments(hass, setup_component, hls_stream, stream_worker_sync):
|
||||
|
@ -400,7 +401,7 @@ async def test_hls_max_segments(hass, setup_component, hls_stream, stream_worker
|
|||
assert segment_response.status == HTTPStatus.OK
|
||||
|
||||
stream_worker_sync.resume()
|
||||
stream.stop()
|
||||
await stream.stop()
|
||||
|
||||
|
||||
async def test_hls_playlist_view_discontinuity(
|
||||
|
@ -438,7 +439,7 @@ async def test_hls_playlist_view_discontinuity(
|
|||
)
|
||||
|
||||
stream_worker_sync.resume()
|
||||
stream.stop()
|
||||
await stream.stop()
|
||||
|
||||
|
||||
async def test_hls_max_segments_discontinuity(
|
||||
|
@ -481,7 +482,7 @@ async def test_hls_max_segments_discontinuity(
|
|||
)
|
||||
|
||||
stream_worker_sync.resume()
|
||||
stream.stop()
|
||||
await stream.stop()
|
||||
|
||||
|
||||
async def test_remove_incomplete_segment_on_exit(
|
||||
|
@ -490,7 +491,7 @@ async def test_remove_incomplete_segment_on_exit(
|
|||
"""Test that the incomplete segment gets removed when the worker thread quits."""
|
||||
stream = create_stream(hass, STREAM_SOURCE, {})
|
||||
stream_worker_sync.pause()
|
||||
stream.start()
|
||||
await stream.start()
|
||||
hls = stream.add_provider(HLS_PROVIDER)
|
||||
|
||||
segment = Segment(sequence=0, stream_id=0, duration=SEGMENT_DURATION)
|
||||
|
@ -511,4 +512,4 @@ async def test_remove_incomplete_segment_on_exit(
|
|||
await hass.async_block_till_done()
|
||||
assert segments[-1].complete
|
||||
assert len(segments) == 2
|
||||
stream.stop()
|
||||
await stream.stop()
|
||||
|
|
|
@ -144,7 +144,7 @@ async def test_ll_hls_stream(hass, hls_stream, stream_worker_sync):
|
|||
|
||||
# Request stream
|
||||
stream.add_provider(HLS_PROVIDER)
|
||||
stream.start()
|
||||
await stream.start()
|
||||
|
||||
hls_client = await hls_stream(stream)
|
||||
|
||||
|
@ -243,7 +243,7 @@ async def test_ll_hls_stream(hass, hls_stream, stream_worker_sync):
|
|||
stream_worker_sync.resume()
|
||||
|
||||
# Stop stream, if it hasn't quit already
|
||||
stream.stop()
|
||||
await stream.stop()
|
||||
|
||||
# Ensure playlist not accessible after stream ends
|
||||
fail_response = await hls_client.get()
|
||||
|
@ -316,7 +316,7 @@ async def test_ll_hls_playlist_view(hass, hls_stream, stream_worker_sync):
|
|||
)
|
||||
|
||||
stream_worker_sync.resume()
|
||||
stream.stop()
|
||||
await stream.stop()
|
||||
|
||||
|
||||
async def test_ll_hls_msn(hass, hls_stream, stream_worker_sync, hls_sync):
|
||||
|
|
|
@ -46,7 +46,7 @@ async def test_record_stream(hass, hass_client, record_worker_sync, h264_video):
|
|||
# thread completes and is shutdown completely to avoid thread leaks.
|
||||
await record_worker_sync.join()
|
||||
|
||||
stream.stop()
|
||||
await stream.stop()
|
||||
|
||||
|
||||
async def test_record_lookback(
|
||||
|
@ -59,14 +59,14 @@ async def test_record_lookback(
|
|||
|
||||
# Start an HLS feed to enable lookback
|
||||
stream.add_provider(HLS_PROVIDER)
|
||||
stream.start()
|
||||
await stream.start()
|
||||
|
||||
with patch.object(hass.config, "is_allowed_path", return_value=True):
|
||||
await stream.async_record("/example/path", lookback=4)
|
||||
|
||||
# This test does not need recorder cleanup since it is not fully exercised
|
||||
|
||||
stream.stop()
|
||||
await stream.stop()
|
||||
|
||||
|
||||
async def test_recorder_timeout(hass, hass_client, stream_worker_sync, h264_video):
|
||||
|
@ -97,7 +97,7 @@ async def test_recorder_timeout(hass, hass_client, stream_worker_sync, h264_vide
|
|||
assert mock_timeout.called
|
||||
|
||||
stream_worker_sync.resume()
|
||||
stream.stop()
|
||||
await stream.stop()
|
||||
await hass.async_block_till_done()
|
||||
await hass.async_block_till_done()
|
||||
|
||||
|
@ -229,7 +229,7 @@ async def test_record_stream_audio(
|
|||
|
||||
assert len(result.streams.audio) == expected_audio_streams
|
||||
result.close()
|
||||
stream.stop()
|
||||
await stream.stop()
|
||||
await hass.async_block_till_done()
|
||||
|
||||
# Verify that the save worker was invoked, then block until its
|
||||
|
|
|
@ -651,12 +651,12 @@ async def test_stream_stopped_while_decoding(hass):
|
|||
return py_av.open(stream_source, args, kwargs)
|
||||
|
||||
with patch("av.open", new=blocking_open):
|
||||
stream.start()
|
||||
await stream.start()
|
||||
assert worker_open.wait(TIMEOUT)
|
||||
# Note: There is a race here where the worker could start as soon
|
||||
# as the wake event is sent, completing all decode work.
|
||||
worker_wake.set()
|
||||
stream.stop()
|
||||
await stream.stop()
|
||||
|
||||
# Stream is still considered available when the worker was still active and asked to stop
|
||||
assert stream.available
|
||||
|
@ -688,7 +688,7 @@ async def test_update_stream_source(hass):
|
|||
return py_av.open(stream_source, args, kwargs)
|
||||
|
||||
with patch("av.open", new=blocking_open):
|
||||
stream.start()
|
||||
await stream.start()
|
||||
assert worker_open.wait(TIMEOUT)
|
||||
assert last_stream_source == STREAM_SOURCE
|
||||
assert stream.available
|
||||
|
@ -704,7 +704,7 @@ async def test_update_stream_source(hass):
|
|||
assert stream.available
|
||||
|
||||
# Cleanup
|
||||
stream.stop()
|
||||
await stream.stop()
|
||||
|
||||
|
||||
async def test_worker_log(hass, caplog):
|
||||
|
@ -796,7 +796,7 @@ async def test_durations(hass, record_worker_sync):
|
|||
|
||||
await record_worker_sync.join()
|
||||
|
||||
stream.stop()
|
||||
await stream.stop()
|
||||
|
||||
|
||||
async def test_has_keyframe(hass, record_worker_sync, h264_video):
|
||||
|
@ -836,7 +836,7 @@ async def test_has_keyframe(hass, record_worker_sync, h264_video):
|
|||
|
||||
await record_worker_sync.join()
|
||||
|
||||
stream.stop()
|
||||
await stream.stop()
|
||||
|
||||
|
||||
async def test_h265_video_is_hvc1(hass, record_worker_sync):
|
||||
|
@ -871,7 +871,7 @@ async def test_h265_video_is_hvc1(hass, record_worker_sync):
|
|||
|
||||
await record_worker_sync.join()
|
||||
|
||||
stream.stop()
|
||||
await stream.stop()
|
||||
|
||||
assert stream.get_diagnostics() == {
|
||||
"container_format": "mov,mp4,m4a,3gp,3g2,mj2",
|
||||
|
@ -905,4 +905,4 @@ async def test_get_image(hass, record_worker_sync):
|
|||
|
||||
assert await stream.async_get_image() == EMPTY_8_6_JPEG
|
||||
|
||||
stream.stop()
|
||||
await stream.stop()
|
||||
|
|
Loading…
Reference in New Issue