Split StreamState class out of SegmentBuffer (#60423)
This refactoring was pulled out of https://github.com/home-assistant/core/pull/53676 as an initial step towards reverting the addition of the SegmentBuffer class, which will be unrolled back into a for loop. The StreamState class holds the persistent state in stream that is used across stream worker instantiations, e.g. state across a retry or url expiration, which primarily handles discontinuities. By itself, this PR is not a large win until follow up PRs further simplify the SegmentBuffer class.pull/60573/head
parent
890790a659
commit
8ca89b10eb
|
@ -286,9 +286,9 @@ class Stream:
|
|||
"""Handle consuming streams and restart keepalive streams."""
|
||||
# Keep import here so that we can import stream integration without installing reqs
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from .worker import SegmentBuffer, StreamWorkerError, stream_worker
|
||||
from .worker import StreamState, StreamWorkerError, stream_worker
|
||||
|
||||
segment_buffer = SegmentBuffer(self.hass, self.outputs)
|
||||
stream_state = StreamState(self.hass, self.outputs)
|
||||
wait_timeout = 0
|
||||
while not self._thread_quit.wait(timeout=wait_timeout):
|
||||
start_time = time.time()
|
||||
|
@ -298,14 +298,14 @@ class Stream:
|
|||
stream_worker(
|
||||
self.source,
|
||||
self.options,
|
||||
segment_buffer,
|
||||
stream_state,
|
||||
self._thread_quit,
|
||||
)
|
||||
except StreamWorkerError as err:
|
||||
_LOGGER.error("Error from stream worker: %s", str(err))
|
||||
self._available = False
|
||||
|
||||
segment_buffer.discontinuity()
|
||||
stream_state.discontinuity()
|
||||
if not self.keepalive or self._thread_quit.is_set():
|
||||
if self._fast_restart_once:
|
||||
# The stream source is updated, restart without any delay.
|
||||
|
|
|
@ -40,28 +40,77 @@ class StreamEndedError(StreamWorkerError):
|
|||
"""Raised when the stream is complete, exposed for facilitating testing."""
|
||||
|
||||
|
||||
class SegmentBuffer:
|
||||
"""Buffer for writing a sequence of packets to the output as a segment."""
|
||||
class StreamState:
|
||||
"""Responsible for trakcing output and playback state for a stream.
|
||||
|
||||
Holds state used for playback to interpret a decoded stream. A source stream
|
||||
may be reset (e.g. reconnecting to an rtsp stream) and this object tracks
|
||||
the state to inform the player.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
outputs_callback: Callable[[], Mapping[str, StreamOutput]],
|
||||
) -> None:
|
||||
"""Initialize SegmentBuffer."""
|
||||
"""Initialize StreamState."""
|
||||
self._stream_id: int = 0
|
||||
self._hass = hass
|
||||
self.hass = hass
|
||||
self._outputs_callback: Callable[
|
||||
[], Mapping[str, StreamOutput]
|
||||
] = outputs_callback
|
||||
# sequence gets incremented before the first segment so the first segment
|
||||
# has a sequence number of 0.
|
||||
self._sequence = -1
|
||||
|
||||
@property
|
||||
def sequence(self) -> int:
|
||||
"""Return the current sequence for the latest segment."""
|
||||
return self._sequence
|
||||
|
||||
def next_sequence(self) -> int:
|
||||
"""Increment the sequence number."""
|
||||
self._sequence += 1
|
||||
return self._sequence
|
||||
|
||||
@property
|
||||
def stream_id(self) -> int:
|
||||
"""Return the readonly stream_id attribute."""
|
||||
return self._stream_id
|
||||
|
||||
def discontinuity(self) -> None:
|
||||
"""Mark the stream as having been restarted."""
|
||||
# Preserving sequence and stream_id here keep the HLS playlist logic
|
||||
# simple to check for discontinuity at output time, and to determine
|
||||
# the discontinuity sequence number.
|
||||
self._stream_id += 1
|
||||
# Call discontinuity to remove incomplete segment from the HLS output
|
||||
if hls_output := self._outputs_callback().get(HLS_PROVIDER):
|
||||
cast(HlsStreamOutput, hls_output).discontinuity()
|
||||
|
||||
@property
|
||||
def outputs(self) -> list[StreamOutput]:
|
||||
"""Return the active stream outputs."""
|
||||
return list(self._outputs_callback().values())
|
||||
|
||||
|
||||
class StreamMuxer:
|
||||
"""StreamMuxer re-packages video/audio packets for output."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
video_stream: av.video.VideoStream,
|
||||
audio_stream: av.audio.stream.AudioStream | None,
|
||||
stream_state: StreamState,
|
||||
) -> None:
|
||||
"""Initialize StreamMuxer."""
|
||||
self._hass = hass
|
||||
self._segment_start_dts: int = cast(int, None)
|
||||
self._memory_file: BytesIO = cast(BytesIO, None)
|
||||
self._av_output: av.container.OutputContainer = None
|
||||
self._input_video_stream: av.video.VideoStream = None
|
||||
self._input_audio_stream: av.audio.stream.AudioStream | None = None
|
||||
self._input_video_stream: av.video.VideoStream = video_stream
|
||||
self._input_audio_stream: av.audio.stream.AudioStream | None = audio_stream
|
||||
self._output_video_stream: av.video.VideoStream = None
|
||||
self._output_audio_stream: av.audio.stream.AudioStream | None = None
|
||||
self._segment: Segment | None = None
|
||||
|
@ -70,6 +119,7 @@ class SegmentBuffer:
|
|||
self._part_start_dts: int = cast(int, None)
|
||||
self._part_has_keyframe = False
|
||||
self._stream_settings: StreamSettings = hass.data[DOMAIN][ATTR_SETTINGS]
|
||||
self._stream_state = stream_state
|
||||
self._start_time = datetime.datetime.utcnow()
|
||||
|
||||
def make_new_av(
|
||||
|
@ -77,14 +127,13 @@ class SegmentBuffer:
|
|||
memory_file: BytesIO,
|
||||
sequence: int,
|
||||
input_vstream: av.video.VideoStream,
|
||||
input_astream: av.audio.stream.AudioStream,
|
||||
input_astream: av.audio.stream.AudioStream | None,
|
||||
) -> tuple[
|
||||
av.container.OutputContainer,
|
||||
av.video.VideoStream,
|
||||
av.audio.stream.AudioStream | None,
|
||||
]:
|
||||
"""Make a new av OutputContainer and add output streams."""
|
||||
add_audio = input_astream and input_astream.name in AUDIO_CODECS
|
||||
container = av.open(
|
||||
memory_file,
|
||||
mode="w",
|
||||
|
@ -135,24 +184,12 @@ class SegmentBuffer:
|
|||
output_vstream = container.add_stream(template=input_vstream)
|
||||
# Check if audio is requested
|
||||
output_astream = None
|
||||
if add_audio:
|
||||
if input_astream:
|
||||
output_astream = container.add_stream(template=input_astream)
|
||||
return container, output_vstream, output_astream
|
||||
|
||||
def set_streams(
|
||||
self,
|
||||
video_stream: av.video.VideoStream,
|
||||
audio_stream: Any,
|
||||
# no type hint for audio_stream until https://github.com/PyAV-Org/PyAV/pull/775 is merged
|
||||
) -> None:
|
||||
"""Initialize output buffer with streams from container."""
|
||||
self._input_video_stream = video_stream
|
||||
self._input_audio_stream = audio_stream
|
||||
|
||||
def reset(self, video_dts: int) -> None:
|
||||
"""Initialize a new stream segment."""
|
||||
# Keep track of the number of segments we've processed
|
||||
self._sequence += 1
|
||||
self._part_start_dts = self._segment_start_dts = video_dts
|
||||
self._segment = None
|
||||
self._memory_file = BytesIO()
|
||||
|
@ -163,7 +200,7 @@ class SegmentBuffer:
|
|||
self._output_audio_stream,
|
||||
) = self.make_new_av(
|
||||
memory_file=self._memory_file,
|
||||
sequence=self._sequence,
|
||||
sequence=self._stream_state.next_sequence(),
|
||||
input_vstream=self._input_video_stream,
|
||||
input_astream=self._input_audio_stream,
|
||||
)
|
||||
|
@ -201,12 +238,12 @@ class SegmentBuffer:
|
|||
# We have our first non-zero byte position. This means the init has just
|
||||
# been written. Create a Segment and put it to the queue of each output.
|
||||
self._segment = Segment(
|
||||
sequence=self._sequence,
|
||||
stream_id=self._stream_id,
|
||||
sequence=self._stream_state.sequence,
|
||||
stream_id=self._stream_state.stream_id,
|
||||
init=self._memory_file.getvalue(),
|
||||
# Fetch the latest StreamOutputs, which may have changed since the
|
||||
# worker started.
|
||||
stream_outputs=self._outputs_callback().values(),
|
||||
stream_outputs=self._stream_state.outputs,
|
||||
start_time=self._start_time,
|
||||
)
|
||||
self._memory_file_pos = self._memory_file.tell()
|
||||
|
@ -283,17 +320,6 @@ class SegmentBuffer:
|
|||
self._part_start_dts = adjusted_dts
|
||||
self._part_has_keyframe = False
|
||||
|
||||
def discontinuity(self) -> None:
|
||||
"""Mark the stream as having been restarted."""
|
||||
# Preserving sequence and stream_id here keep the HLS playlist logic
|
||||
# simple to check for discontinuity at output time, and to determine
|
||||
# the discontinuity sequence number.
|
||||
self._stream_id += 1
|
||||
self._start_time = datetime.datetime.utcnow()
|
||||
# Call discontinuity to remove incomplete segment from the HLS output
|
||||
if hls_output := self._outputs_callback().get(HLS_PROVIDER):
|
||||
cast(HlsStreamOutput, hls_output).discontinuity()
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close stream buffer."""
|
||||
self._av_output.close()
|
||||
|
@ -412,7 +438,7 @@ def unsupported_audio(packets: Iterator[av.Packet], audio_stream: Any) -> bool:
|
|||
def stream_worker(
|
||||
source: str,
|
||||
options: dict[str, str],
|
||||
segment_buffer: SegmentBuffer,
|
||||
stream_state: StreamState,
|
||||
quit_event: Event,
|
||||
) -> None:
|
||||
"""Handle consuming streams."""
|
||||
|
@ -431,6 +457,8 @@ def stream_worker(
|
|||
audio_stream = container.streams.audio[0]
|
||||
except (KeyError, IndexError):
|
||||
audio_stream = None
|
||||
if audio_stream and audio_stream.name not in AUDIO_CODECS:
|
||||
audio_stream = None
|
||||
# These formats need aac_adtstoasc bitstream filter, but auto_bsf not
|
||||
# compatible with empty_moov and manual bitstream filters not in PyAV
|
||||
if container.format.name in {"hls", "mpegts"}:
|
||||
|
@ -489,13 +517,13 @@ def stream_worker(
|
|||
"Error demuxing stream while finding first packet: %s" % str(ex)
|
||||
) from ex
|
||||
|
||||
segment_buffer.set_streams(video_stream, audio_stream)
|
||||
segment_buffer.reset(start_dts)
|
||||
muxer = StreamMuxer(stream_state.hass, video_stream, audio_stream, stream_state)
|
||||
muxer.reset(start_dts)
|
||||
|
||||
# Mux the first keyframe, then proceed through the rest of the packets
|
||||
segment_buffer.mux_packet(first_keyframe)
|
||||
muxer.mux_packet(first_keyframe)
|
||||
|
||||
with contextlib.closing(container), contextlib.closing(segment_buffer):
|
||||
with contextlib.closing(container), contextlib.closing(muxer):
|
||||
while not quit_event.is_set():
|
||||
try:
|
||||
packet = next(container_packets)
|
||||
|
@ -506,4 +534,4 @@ def stream_worker(
|
|||
except av.AVError as ex:
|
||||
raise StreamWorkerError("Error demuxing stream: %s" % str(ex)) from ex
|
||||
|
||||
segment_buffer.mux_packet(packet)
|
||||
muxer.mux_packet(packet)
|
||||
|
|
|
@ -23,7 +23,7 @@ import async_timeout
|
|||
import pytest
|
||||
|
||||
from homeassistant.components.stream.core import Segment, StreamOutput
|
||||
from homeassistant.components.stream.worker import SegmentBuffer
|
||||
from homeassistant.components.stream.worker import StreamState
|
||||
|
||||
TEST_TIMEOUT = 7.0 # Lower than 9s home assistant timeout
|
||||
|
||||
|
@ -34,7 +34,7 @@ class WorkerSync:
|
|||
def __init__(self):
|
||||
"""Initialize WorkerSync."""
|
||||
self._event = None
|
||||
self._original = SegmentBuffer.discontinuity
|
||||
self._original = StreamState.discontinuity
|
||||
|
||||
def pause(self):
|
||||
"""Pause the worker before it finalizes the stream."""
|
||||
|
@ -45,7 +45,7 @@ class WorkerSync:
|
|||
logging.debug("waking blocked worker")
|
||||
self._event.set()
|
||||
|
||||
def blocking_discontinuity(self, stream: SegmentBuffer):
|
||||
def blocking_discontinuity(self, stream_state: StreamState):
|
||||
"""Intercept call to pause stream worker."""
|
||||
# Worker is ending the stream, which clears all output buffers.
|
||||
# Block the worker thread until the test has a chance to verify
|
||||
|
@ -55,7 +55,7 @@ class WorkerSync:
|
|||
self._event.wait()
|
||||
|
||||
# Forward to actual implementation
|
||||
self._original(stream)
|
||||
self._original(stream_state)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
|
@ -63,7 +63,7 @@ def stream_worker_sync(hass):
|
|||
"""Patch StreamOutput to allow test to synchronize worker stream end."""
|
||||
sync = WorkerSync()
|
||||
with patch(
|
||||
"homeassistant.components.stream.worker.SegmentBuffer.discontinuity",
|
||||
"homeassistant.components.stream.worker.StreamState.discontinuity",
|
||||
side_effect=sync.blocking_discontinuity,
|
||||
autospec=True,
|
||||
):
|
||||
|
|
|
@ -38,8 +38,8 @@ from homeassistant.components.stream.const import (
|
|||
)
|
||||
from homeassistant.components.stream.core import StreamSettings
|
||||
from homeassistant.components.stream.worker import (
|
||||
SegmentBuffer,
|
||||
StreamEndedError,
|
||||
StreamState,
|
||||
StreamWorkerError,
|
||||
stream_worker,
|
||||
)
|
||||
|
@ -255,6 +255,12 @@ class MockPyAv:
|
|||
return self.container
|
||||
|
||||
|
||||
def run_worker(hass, stream, stream_source):
|
||||
"""Run the stream worker under test."""
|
||||
stream_state = StreamState(hass, stream.outputs)
|
||||
stream_worker(stream_source, {}, stream_state, threading.Event())
|
||||
|
||||
|
||||
async def async_decode_stream(hass, packets, py_av=None):
|
||||
"""Start a stream worker that decodes incoming stream packets into output segments."""
|
||||
stream = Stream(hass, STREAM_SOURCE, {})
|
||||
|
@ -268,9 +274,8 @@ async def async_decode_stream(hass, packets, py_av=None):
|
|||
"homeassistant.components.stream.core.StreamOutput.put",
|
||||
side_effect=py_av.capture_buffer.capture_output_segment,
|
||||
):
|
||||
segment_buffer = SegmentBuffer(hass, stream.outputs)
|
||||
try:
|
||||
stream_worker(STREAM_SOURCE, {}, segment_buffer, threading.Event())
|
||||
run_worker(hass, stream, STREAM_SOURCE)
|
||||
except StreamEndedError:
|
||||
# Tests only use a limited number of packets, then the worker exits as expected. In
|
||||
# production, stream ending would be unexpected.
|
||||
|
@ -288,8 +293,7 @@ async def test_stream_open_fails(hass):
|
|||
stream.add_provider(HLS_PROVIDER)
|
||||
with patch("av.open") as av_open, pytest.raises(StreamWorkerError):
|
||||
av_open.side_effect = av.error.InvalidDataError(-2, "error")
|
||||
segment_buffer = SegmentBuffer(hass, stream.outputs)
|
||||
stream_worker(STREAM_SOURCE, {}, segment_buffer, threading.Event())
|
||||
run_worker(hass, stream, STREAM_SOURCE)
|
||||
await hass.async_block_till_done()
|
||||
av_open.assert_called_once()
|
||||
|
||||
|
@ -695,10 +699,7 @@ async def test_worker_log(hass, caplog):
|
|||
|
||||
with patch("av.open") as av_open, pytest.raises(StreamWorkerError) as err:
|
||||
av_open.side_effect = av.error.InvalidDataError(-2, "error")
|
||||
segment_buffer = SegmentBuffer(hass, stream.outputs)
|
||||
stream_worker(
|
||||
"https://abcd:efgh@foo.bar", {}, segment_buffer, threading.Event()
|
||||
)
|
||||
run_worker(hass, stream, "https://abcd:efgh@foo.bar")
|
||||
await hass.async_block_till_done()
|
||||
assert str(err.value) == "Error opening stream https://****:****@foo.bar"
|
||||
assert "https://abcd:efgh@foo.bar" not in caplog.text
|
||||
|
|
Loading…
Reference in New Issue