Split StreamState class out of SegmentBuffer (#60423)

This refactoring was pulled out of https://github.com/home-assistant/core/pull/53676 as an
initial step towards reverting the addition of the SegmentBuffer class, which will be
unrolled back into a for loop.

The StreamState class holds the persistent state in stream that is used across stream worker
instantiations, e.g. state across a retry or url expiration, which primarily handles
discontinuities. By itself, this PR is not a large win until follow up PRs further simplify
the SegmentBuffer class.
pull/60573/head
Allen Porter 2021-11-29 22:25:28 -08:00 committed by GitHub
parent 890790a659
commit 8ca89b10eb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 89 additions and 60 deletions

View File

@ -286,9 +286,9 @@ class Stream:
"""Handle consuming streams and restart keepalive streams."""
# Keep import here so that we can import stream integration without installing reqs
# pylint: disable=import-outside-toplevel
from .worker import SegmentBuffer, StreamWorkerError, stream_worker
from .worker import StreamState, StreamWorkerError, stream_worker
segment_buffer = SegmentBuffer(self.hass, self.outputs)
stream_state = StreamState(self.hass, self.outputs)
wait_timeout = 0
while not self._thread_quit.wait(timeout=wait_timeout):
start_time = time.time()
@ -298,14 +298,14 @@ class Stream:
stream_worker(
self.source,
self.options,
segment_buffer,
stream_state,
self._thread_quit,
)
except StreamWorkerError as err:
_LOGGER.error("Error from stream worker: %s", str(err))
self._available = False
segment_buffer.discontinuity()
stream_state.discontinuity()
if not self.keepalive or self._thread_quit.is_set():
if self._fast_restart_once:
# The stream source is updated, restart without any delay.

View File

@ -40,28 +40,77 @@ class StreamEndedError(StreamWorkerError):
"""Raised when the stream is complete, exposed for facilitating testing."""
class SegmentBuffer:
"""Buffer for writing a sequence of packets to the output as a segment."""
class StreamState:
"""Responsible for trakcing output and playback state for a stream.
Holds state used for playback to interpret a decoded stream. A source stream
may be reset (e.g. reconnecting to an rtsp stream) and this object tracks
the state to inform the player.
"""
def __init__(
self,
hass: HomeAssistant,
outputs_callback: Callable[[], Mapping[str, StreamOutput]],
) -> None:
"""Initialize SegmentBuffer."""
"""Initialize StreamState."""
self._stream_id: int = 0
self._hass = hass
self.hass = hass
self._outputs_callback: Callable[
[], Mapping[str, StreamOutput]
] = outputs_callback
# sequence gets incremented before the first segment so the first segment
# has a sequence number of 0.
self._sequence = -1
@property
def sequence(self) -> int:
"""Return the current sequence for the latest segment."""
return self._sequence
def next_sequence(self) -> int:
"""Increment the sequence number."""
self._sequence += 1
return self._sequence
@property
def stream_id(self) -> int:
"""Return the readonly stream_id attribute."""
return self._stream_id
def discontinuity(self) -> None:
"""Mark the stream as having been restarted."""
# Preserving sequence and stream_id here keep the HLS playlist logic
# simple to check for discontinuity at output time, and to determine
# the discontinuity sequence number.
self._stream_id += 1
# Call discontinuity to remove incomplete segment from the HLS output
if hls_output := self._outputs_callback().get(HLS_PROVIDER):
cast(HlsStreamOutput, hls_output).discontinuity()
@property
def outputs(self) -> list[StreamOutput]:
"""Return the active stream outputs."""
return list(self._outputs_callback().values())
class StreamMuxer:
"""StreamMuxer re-packages video/audio packets for output."""
def __init__(
self,
hass: HomeAssistant,
video_stream: av.video.VideoStream,
audio_stream: av.audio.stream.AudioStream | None,
stream_state: StreamState,
) -> None:
"""Initialize StreamMuxer."""
self._hass = hass
self._segment_start_dts: int = cast(int, None)
self._memory_file: BytesIO = cast(BytesIO, None)
self._av_output: av.container.OutputContainer = None
self._input_video_stream: av.video.VideoStream = None
self._input_audio_stream: av.audio.stream.AudioStream | None = None
self._input_video_stream: av.video.VideoStream = video_stream
self._input_audio_stream: av.audio.stream.AudioStream | None = audio_stream
self._output_video_stream: av.video.VideoStream = None
self._output_audio_stream: av.audio.stream.AudioStream | None = None
self._segment: Segment | None = None
@ -70,6 +119,7 @@ class SegmentBuffer:
self._part_start_dts: int = cast(int, None)
self._part_has_keyframe = False
self._stream_settings: StreamSettings = hass.data[DOMAIN][ATTR_SETTINGS]
self._stream_state = stream_state
self._start_time = datetime.datetime.utcnow()
def make_new_av(
@ -77,14 +127,13 @@ class SegmentBuffer:
memory_file: BytesIO,
sequence: int,
input_vstream: av.video.VideoStream,
input_astream: av.audio.stream.AudioStream,
input_astream: av.audio.stream.AudioStream | None,
) -> tuple[
av.container.OutputContainer,
av.video.VideoStream,
av.audio.stream.AudioStream | None,
]:
"""Make a new av OutputContainer and add output streams."""
add_audio = input_astream and input_astream.name in AUDIO_CODECS
container = av.open(
memory_file,
mode="w",
@ -135,24 +184,12 @@ class SegmentBuffer:
output_vstream = container.add_stream(template=input_vstream)
# Check if audio is requested
output_astream = None
if add_audio:
if input_astream:
output_astream = container.add_stream(template=input_astream)
return container, output_vstream, output_astream
def set_streams(
self,
video_stream: av.video.VideoStream,
audio_stream: Any,
# no type hint for audio_stream until https://github.com/PyAV-Org/PyAV/pull/775 is merged
) -> None:
"""Initialize output buffer with streams from container."""
self._input_video_stream = video_stream
self._input_audio_stream = audio_stream
def reset(self, video_dts: int) -> None:
"""Initialize a new stream segment."""
# Keep track of the number of segments we've processed
self._sequence += 1
self._part_start_dts = self._segment_start_dts = video_dts
self._segment = None
self._memory_file = BytesIO()
@ -163,7 +200,7 @@ class SegmentBuffer:
self._output_audio_stream,
) = self.make_new_av(
memory_file=self._memory_file,
sequence=self._sequence,
sequence=self._stream_state.next_sequence(),
input_vstream=self._input_video_stream,
input_astream=self._input_audio_stream,
)
@ -201,12 +238,12 @@ class SegmentBuffer:
# We have our first non-zero byte position. This means the init has just
# been written. Create a Segment and put it to the queue of each output.
self._segment = Segment(
sequence=self._sequence,
stream_id=self._stream_id,
sequence=self._stream_state.sequence,
stream_id=self._stream_state.stream_id,
init=self._memory_file.getvalue(),
# Fetch the latest StreamOutputs, which may have changed since the
# worker started.
stream_outputs=self._outputs_callback().values(),
stream_outputs=self._stream_state.outputs,
start_time=self._start_time,
)
self._memory_file_pos = self._memory_file.tell()
@ -283,17 +320,6 @@ class SegmentBuffer:
self._part_start_dts = adjusted_dts
self._part_has_keyframe = False
def discontinuity(self) -> None:
"""Mark the stream as having been restarted."""
# Preserving sequence and stream_id here keep the HLS playlist logic
# simple to check for discontinuity at output time, and to determine
# the discontinuity sequence number.
self._stream_id += 1
self._start_time = datetime.datetime.utcnow()
# Call discontinuity to remove incomplete segment from the HLS output
if hls_output := self._outputs_callback().get(HLS_PROVIDER):
cast(HlsStreamOutput, hls_output).discontinuity()
def close(self) -> None:
"""Close stream buffer."""
self._av_output.close()
@ -412,7 +438,7 @@ def unsupported_audio(packets: Iterator[av.Packet], audio_stream: Any) -> bool:
def stream_worker(
source: str,
options: dict[str, str],
segment_buffer: SegmentBuffer,
stream_state: StreamState,
quit_event: Event,
) -> None:
"""Handle consuming streams."""
@ -431,6 +457,8 @@ def stream_worker(
audio_stream = container.streams.audio[0]
except (KeyError, IndexError):
audio_stream = None
if audio_stream and audio_stream.name not in AUDIO_CODECS:
audio_stream = None
# These formats need aac_adtstoasc bitstream filter, but auto_bsf not
# compatible with empty_moov and manual bitstream filters not in PyAV
if container.format.name in {"hls", "mpegts"}:
@ -489,13 +517,13 @@ def stream_worker(
"Error demuxing stream while finding first packet: %s" % str(ex)
) from ex
segment_buffer.set_streams(video_stream, audio_stream)
segment_buffer.reset(start_dts)
muxer = StreamMuxer(stream_state.hass, video_stream, audio_stream, stream_state)
muxer.reset(start_dts)
# Mux the first keyframe, then proceed through the rest of the packets
segment_buffer.mux_packet(first_keyframe)
muxer.mux_packet(first_keyframe)
with contextlib.closing(container), contextlib.closing(segment_buffer):
with contextlib.closing(container), contextlib.closing(muxer):
while not quit_event.is_set():
try:
packet = next(container_packets)
@ -506,4 +534,4 @@ def stream_worker(
except av.AVError as ex:
raise StreamWorkerError("Error demuxing stream: %s" % str(ex)) from ex
segment_buffer.mux_packet(packet)
muxer.mux_packet(packet)

View File

@ -23,7 +23,7 @@ import async_timeout
import pytest
from homeassistant.components.stream.core import Segment, StreamOutput
from homeassistant.components.stream.worker import SegmentBuffer
from homeassistant.components.stream.worker import StreamState
TEST_TIMEOUT = 7.0 # Lower than 9s home assistant timeout
@ -34,7 +34,7 @@ class WorkerSync:
def __init__(self):
"""Initialize WorkerSync."""
self._event = None
self._original = SegmentBuffer.discontinuity
self._original = StreamState.discontinuity
def pause(self):
"""Pause the worker before it finalizes the stream."""
@ -45,7 +45,7 @@ class WorkerSync:
logging.debug("waking blocked worker")
self._event.set()
def blocking_discontinuity(self, stream: SegmentBuffer):
def blocking_discontinuity(self, stream_state: StreamState):
"""Intercept call to pause stream worker."""
# Worker is ending the stream, which clears all output buffers.
# Block the worker thread until the test has a chance to verify
@ -55,7 +55,7 @@ class WorkerSync:
self._event.wait()
# Forward to actual implementation
self._original(stream)
self._original(stream_state)
@pytest.fixture()
@ -63,7 +63,7 @@ def stream_worker_sync(hass):
"""Patch StreamOutput to allow test to synchronize worker stream end."""
sync = WorkerSync()
with patch(
"homeassistant.components.stream.worker.SegmentBuffer.discontinuity",
"homeassistant.components.stream.worker.StreamState.discontinuity",
side_effect=sync.blocking_discontinuity,
autospec=True,
):

View File

@ -38,8 +38,8 @@ from homeassistant.components.stream.const import (
)
from homeassistant.components.stream.core import StreamSettings
from homeassistant.components.stream.worker import (
SegmentBuffer,
StreamEndedError,
StreamState,
StreamWorkerError,
stream_worker,
)
@ -255,6 +255,12 @@ class MockPyAv:
return self.container
def run_worker(hass, stream, stream_source):
"""Run the stream worker under test."""
stream_state = StreamState(hass, stream.outputs)
stream_worker(stream_source, {}, stream_state, threading.Event())
async def async_decode_stream(hass, packets, py_av=None):
"""Start a stream worker that decodes incoming stream packets into output segments."""
stream = Stream(hass, STREAM_SOURCE, {})
@ -268,9 +274,8 @@ async def async_decode_stream(hass, packets, py_av=None):
"homeassistant.components.stream.core.StreamOutput.put",
side_effect=py_av.capture_buffer.capture_output_segment,
):
segment_buffer = SegmentBuffer(hass, stream.outputs)
try:
stream_worker(STREAM_SOURCE, {}, segment_buffer, threading.Event())
run_worker(hass, stream, STREAM_SOURCE)
except StreamEndedError:
# Tests only use a limited number of packets, then the worker exits as expected. In
# production, stream ending would be unexpected.
@ -288,8 +293,7 @@ async def test_stream_open_fails(hass):
stream.add_provider(HLS_PROVIDER)
with patch("av.open") as av_open, pytest.raises(StreamWorkerError):
av_open.side_effect = av.error.InvalidDataError(-2, "error")
segment_buffer = SegmentBuffer(hass, stream.outputs)
stream_worker(STREAM_SOURCE, {}, segment_buffer, threading.Event())
run_worker(hass, stream, STREAM_SOURCE)
await hass.async_block_till_done()
av_open.assert_called_once()
@ -695,10 +699,7 @@ async def test_worker_log(hass, caplog):
with patch("av.open") as av_open, pytest.raises(StreamWorkerError) as err:
av_open.side_effect = av.error.InvalidDataError(-2, "error")
segment_buffer = SegmentBuffer(hass, stream.outputs)
stream_worker(
"https://abcd:efgh@foo.bar", {}, segment_buffer, threading.Event()
)
run_worker(hass, stream, "https://abcd:efgh@foo.bar")
await hass.async_block_till_done()
assert str(err.value) == "Error opening stream https://****:****@foo.bar"
assert "https://abcd:efgh@foo.bar" not in caplog.text