From 123e8f01a19ec075b597f1052245a24c89388bf6 Mon Sep 17 00:00:00 2001 From: uvjustin <46082645+uvjustin@users.noreply.github.com> Date: Mon, 14 Jun 2021 00:41:21 +0800 Subject: [PATCH] Refactor stream to create partial segments (#51282) --- homeassistant/components/stream/const.py | 3 +- homeassistant/components/stream/core.py | 29 ++- homeassistant/components/stream/fmp4utils.py | 10 - homeassistant/components/stream/hls.py | 86 +++++---- homeassistant/components/stream/recorder.py | 2 +- homeassistant/components/stream/worker.py | 174 +++++++++++------ tests/components/stream/conftest.py | 62 ++++++ tests/components/stream/test_hls.py | 141 ++++++-------- tests/components/stream/test_recorder.py | 107 +++-------- tests/components/stream/test_worker.py | 190 +++++++++++++++---- 10 files changed, 499 insertions(+), 305 deletions(-) diff --git a/homeassistant/components/stream/const.py b/homeassistant/components/stream/const.py index 62d13321f91..cf4a80d9705 100644 --- a/homeassistant/components/stream/const.py +++ b/homeassistant/components/stream/const.py @@ -18,8 +18,9 @@ FORMAT_CONTENT_TYPE = {HLS_PROVIDER: "application/vnd.apple.mpegurl"} OUTPUT_IDLE_TIMEOUT = 300 # Idle timeout due to inactivity NUM_PLAYLIST_SEGMENTS = 3 # Number of segments to use in HLS playlist -MAX_SEGMENTS = 4 # Max number of segments to keep around +MAX_SEGMENTS = 5 # Max number of segments to keep around TARGET_SEGMENT_DURATION = 2.0 # Each segment is about this many seconds +TARGET_PART_DURATION = 1.0 SEGMENT_DURATION_ADJUSTER = 0.1 # Used to avoid missing keyframe boundaries # Each segment is at least this many seconds MIN_SEGMENT_DURATION = TARGET_SEGMENT_DURATION - SEGMENT_DURATION_ADJUSTER diff --git a/homeassistant/components/stream/core.py b/homeassistant/components/stream/core.py index f3d30fa6e1b..136c3c1dbfa 100644 --- a/homeassistant/components/stream/core.py +++ b/homeassistant/components/stream/core.py @@ -19,20 +19,37 @@ from .const import ATTR_STREAMS, DOMAIN PROVIDERS = Registry() +@attr.s(slots=True) +class Part: + """Represent a segment part.""" + + duration: float = attr.ib() + has_keyframe: bool = attr.ib() + data: bytes = attr.ib() + + @attr.s(slots=True) class Segment: """Represent a segment.""" - sequence: int = attr.ib() - # the init of the mp4 - init: bytes = attr.ib() - # the video data (moof + mddat)s of the mp4 - moof_data: bytes = attr.ib() - duration: float = attr.ib() + sequence: int = attr.ib(default=0) + # the init of the mp4 the segment is based on + init: bytes = attr.ib(default=None) + duration: float = attr.ib(default=0) # For detecting discontinuities across stream restarts stream_id: int = attr.ib(default=0) + parts: list[Part] = attr.ib(factory=list) start_time: datetime.datetime = attr.ib(factory=datetime.datetime.utcnow) + @property + def complete(self) -> bool: + """Return whether the Segment is complete.""" + return self.duration > 0 + + def get_bytes_without_init(self) -> bytes: + """Return reconstructed data for entire segment as bytes.""" + return b"".join([part.data for part in self.parts]) + class IdleTimer: """Invoke a callback after an inactivity timeout. diff --git a/homeassistant/components/stream/fmp4utils.py b/homeassistant/components/stream/fmp4utils.py index 511bbc0939a..ef01158be62 100644 --- a/homeassistant/components/stream/fmp4utils.py +++ b/homeassistant/components/stream/fmp4utils.py @@ -25,16 +25,6 @@ def find_box( index += int.from_bytes(box_header[0:4], byteorder="big") -def get_init_and_moof_data(segment: memoryview) -> tuple[bytes, bytes]: - """Get the init and moof data from a segment.""" - moof_location = next(find_box(segment, b"moof"), 0) - mfra_location = next(find_box(segment, b"mfra"), len(segment)) - return ( - segment[:moof_location].tobytes(), - segment[moof_location:mfra_location].tobytes(), - ) - - def get_codec_string(mp4_bytes: bytes) -> str: """Get RFC 6381 codec string.""" codecs = [] diff --git a/homeassistant/components/stream/hls.py b/homeassistant/components/stream/hls.py index 1d2921df192..0b0cd4ac3b2 100644 --- a/homeassistant/components/stream/hls.py +++ b/homeassistant/components/stream/hls.py @@ -37,9 +37,12 @@ class HlsMasterPlaylistView(StreamView): # Need to calculate max bandwidth as input_container.bit_rate doesn't seem to work # Calculate file size / duration and use a small multiplier to account for variation # hls spec already allows for 25% variation - segment = track.get_segment(track.sequences[-1]) + segment = track.get_segment(track.sequences[-2]) bandwidth = round( - (len(segment.init) + len(segment.moof_data)) * 8 / segment.duration * 1.2 + (len(segment.init) + sum(len(part.data) for part in segment.parts)) + * 8 + / segment.duration + * 1.2 ) codecs = get_codec_string(segment.init) lines = [ @@ -53,9 +56,11 @@ class HlsMasterPlaylistView(StreamView): """Return m3u8 playlist.""" track = stream.add_provider(HLS_PROVIDER) stream.start() - # Wait for a segment to be ready + # Make sure at least two segments are ready (last one may not be complete) if not track.sequences and not await track.recv(): return web.HTTPNotFound() + if len(track.sequences) == 1 and not await track.recv(): + return web.HTTPNotFound() headers = {"Content-Type": FORMAT_CONTENT_TYPE[HLS_PROVIDER]} return web.Response(body=self.render(track).encode("utf-8"), headers=headers) @@ -68,69 +73,72 @@ class HlsPlaylistView(StreamView): cors_allowed = True @staticmethod - def render_preamble(track): - """Render preamble.""" - return [ - "#EXT-X-VERSION:6", - f"#EXT-X-TARGETDURATION:{track.target_duration}", - '#EXT-X-MAP:URI="init.mp4"', - ] - - @staticmethod - def render_playlist(track): + def render(track): """Render playlist.""" - segments = list(track.get_segments())[-NUM_PLAYLIST_SEGMENTS:] + # NUM_PLAYLIST_SEGMENTS+1 because most recent is probably not yet complete + segments = list(track.get_segments())[-(NUM_PLAYLIST_SEGMENTS + 1) :] - if not segments: - return [] + # To cap the number of complete segments at NUM_PLAYLIST_SEGMENTS, + # remove the first segment if the last segment is actually complete + if segments[-1].complete: + segments = segments[-NUM_PLAYLIST_SEGMENTS:] first_segment = segments[0] playlist = [ + "#EXTM3U", + "#EXT-X-VERSION:6", + "#EXT-X-INDEPENDENT-SEGMENTS", + '#EXT-X-MAP:URI="init.mp4"', + f"#EXT-X-TARGETDURATION:{track.target_duration:.0f}", f"#EXT-X-MEDIA-SEQUENCE:{first_segment.sequence}", f"#EXT-X-DISCONTINUITY-SEQUENCE:{first_segment.stream_id}", "#EXT-X-PROGRAM-DATE-TIME:" + first_segment.start_time.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z", # Since our window doesn't have many segments, we don't want to start - # at the beginning or we risk a behind live window exception in exoplayer. + # at the beginning or we risk a behind live window exception in Exoplayer. # EXT-X-START is not supposed to be within 3 target durations of the end, - # but this seems ok - f"#EXT-X-START:TIME-OFFSET=-{EXT_X_START * track.target_duration:.3f},PRECISE=YES", + # but a value as low as 1.5 doesn't seem to hurt. + # A value below 3 may not be as useful for hls.js as many hls.js clients + # don't autoplay. Also, hls.js uses the player parameter liveSyncDuration + # which seems to take precedence for setting target delay. Yet it also + # doesn't seem to hurt, so we can stick with it for now. + f"#EXT-X-START:TIME-OFFSET=-{EXT_X_START * track.target_duration:.3f}", ] last_stream_id = first_segment.stream_id + # Add playlist sections for segment in segments: - if last_stream_id != segment.stream_id: + # Skip last segment if it is not complete + if segment.complete: + if last_stream_id != segment.stream_id: + playlist.extend( + [ + "#EXT-X-DISCONTINUITY", + "#EXT-X-PROGRAM-DATE-TIME:" + + segment.start_time.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + + "Z", + ] + ) playlist.extend( [ - "#EXT-X-DISCONTINUITY", - "#EXT-X-PROGRAM-DATE-TIME:" - + segment.start_time.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] - + "Z", + f"#EXTINF:{segment.duration:.3f},", + f"./segment/{segment.sequence}.m4s", ] ) - playlist.extend( - [ - f"#EXTINF:{float(segment.duration):.04f},", - f"./segment/{segment.sequence}.m4s", - ] - ) - last_stream_id = segment.stream_id + last_stream_id = segment.stream_id - return playlist - - def render(self, track): - """Render M3U8 file.""" - lines = ["#EXTM3U"] + self.render_preamble(track) + self.render_playlist(track) - return "\n".join(lines) + "\n" + return "\n".join(playlist) + "\n" async def handle(self, request, stream, sequence): """Return m3u8 playlist.""" track = stream.add_provider(HLS_PROVIDER) stream.start() - # Wait for a segment to be ready + # Make sure at least two segments are ready (last one may not be complete) if not track.sequences and not await track.recv(): return web.HTTPNotFound() + if len(track.sequences) == 1 and not await track.recv(): + return web.HTTPNotFound() headers = {"Content-Type": FORMAT_CONTENT_TYPE[HLS_PROVIDER]} response = web.Response( body=self.render(track).encode("utf-8"), headers=headers @@ -170,7 +178,7 @@ class HlsSegmentView(StreamView): return web.HTTPNotFound() headers = {"Content-Type": "video/iso.segment"} return web.Response( - body=segment.moof_data, + body=segment.get_bytes_without_init(), headers=headers, ) diff --git a/homeassistant/components/stream/recorder.py b/homeassistant/components/stream/recorder.py index ac5f102e625..8e21777fa0b 100644 --- a/homeassistant/components/stream/recorder.py +++ b/homeassistant/components/stream/recorder.py @@ -57,7 +57,7 @@ def recorder_save_worker(file_out: str, segments: deque[Segment]): # Open segment source = av.open( - BytesIO(segment.init + segment.moof_data), + BytesIO(segment.init + segment.get_bytes_without_init()), "r", format=SEGMENT_CONTAINER_FORMAT, ) diff --git a/homeassistant/components/stream/worker.py b/homeassistant/components/stream/worker.py index c606d1ad0dc..cca981e5db3 100644 --- a/homeassistant/components/stream/worker.py +++ b/homeassistant/components/stream/worker.py @@ -2,9 +2,12 @@ from __future__ import annotations from collections import deque +from collections.abc import Iterator, Mapping +from fractions import Fraction from io import BytesIO import logging -from typing import cast +from threading import Event +from typing import Callable, cast import av @@ -17,9 +20,9 @@ from .const import ( PACKETS_TO_WAIT_FOR_AUDIO, SEGMENT_CONTAINER_FORMAT, SOURCE_TIMEOUT, + TARGET_PART_DURATION, ) -from .core import Segment, StreamOutput -from .fmp4utils import get_init_and_moof_data +from .core import Part, Segment, StreamOutput _LOGGER = logging.getLogger(__name__) @@ -27,22 +30,28 @@ _LOGGER = logging.getLogger(__name__) class SegmentBuffer: """Buffer for writing a sequence of packets to the output as a segment.""" - def __init__(self, outputs_callback) -> None: + def __init__( + self, outputs_callback: Callable[[], Mapping[str, StreamOutput]] + ) -> None: """Initialize SegmentBuffer.""" - self._stream_id = 0 - self._outputs_callback = outputs_callback - self._outputs: list[StreamOutput] = [] + self._stream_id: int = 0 + self._outputs_callback: Callable[ + [], Mapping[str, StreamOutput] + ] = outputs_callback # sequence gets incremented before the first segment so the first segment # has a sequence number of 0. self._sequence = -1 - self._segment_start_pts = None + self._segment_start_dts: int = cast(int, None) self._memory_file: BytesIO = cast(BytesIO, None) self._av_output: av.container.OutputContainer = None self._input_video_stream: av.video.VideoStream = None self._input_audio_stream = None # av.audio.AudioStream | None self._output_video_stream: av.video.VideoStream = None self._output_audio_stream = None # av.audio.AudioStream | None - self._segment: Segment = cast(Segment, None) + self._segment: Segment | None = None + self._segment_last_write_pos: int = cast(int, None) + self._part_start_dts: int = cast(int, None) + self._part_has_keyframe = False @staticmethod def make_new_av( @@ -56,10 +65,17 @@ class SegmentBuffer: container_options={ # Removed skip_sidx - see https://github.com/home-assistant/core/pull/39970 # "cmaf" flag replaces several of the movflags used, but too recent to use for now - "movflags": "frag_custom+empty_moov+default_base_moof+frag_discont+negative_cts_offsets+skip_trailer", - "avoid_negative_ts": "disabled", + "movflags": "empty_moov+default_base_moof+frag_discont+negative_cts_offsets+skip_trailer", + # Sometimes the first segment begins with negative timestamps, and this setting just + # adjusts the timestamps in the output from that segment to start from 0. Helps from + # having to make some adjustments in test_durations + "avoid_negative_ts": "make_non_negative", "fragment_index": str(sequence + 1), "video_track_timescale": str(int(1 / input_vstream.time_base)), + # Create a fragments every TARGET_PART_DURATION. The data from each fragment is stored in + # a "Part" that can be combined with the data from all the other "Part"s, plus an init + # section, to reconstitute the data in a "Segment". + "frag_duration": str(int(TARGET_PART_DURATION * 1e6)), }, ) @@ -73,15 +89,13 @@ class SegmentBuffer: self._input_video_stream = video_stream self._input_audio_stream = audio_stream - def reset(self, video_pts): + def reset(self, video_dts: int) -> None: """Initialize a new stream segment.""" # Keep track of the number of segments we've processed self._sequence += 1 - self._segment_start_pts = video_pts - - # Fetch the latest StreamOutputs, which may have changed since the - # worker started. - self._outputs = self._outputs_callback().values() + self._segment_start_dts = self._part_start_dts = video_dts + self._segment = None + self._segment_last_write_pos = 0 self._memory_file = BytesIO() self._av_output = self.make_new_av( memory_file=self._memory_file, @@ -98,54 +112,102 @@ class SegmentBuffer: template=self._input_audio_stream ) - def mux_packet(self, packet): + def mux_packet(self, packet: av.Packet) -> None: """Mux a packet to the appropriate output stream.""" # Check for end of segment - if packet.stream == self._input_video_stream and packet.is_keyframe: - duration = (packet.pts - self._segment_start_pts) * packet.time_base - if duration >= MIN_SEGMENT_DURATION: - # Save segment to outputs - self.flush(duration) - - # Reinitialize - self.reset(packet.pts) - - # Mux the packet if packet.stream == self._input_video_stream: + + if ( + packet.is_keyframe + and ( + segment_duration := (packet.dts - self._segment_start_dts) + * packet.time_base + ) + >= MIN_SEGMENT_DURATION + ): + # Flush segment (also flushes the stub part segment) + self.flush(segment_duration, packet) + # Reinitialize + self.reset(packet.dts) + + # Mux the packet packet.stream = self._output_video_stream self._av_output.mux(packet) + self.check_flush_part(packet) + self._part_has_keyframe |= packet.is_keyframe + elif packet.stream == self._input_audio_stream: packet.stream = self._output_audio_stream self._av_output.mux(packet) - def flush(self, duration): + def check_flush_part(self, packet: av.Packet) -> None: + """Check for and mark a part segment boundary and record its duration.""" + byte_position = self._memory_file.tell() + if self._segment_last_write_pos == byte_position: + return + if self._segment is None: + # We have our first non-zero byte position. This means the init has just + # been written. Create a Segment and put it to the queue of each output. + self._segment = Segment( + sequence=self._sequence, + stream_id=self._stream_id, + init=self._memory_file.getvalue(), + ) + self._segment_last_write_pos = byte_position + # Fetch the latest StreamOutputs, which may have changed since the + # worker started. + for stream_output in self._outputs_callback().values(): + stream_output.put(self._segment) + else: # These are the ends of the part segments + self._segment.parts.append( + Part( + duration=float( + (packet.dts - self._part_start_dts) * packet.time_base + ), + has_keyframe=self._part_has_keyframe, + data=self._memory_file.getbuffer()[ + self._segment_last_write_pos : byte_position + ].tobytes(), + ) + ) + self._segment_last_write_pos = byte_position + self._part_start_dts = packet.dts + self._part_has_keyframe = False + + def flush(self, duration: Fraction, packet: av.Packet) -> None: """Create a segment from the buffered packets and write to output.""" self._av_output.close() - segment = Segment( - self._sequence, - *get_init_and_moof_data(self._memory_file.getbuffer()), - duration, - self._stream_id, + assert self._segment + self._segment.duration = float(duration) + # Also flush the part segment (need to close the output above before this) + self._segment.parts.append( + Part( + duration=float((packet.dts - self._part_start_dts) * packet.time_base), + has_keyframe=self._part_has_keyframe, + data=self._memory_file.getbuffer()[ + self._segment_last_write_pos : + ].tobytes(), + ) ) - self._memory_file.close() - for stream_output in self._outputs: - stream_output.put(segment) + self._memory_file.close() # We don't need the BytesIO object anymore - def discontinuity(self): + def discontinuity(self) -> None: """Mark the stream as having been restarted.""" # Preserving sequence and stream_id here keep the HLS playlist logic # simple to check for discontinuity at output time, and to determine # the discontinuity sequence number. self._stream_id += 1 - def close(self): + def close(self) -> None: """Close stream buffer.""" self._av_output.close() self._memory_file.close() -def stream_worker(source, options, segment_buffer, quit_event): # noqa: C901 +def stream_worker( # noqa: C901 + source: str, options: dict, segment_buffer: SegmentBuffer, quit_event: Event +) -> None: """Handle consuming streams.""" try: @@ -172,27 +234,27 @@ def stream_worker(source, options, segment_buffer, quit_event): # noqa: C901 audio_stream = None # Iterator for demuxing - container_packets = None + container_packets: Iterator[av.Packet] # The decoder timestamps of the latest packet in each stream we processed last_dts = {video_stream: float("-inf"), audio_stream: float("-inf")} # Keep track of consecutive packets without a dts to detect end of stream. missing_dts = 0 - # The video pts at the beginning of the segment - segment_start_pts = None + # The video dts at the beginning of the segment + segment_start_dts: int | None = None # Because of problems 1 and 2 below, we need to store the first few packets and replay them - initial_packets = deque() + initial_packets: deque[av.Packet] = deque() # Have to work around two problems with RTSP feeds in ffmpeg # 1 - first frame has bad pts/dts https://trac.ffmpeg.org/ticket/5018 # 2 - seeking can be problematic https://trac.ffmpeg.org/ticket/7815 - def peek_first_pts(): + def peek_first_dts() -> bool: """Initialize by peeking into the first few packets of the stream. Deal with problem #1 above (bad first packet pts/dts) by recalculating using pts/dts from second packet. - Also load the first video keyframe pts into segment_start_pts and check if the audio stream really exists. + Also load the first video keyframe dts into segment_start_dts and check if the audio stream really exists. """ - nonlocal segment_start_pts, audio_stream, container_packets + nonlocal segment_start_dts, audio_stream, container_packets missing_dts = 0 found_audio = False try: @@ -215,8 +277,8 @@ def stream_worker(source, options, segment_buffer, quit_event): # noqa: C901 elif packet.is_keyframe: # video_keyframe first_packet = packet initial_packets.append(packet) - # Get first_pts from subsequent frame to first keyframe - while segment_start_pts is None or ( + # Get first_dts from subsequent frame to first keyframe + while segment_start_dts is None or ( audio_stream and not found_audio and len(initial_packets) < PACKETS_TO_WAIT_FOR_AUDIO @@ -244,11 +306,10 @@ def stream_worker(source, options, segment_buffer, quit_event): # noqa: C901 continue found_audio = True elif ( - segment_start_pts is None - ): # This is the second video frame to calculate first_pts from - segment_start_pts = packet.dts - packet.duration - first_packet.pts = segment_start_pts - first_packet.dts = segment_start_pts + segment_start_dts is None + ): # This is the second video frame to calculate first_dts from + segment_start_dts = packet.dts - packet.duration + first_packet.pts = first_packet.dts = segment_start_dts initial_packets.append(packet) if audio_stream and not found_audio: _LOGGER.warning( @@ -263,12 +324,13 @@ def stream_worker(source, options, segment_buffer, quit_event): # noqa: C901 return False return True - if not peek_first_pts(): + if not peek_first_dts(): container.close() return segment_buffer.set_streams(video_stream, audio_stream) - segment_buffer.reset(segment_start_pts) + assert isinstance(segment_start_dts, int) + segment_buffer.reset(segment_start_dts) while not quit_event.is_set(): try: diff --git a/tests/components/stream/conftest.py b/tests/components/stream/conftest.py index ead2018b528..a73678d763f 100644 --- a/tests/components/stream/conftest.py +++ b/tests/components/stream/conftest.py @@ -9,13 +9,21 @@ nothing for the test to verify. The solution is the WorkerSync class that allows the tests to pause the worker thread before finalizing the stream so that it can inspect the output. """ +from __future__ import annotations + +import asyncio +from collections import deque import logging import threading from unittest.mock import patch +import async_timeout import pytest from homeassistant.components.stream import Stream +from homeassistant.components.stream.core import Segment + +TEST_TIMEOUT = 7.0 # Lower than 9s home assistant timeout class WorkerSync: @@ -58,3 +66,57 @@ def stream_worker_sync(hass): autospec=True, ): yield sync + + +class SaveRecordWorkerSync: + """ + Test fixture to manage RecordOutput thread for recorder_save_worker. + + This is used to assert that the worker is started and stopped cleanly + to avoid thread leaks in tests. + """ + + def __init__(self): + """Initialize SaveRecordWorkerSync.""" + self._save_event = None + self._segments = None + self._save_thread = None + self.reset() + + def recorder_save_worker(self, file_out: str, segments: deque[Segment]): + """Mock method for patch.""" + logging.debug("recorder_save_worker thread started") + assert self._save_thread is None + self._segments = segments + self._save_thread = threading.current_thread() + self._save_event.set() + + async def get_segments(self): + """Return the recorded video segments.""" + with async_timeout.timeout(TEST_TIMEOUT): + await self._save_event.wait() + return self._segments + + async def join(self): + """Verify save worker was invoked and block on shutdown.""" + with async_timeout.timeout(TEST_TIMEOUT): + await self._save_event.wait() + self._save_thread.join(timeout=TEST_TIMEOUT) + assert not self._save_thread.is_alive() + + def reset(self): + """Reset callback state for reuse in tests.""" + self._save_thread = None + self._save_event = asyncio.Event() + + +@pytest.fixture() +def record_worker_sync(hass): + """Patch recorder_save_worker for clean thread shutdown for test.""" + sync = SaveRecordWorkerSync() + with patch( + "homeassistant.components.stream.recorder.recorder_save_worker", + side_effect=sync.recorder_save_worker, + autospec=True, + ): + yield sync diff --git a/tests/components/stream/test_hls.py b/tests/components/stream/test_hls.py index a31c686dcaf..37c499b6bd0 100644 --- a/tests/components/stream/test_hls.py +++ b/tests/components/stream/test_hls.py @@ -12,7 +12,7 @@ from homeassistant.components.stream.const import ( MAX_SEGMENTS, NUM_PLAYLIST_SEGMENTS, ) -from homeassistant.components.stream.core import Segment +from homeassistant.components.stream.core import Part, Segment from homeassistant.const import HTTP_NOT_FOUND from homeassistant.setup import async_setup_component import homeassistant.util.dt as dt_util @@ -22,7 +22,7 @@ from tests.components.stream.common import generate_h264_video STREAM_SOURCE = "some-stream-source" INIT_BYTES = b"init" -MOOF_BYTES = b"some-bytes" +FAKE_PAYLOAD = b"fake-payload" SEGMENT_DURATION = 10 TEST_TIMEOUT = 5.0 # Lower than 9s home assistant timeout MAX_ABORT_SEGMENTS = 20 # Abort test to avoid looping forever @@ -70,23 +70,24 @@ def make_segment(segment, discontinuity=False): + "Z", ] ) - response.extend(["#EXTINF:10.0000,", f"./segment/{segment}.m4s"]), + response.extend([f"#EXTINF:{SEGMENT_DURATION:.3f},", f"./segment/{segment}.m4s"]) return "\n".join(response) -def make_playlist(sequence, discontinuity_sequence=0, segments=[]): +def make_playlist(sequence, segments, discontinuity_sequence=0): """Create a an hls playlist response for tests to assert on.""" response = [ "#EXTM3U", "#EXT-X-VERSION:6", - "#EXT-X-TARGETDURATION:10", + "#EXT-X-INDEPENDENT-SEGMENTS", '#EXT-X-MAP:URI="init.mp4"', + "#EXT-X-TARGETDURATION:10", f"#EXT-X-MEDIA-SEQUENCE:{sequence}", f"#EXT-X-DISCONTINUITY-SEQUENCE:{discontinuity_sequence}", "#EXT-X-PROGRAM-DATE-TIME:" + FAKE_TIME.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z", - f"#EXT-X-START:TIME-OFFSET=-{1.5*SEGMENT_DURATION:.3f},PRECISE=YES", + f"#EXT-X-START:TIME-OFFSET=-{1.5*SEGMENT_DURATION:.3f}", ] response.extend(segments) response.append("") @@ -264,21 +265,26 @@ async def test_hls_playlist_view(hass, hls_stream, stream_worker_sync): stream_worker_sync.pause() hls = stream.add_provider(HLS_PROVIDER) - hls.put(Segment(1, INIT_BYTES, MOOF_BYTES, SEGMENT_DURATION, start_time=FAKE_TIME)) + for i in range(2): + segment = Segment(sequence=i, duration=SEGMENT_DURATION, start_time=FAKE_TIME) + hls.put(segment) await hass.async_block_till_done() hls_client = await hls_stream(stream) resp = await hls_client.get("/playlist.m3u8") assert resp.status == 200 - assert await resp.text() == make_playlist(sequence=1, segments=[make_segment(1)]) + assert await resp.text() == make_playlist( + sequence=0, segments=[make_segment(0), make_segment(1)] + ) - hls.put(Segment(2, INIT_BYTES, MOOF_BYTES, SEGMENT_DURATION, start_time=FAKE_TIME)) + segment = Segment(sequence=2, duration=SEGMENT_DURATION, start_time=FAKE_TIME) + hls.put(segment) await hass.async_block_till_done() resp = await hls_client.get("/playlist.m3u8") assert resp.status == 200 assert await resp.text() == make_playlist( - sequence=1, segments=[make_segment(1), make_segment(2)] + sequence=0, segments=[make_segment(0), make_segment(1), make_segment(2)] ) stream_worker_sync.resume() @@ -296,37 +302,40 @@ async def test_hls_max_segments(hass, hls_stream, stream_worker_sync): hls_client = await hls_stream(stream) # Produce enough segments to overfill the output buffer by one - for sequence in range(1, MAX_SEGMENTS + 2): - hls.put( - Segment( - sequence, - INIT_BYTES, - MOOF_BYTES, - SEGMENT_DURATION, - start_time=FAKE_TIME, - ) + for sequence in range(MAX_SEGMENTS + 1): + segment = Segment( + sequence=sequence, duration=SEGMENT_DURATION, start_time=FAKE_TIME ) + hls.put(segment) await hass.async_block_till_done() resp = await hls_client.get("/playlist.m3u8") assert resp.status == 200 # Only NUM_PLAYLIST_SEGMENTS are returned in the playlist. - start = MAX_SEGMENTS + 2 - NUM_PLAYLIST_SEGMENTS + start = MAX_SEGMENTS + 1 - NUM_PLAYLIST_SEGMENTS segments = [] - for sequence in range(start, MAX_SEGMENTS + 2): + for sequence in range(start, MAX_SEGMENTS + 1): segments.append(make_segment(sequence)) - assert await resp.text() == make_playlist( - sequence=start, - segments=segments, - ) + assert await resp.text() == make_playlist(sequence=start, segments=segments) + + # Fetch the actual segments with a fake byte payload + for segment in hls.get_segments(): + segment.init = INIT_BYTES + segment.parts = [ + Part( + duration=SEGMENT_DURATION, + has_keyframe=True, + data=FAKE_PAYLOAD, + ) + ] # The segment that fell off the buffer is not accessible - segment_response = await hls_client.get("/segment/1.m4s") + segment_response = await hls_client.get("/segment/0.m4s") assert segment_response.status == 404 # However all segments in the buffer are accessible, even those that were not in the playlist. - for sequence in range(2, MAX_SEGMENTS + 2): + for sequence in range(1, MAX_SEGMENTS + 1): segment_response = await hls_client.get(f"/segment/{sequence}.m4s") assert segment_response.status == 200 @@ -342,36 +351,21 @@ async def test_hls_playlist_view_discontinuity(hass, hls_stream, stream_worker_s stream_worker_sync.pause() hls = stream.add_provider(HLS_PROVIDER) - hls.put( - Segment( - 1, - INIT_BYTES, - MOOF_BYTES, - SEGMENT_DURATION, - stream_id=0, - start_time=FAKE_TIME, - ) + segment = Segment( + sequence=0, stream_id=0, duration=SEGMENT_DURATION, start_time=FAKE_TIME ) - hls.put( - Segment( - 2, - INIT_BYTES, - MOOF_BYTES, - SEGMENT_DURATION, - stream_id=0, - start_time=FAKE_TIME, - ) + hls.put(segment) + segment = Segment( + sequence=1, stream_id=0, duration=SEGMENT_DURATION, start_time=FAKE_TIME ) - hls.put( - Segment( - 3, - INIT_BYTES, - MOOF_BYTES, - SEGMENT_DURATION, - stream_id=1, - start_time=FAKE_TIME, - ) + hls.put(segment) + segment = Segment( + sequence=2, + stream_id=1, + duration=SEGMENT_DURATION, + start_time=FAKE_TIME, ) + hls.put(segment) await hass.async_block_till_done() hls_client = await hls_stream(stream) @@ -379,11 +373,11 @@ async def test_hls_playlist_view_discontinuity(hass, hls_stream, stream_worker_s resp = await hls_client.get("/playlist.m3u8") assert resp.status == 200 assert await resp.text() == make_playlist( - sequence=1, + sequence=0, segments=[ + make_segment(0), make_segment(1), - make_segment(2), - make_segment(3, discontinuity=True), + make_segment(2, discontinuity=True), ], ) @@ -401,29 +395,20 @@ async def test_hls_max_segments_discontinuity(hass, hls_stream, stream_worker_sy hls_client = await hls_stream(stream) - hls.put( - Segment( - 1, - INIT_BYTES, - MOOF_BYTES, - SEGMENT_DURATION, - stream_id=0, - start_time=FAKE_TIME, - ) + segment = Segment( + sequence=0, stream_id=0, duration=SEGMENT_DURATION, start_time=FAKE_TIME ) + hls.put(segment) # Produce enough segments to overfill the output buffer by one - for sequence in range(1, MAX_SEGMENTS + 2): - hls.put( - Segment( - sequence, - INIT_BYTES, - MOOF_BYTES, - SEGMENT_DURATION, - stream_id=1, - start_time=FAKE_TIME, - ) + for sequence in range(MAX_SEGMENTS + 1): + segment = Segment( + sequence=sequence, + stream_id=1, + duration=SEGMENT_DURATION, + start_time=FAKE_TIME, ) + hls.put(segment) await hass.async_block_till_done() resp = await hls_client.get("/playlist.m3u8") @@ -432,9 +417,9 @@ async def test_hls_max_segments_discontinuity(hass, hls_stream, stream_worker_sy # Only NUM_PLAYLIST_SEGMENTS are returned in the playlist causing the # EXT-X-DISCONTINUITY tag to be omitted and EXT-X-DISCONTINUITY-SEQUENCE # returned instead. - start = MAX_SEGMENTS + 2 - NUM_PLAYLIST_SEGMENTS + start = MAX_SEGMENTS + 1 - NUM_PLAYLIST_SEGMENTS segments = [] - for sequence in range(start, MAX_SEGMENTS + 2): + for sequence in range(start, MAX_SEGMENTS + 1): segments.append(make_segment(sequence)) assert await resp.text() == make_playlist( sequence=start, diff --git a/tests/components/stream/test_recorder.py b/tests/components/stream/test_recorder.py index d45dd0cbca7..07e1464f31a 100644 --- a/tests/components/stream/test_recorder.py +++ b/tests/components/stream/test_recorder.py @@ -1,23 +1,16 @@ """The tests for hls streams.""" -from __future__ import annotations - -import asyncio -from collections import deque from datetime import timedelta from io import BytesIO -import logging import os -import threading from unittest.mock import patch -import async_timeout import av import pytest from homeassistant.components.stream import create_stream from homeassistant.components.stream.const import HLS_PROVIDER, RECORDER_PROVIDER -from homeassistant.components.stream.core import Segment -from homeassistant.components.stream.fmp4utils import get_init_and_moof_data +from homeassistant.components.stream.core import Part, Segment +from homeassistant.components.stream.fmp4utils import find_box from homeassistant.components.stream.recorder import recorder_save_worker from homeassistant.exceptions import HomeAssistantError from homeassistant.setup import async_setup_component @@ -26,63 +19,9 @@ import homeassistant.util.dt as dt_util from tests.common import async_fire_time_changed from tests.components.stream.common import generate_h264_video -TEST_TIMEOUT = 7.0 # Lower than 9s home assistant timeout MAX_ABORT_SEGMENTS = 20 # Abort test to avoid looping forever -class SaveRecordWorkerSync: - """ - Test fixture to manage RecordOutput thread for recorder_save_worker. - - This is used to assert that the worker is started and stopped cleanly - to avoid thread leaks in tests. - """ - - def __init__(self): - """Initialize SaveRecordWorkerSync.""" - self.reset() - self._segments = None - self._save_thread = None - - def recorder_save_worker(self, file_out: str, segments: deque[Segment]): - """Mock method for patch.""" - logging.debug("recorder_save_worker thread started") - assert self._save_thread is None - self._segments = segments - self._save_thread = threading.current_thread() - self._save_event.set() - - async def get_segments(self): - """Return the recorded video segments.""" - with async_timeout.timeout(TEST_TIMEOUT): - await self._save_event.wait() - return self._segments - - async def join(self): - """Verify save worker was invoked and block on shutdown.""" - with async_timeout.timeout(TEST_TIMEOUT): - await self._save_event.wait() - self._save_thread.join(timeout=TEST_TIMEOUT) - assert not self._save_thread.is_alive() - - def reset(self): - """Reset callback state for reuse in tests.""" - self._save_thread = None - self._save_event = asyncio.Event() - - -@pytest.fixture() -def record_worker_sync(hass): - """Patch recorder_save_worker for clean thread shutdown for test.""" - sync = SaveRecordWorkerSync() - with patch( - "homeassistant.components.stream.recorder.recorder_save_worker", - side_effect=sync.recorder_save_worker, - autospec=True, - ): - yield sync - - async def test_record_stream(hass, hass_client, record_worker_sync): """ Test record stream. @@ -179,6 +118,21 @@ async def test_record_path_not_allowed(hass, hass_client): await stream.async_record("/example/path") +def add_parts_to_segment(segment, source): + """Add relevant part data to segment for testing recorder.""" + moof_locs = list(find_box(source.getbuffer(), b"moof")) + [len(source.getbuffer())] + segment.init = source.getbuffer()[: moof_locs[0]].tobytes() + segment.parts = [ + Part( + duration=None, + has_keyframe=None, + http_range_start=None, + data=source.getbuffer()[moof_locs[i] : moof_locs[i + 1]], + ) + for i in range(1, len(moof_locs) - 1) + ] + + async def test_recorder_save(tmpdir): """Test recorder save.""" # Setup @@ -186,9 +140,10 @@ async def test_recorder_save(tmpdir): filename = f"{tmpdir}/test.mp4" # Run - recorder_save_worker( - filename, [Segment(1, *get_init_and_moof_data(source.getbuffer()), 4)] - ) + segment = Segment(sequence=1) + add_parts_to_segment(segment, source) + segment.duration = 4 + recorder_save_worker(filename, [segment]) # Assert assert os.path.exists(filename) @@ -201,15 +156,13 @@ async def test_recorder_discontinuity(tmpdir): filename = f"{tmpdir}/test.mp4" # Run - init, moof_data = get_init_and_moof_data(source.getbuffer()) - recorder_save_worker( - filename, - [ - Segment(1, init, moof_data, 4, 0), - Segment(2, init, moof_data, 4, 1), - ], - ) - + segment_1 = Segment(sequence=1, stream_id=0) + add_parts_to_segment(segment_1, source) + segment_1.duration = 4 + segment_2 = Segment(sequence=2, stream_id=1) + add_parts_to_segment(segment_2, source) + segment_2.duration = 4 + recorder_save_worker(filename, [segment_1, segment_2]) # Assert assert os.path.exists(filename) @@ -263,7 +216,9 @@ async def test_record_stream_audio( stream_worker_sync.resume() result = av.open( - BytesIO(last_segment.init + last_segment.moof_data), "r", format="mp4" + BytesIO(last_segment.init + last_segment.get_bytes_without_init()), + "r", + format="mp4", ) assert len(result.streams.audio) == expected_audio_streams diff --git a/tests/components/stream/test_worker.py b/tests/components/stream/test_worker.py index aa354ef41cb..74a4fa0e553 100644 --- a/tests/components/stream/test_worker.py +++ b/tests/components/stream/test_worker.py @@ -21,7 +21,7 @@ from unittest.mock import patch import av -from homeassistant.components.stream import Stream +from homeassistant.components.stream import Stream, create_stream from homeassistant.components.stream.const import ( HLS_PROVIDER, MAX_MISSING_DTS, @@ -29,6 +29,9 @@ from homeassistant.components.stream.const import ( TARGET_SEGMENT_DURATION, ) from homeassistant.components.stream.worker import SegmentBuffer, stream_worker +from homeassistant.setup import async_setup_component + +from tests.components.stream.common import generate_h264_video STREAM_SOURCE = "some-stream-source" # Formats here are arbitrary, not exercised by tests @@ -99,9 +102,9 @@ class PacketSequence: super().__init__(3) time_base = fractions.Fraction(1, VIDEO_FRAME_RATE) - dts = self.packet * PACKET_DURATION / time_base - pts = self.packet * PACKET_DURATION / time_base - duration = PACKET_DURATION / time_base + dts = int(self.packet * PACKET_DURATION / time_base) + pts = int(self.packet * PACKET_DURATION / time_base) + duration = int(PACKET_DURATION / time_base) stream = VIDEO_STREAM # Pretend we get 1 keyframe every second is_keyframe = not (self.packet - 1) % (VIDEO_FRAME_RATE * KEYFRAME_INTERVAL) @@ -177,6 +180,11 @@ class FakePyAvBuffer: """Capture the output segment for tests to inspect.""" self.segments.append(segment) + @property + def complete_segments(self): + """Return only the complete segments.""" + return [segment for segment in self.segments if segment.complete] + class MockPyAv: """Mocks out av.open.""" @@ -197,6 +205,19 @@ class MockPyAv: return self.container +class MockFlushPart: + """Class to hold a wrapper function for check_flush_part.""" + + # Wrap this method with a preceding write so the BytesIO pointer moves + check_flush_part = SegmentBuffer.check_flush_part + + @classmethod + def wrapped_check_flush_part(cls, segment_buffer, packet): + """Wrap check_flush_part to also advance the memory_file pointer.""" + segment_buffer._memory_file.write(b"0") + return cls.check_flush_part(segment_buffer, packet) + + async def async_decode_stream(hass, packets, py_av=None): """Start a stream worker that decodes incoming stream packets into output segments.""" stream = Stream(hass, STREAM_SOURCE) @@ -209,6 +230,10 @@ async def async_decode_stream(hass, packets, py_av=None): with patch("av.open", new=py_av.open), patch( "homeassistant.components.stream.core.StreamOutput.put", side_effect=py_av.capture_buffer.capture_output_segment, + ), patch( + "homeassistant.components.stream.worker.SegmentBuffer.check_flush_part", + side_effect=MockFlushPart.wrapped_check_flush_part, + autospec=True, ): segment_buffer = SegmentBuffer(stream.outputs) stream_worker(STREAM_SOURCE, {}, segment_buffer, threading.Event()) @@ -235,13 +260,16 @@ async def test_stream_worker_success(hass): hass, PacketSequence(TEST_SEQUENCE_LENGTH) ) segments = decoded_stream.segments + complete_segments = decoded_stream.complete_segments # Check number of segments. A segment is only formed when a packet from the next # segment arrives, hence the subtraction of one from the sequence length. - assert len(segments) == int((TEST_SEQUENCE_LENGTH - 1) * SEGMENTS_PER_PACKET) + assert len(complete_segments) == int( + (TEST_SEQUENCE_LENGTH - 1) * SEGMENTS_PER_PACKET + ) # Check sequence numbers assert all(segments[i].sequence == i for i in range(len(segments))) # Check segment durations - assert all(s.duration == SEGMENT_DURATION for s in segments) + assert all(s.duration == SEGMENT_DURATION for s in complete_segments) assert len(decoded_stream.video_packets) == TEST_SEQUENCE_LENGTH assert len(decoded_stream.audio_packets) == 0 @@ -259,6 +287,7 @@ async def test_skip_out_of_order_packet(hass): decoded_stream = await async_decode_stream(hass, iter(packets)) segments = decoded_stream.segments + complete_segments = decoded_stream.complete_segments # Check sequence numbers assert all(segments[i].sequence == i for i in range(len(segments))) # If skipped packet would have been the first packet of a segment, the previous @@ -273,12 +302,14 @@ async def test_skip_out_of_order_packet(hass): ) del segments[longer_segment_index] # Check number of segments - assert len(segments) == int((len(packets) - 1 - 1) * SEGMENTS_PER_PACKET - 1) + assert len(complete_segments) == int( + (len(packets) - 1 - 1) * SEGMENTS_PER_PACKET - 1 + ) else: # Otherwise segment durations and number of segments are unaffected # Check number of segments - assert len(segments) == int((len(packets) - 1) * SEGMENTS_PER_PACKET) + assert len(complete_segments) == int((len(packets) - 1) * SEGMENTS_PER_PACKET) # Check remaining segment durations - assert all(s.duration == SEGMENT_DURATION for s in segments) + assert all(s.duration == SEGMENT_DURATION for s in complete_segments) assert len(decoded_stream.video_packets) == len(packets) - 1 assert len(decoded_stream.audio_packets) == 0 @@ -292,12 +323,15 @@ async def test_discard_old_packets(hass): decoded_stream = await async_decode_stream(hass, iter(packets)) segments = decoded_stream.segments + complete_segments = decoded_stream.complete_segments # Check number of segments - assert len(segments) == int((OUT_OF_ORDER_PACKET_INDEX - 1) * SEGMENTS_PER_PACKET) + assert len(complete_segments) == int( + (OUT_OF_ORDER_PACKET_INDEX - 1) * SEGMENTS_PER_PACKET + ) # Check sequence numbers assert all(segments[i].sequence == i for i in range(len(segments))) # Check segment durations - assert all(s.duration == SEGMENT_DURATION for s in segments) + assert all(s.duration == SEGMENT_DURATION for s in complete_segments) assert len(decoded_stream.video_packets) == OUT_OF_ORDER_PACKET_INDEX assert len(decoded_stream.audio_packets) == 0 @@ -311,12 +345,15 @@ async def test_packet_overflow(hass): decoded_stream = await async_decode_stream(hass, iter(packets)) segments = decoded_stream.segments + complete_segments = decoded_stream.complete_segments # Check number of segments - assert len(segments) == int((OUT_OF_ORDER_PACKET_INDEX - 1) * SEGMENTS_PER_PACKET) + assert len(complete_segments) == int( + (OUT_OF_ORDER_PACKET_INDEX - 1) * SEGMENTS_PER_PACKET + ) # Check sequence numbers assert all(segments[i].sequence == i for i in range(len(segments))) # Check segment durations - assert all(s.duration == SEGMENT_DURATION for s in segments) + assert all(s.duration == SEGMENT_DURATION for s in complete_segments) assert len(decoded_stream.video_packets) == OUT_OF_ORDER_PACKET_INDEX assert len(decoded_stream.audio_packets) == 0 @@ -332,10 +369,11 @@ async def test_skip_initial_bad_packets(hass): decoded_stream = await async_decode_stream(hass, iter(packets)) segments = decoded_stream.segments + complete_segments = decoded_stream.complete_segments # Check sequence numbers assert all(segments[i].sequence == i for i in range(len(segments))) # Check segment durations - assert all(s.duration == SEGMENT_DURATION for s in segments) + assert all(s.duration == SEGMENT_DURATION for s in complete_segments) assert ( len(decoded_stream.video_packets) == num_packets @@ -344,7 +382,7 @@ async def test_skip_initial_bad_packets(hass): * KEYFRAME_INTERVAL ) # Check number of segments - assert len(segments) == int( + assert len(complete_segments) == int( (len(decoded_stream.video_packets) - 1) * SEGMENTS_PER_PACKET ) assert len(decoded_stream.audio_packets) == 0 @@ -381,13 +419,11 @@ async def test_skip_missing_dts(hass): decoded_stream = await async_decode_stream(hass, iter(packets)) segments = decoded_stream.segments + complete_segments = decoded_stream.complete_segments # Check sequence numbers assert all(segments[i].sequence == i for i in range(len(segments))) # Check segment durations (not counting the last segment) - assert ( - sum([segments[i].duration == SEGMENT_DURATION for i in range(len(segments))]) - >= len(segments) - 1 - ) + assert sum(segment.duration for segment in complete_segments) >= len(segments) - 1 assert len(decoded_stream.video_packets) == num_packets - num_bad_packets assert len(decoded_stream.audio_packets) == 0 @@ -403,8 +439,8 @@ async def test_too_many_bad_packets(hass): packets[i].dts = None decoded_stream = await async_decode_stream(hass, iter(packets)) - segments = decoded_stream.segments - assert len(segments) == int((bad_packet_start - 1) * SEGMENTS_PER_PACKET) + complete_segments = decoded_stream.complete_segments + assert len(complete_segments) == int((bad_packet_start - 1) * SEGMENTS_PER_PACKET) assert len(decoded_stream.video_packets) == bad_packet_start assert len(decoded_stream.audio_packets) == 0 @@ -431,8 +467,8 @@ async def test_audio_packets_not_found(hass): packets = PacketSequence(num_packets) # Contains only video packets decoded_stream = await async_decode_stream(hass, iter(packets), py_av=py_av) - segments = decoded_stream.segments - assert len(segments) == int((num_packets - 1) * SEGMENTS_PER_PACKET) + complete_segments = decoded_stream.complete_segments + assert len(complete_segments) == int((num_packets - 1) * SEGMENTS_PER_PACKET) assert len(decoded_stream.video_packets) == num_packets assert len(decoded_stream.audio_packets) == 0 @@ -444,8 +480,8 @@ async def test_adts_aac_audio(hass): num_packets = PACKETS_TO_WAIT_FOR_AUDIO + 1 packets = list(PacketSequence(num_packets)) packets[1].stream = AUDIO_STREAM - packets[1].dts = packets[0].dts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE - packets[1].pts = packets[0].pts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE + packets[1].dts = int(packets[0].dts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE) + packets[1].pts = int(packets[0].pts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE) # The following is packet data is a sign of ADTS AAC packets[1][0] = 255 packets[1][1] = 241 @@ -462,17 +498,17 @@ async def test_audio_is_first_packet(hass): packets = list(PacketSequence(num_packets)) # Pair up an audio packet for each video packet packets[0].stream = AUDIO_STREAM - packets[0].dts = packets[1].dts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE - packets[0].pts = packets[1].pts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE + packets[0].dts = int(packets[1].dts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE) + packets[0].pts = int(packets[1].pts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE) packets[1].is_keyframe = True # Move the video keyframe from packet 0 to packet 1 packets[2].stream = AUDIO_STREAM - packets[2].dts = packets[3].dts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE - packets[2].pts = packets[3].pts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE + packets[2].dts = int(packets[3].dts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE) + packets[2].pts = int(packets[3].pts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE) decoded_stream = await async_decode_stream(hass, iter(packets), py_av=py_av) - segments = decoded_stream.segments + complete_segments = decoded_stream.complete_segments # The audio packets are segmented with the video packets - assert len(segments) == int((num_packets - 2 - 1) * SEGMENTS_PER_PACKET) + assert len(complete_segments) == int((num_packets - 2 - 1) * SEGMENTS_PER_PACKET) assert len(decoded_stream.video_packets) == num_packets - 2 assert len(decoded_stream.audio_packets) == 1 @@ -484,13 +520,13 @@ async def test_audio_packets_found(hass): num_packets = PACKETS_TO_WAIT_FOR_AUDIO + 1 packets = list(PacketSequence(num_packets)) packets[1].stream = AUDIO_STREAM - packets[1].dts = packets[0].dts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE - packets[1].pts = packets[0].pts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE + packets[1].dts = int(packets[0].dts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE) + packets[1].pts = int(packets[0].pts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE) decoded_stream = await async_decode_stream(hass, iter(packets), py_av=py_av) - segments = decoded_stream.segments + complete_segments = decoded_stream.complete_segments # The audio packet above is buffered with the video packet - assert len(segments) == int((num_packets - 1 - 1) * SEGMENTS_PER_PACKET) + assert len(complete_segments) == int((num_packets - 1 - 1) * SEGMENTS_PER_PACKET) assert len(decoded_stream.video_packets) == num_packets - 1 assert len(decoded_stream.audio_packets) == 1 @@ -507,12 +543,15 @@ async def test_pts_out_of_order(hass): decoded_stream = await async_decode_stream(hass, iter(packets)) segments = decoded_stream.segments + complete_segments = decoded_stream.complete_segments # Check number of segments - assert len(segments) == int((TEST_SEQUENCE_LENGTH - 1) * SEGMENTS_PER_PACKET) + assert len(complete_segments) == int( + (TEST_SEQUENCE_LENGTH - 1) * SEGMENTS_PER_PACKET + ) # Check sequence numbers assert all(segments[i].sequence == i for i in range(len(segments))) # Check segment durations - assert all(s.duration == SEGMENT_DURATION for s in segments) + assert all(s.duration == SEGMENT_DURATION for s in complete_segments) assert len(decoded_stream.video_packets) == len(packets) assert len(decoded_stream.audio_packets) == 0 @@ -573,7 +612,11 @@ async def test_update_stream_source(hass): worker_wake.wait() return py_av.open(stream_source, args, kwargs) - with patch("av.open", new=blocking_open): + with patch("av.open", new=blocking_open), patch( + "homeassistant.components.stream.worker.SegmentBuffer.check_flush_part", + side_effect=MockFlushPart.wrapped_check_flush_part, + autospec=True, + ): stream.start() assert worker_open.wait(TIMEOUT) assert last_stream_source == STREAM_SOURCE @@ -604,3 +647,74 @@ async def test_worker_log(hass, caplog): await hass.async_block_till_done() assert "https://abcd:efgh@foo.bar" not in caplog.text assert "https://****:****@foo.bar" in caplog.text + + +async def test_durations(hass, record_worker_sync): + """Test that the duration metadata matches the media.""" + await async_setup_component(hass, "stream", {"stream": {}}) + + source = generate_h264_video() + stream = create_stream(hass, source) + + # use record_worker_sync to grab output segments + with patch.object(hass.config, "is_allowed_path", return_value=True): + await stream.async_record("/example/path") + + complete_segments = list(await record_worker_sync.get_segments())[:-1] + assert len(complete_segments) >= 1 + + # check that the Part duration metadata matches the durations in the media + running_metadata_duration = 0 + for segment in complete_segments: + for part in segment.parts: + av_part = av.open(io.BytesIO(segment.init + part.data)) + running_metadata_duration += part.duration + # av_part.duration will just return the largest dts in av_part. + # When we normalize by av.time_base this should equal the running duration + assert math.isclose( + running_metadata_duration, + av_part.duration / av.time_base, + abs_tol=1e-6, + ) + av_part.close() + # check that the Part durations are consistent with the Segment durations + for segment in complete_segments: + assert math.isclose( + sum(part.duration for part in segment.parts), segment.duration, abs_tol=1e-6 + ) + + await record_worker_sync.join() + + stream.stop() + + +async def test_has_keyframe(hass, record_worker_sync): + """Test that the has_keyframe metadata matches the media.""" + await async_setup_component(hass, "stream", {"stream": {}}) + + source = generate_h264_video() + stream = create_stream(hass, source) + + # use record_worker_sync to grab output segments + with patch.object(hass.config, "is_allowed_path", return_value=True): + await stream.async_record("/example/path") + + # Our test video has keyframes every second. Use smaller parts so we have more + # part boundaries to better test keyframe logic. + with patch("homeassistant.components.stream.worker.TARGET_PART_DURATION", 0.25): + complete_segments = list(await record_worker_sync.get_segments())[:-1] + assert len(complete_segments) >= 1 + + # check that the Part has_keyframe metadata matches the keyframes in the media + for segment in complete_segments: + for part in segment.parts: + av_part = av.open(io.BytesIO(segment.init + part.data)) + media_has_keyframe = any( + packet.is_keyframe for packet in av_part.demux(av_part.streams.video[0]) + ) + av_part.close() + assert part.has_keyframe == media_has_keyframe + + await record_worker_sync.join() + + stream.stop()