Refactor stream to create partial segments (#51282)

pull/51821/head
uvjustin 2021-06-14 00:41:21 +08:00 committed by GitHub
parent 1adeb82930
commit 123e8f01a1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 499 additions and 305 deletions

View File

@ -18,8 +18,9 @@ FORMAT_CONTENT_TYPE = {HLS_PROVIDER: "application/vnd.apple.mpegurl"}
OUTPUT_IDLE_TIMEOUT = 300 # Idle timeout due to inactivity
NUM_PLAYLIST_SEGMENTS = 3 # Number of segments to use in HLS playlist
MAX_SEGMENTS = 4 # Max number of segments to keep around
MAX_SEGMENTS = 5 # Max number of segments to keep around
TARGET_SEGMENT_DURATION = 2.0 # Each segment is about this many seconds
TARGET_PART_DURATION = 1.0
SEGMENT_DURATION_ADJUSTER = 0.1 # Used to avoid missing keyframe boundaries
# Each segment is at least this many seconds
MIN_SEGMENT_DURATION = TARGET_SEGMENT_DURATION - SEGMENT_DURATION_ADJUSTER

View File

@ -19,20 +19,37 @@ from .const import ATTR_STREAMS, DOMAIN
PROVIDERS = Registry()
@attr.s(slots=True)
class Part:
"""Represent a segment part."""
duration: float = attr.ib()
has_keyframe: bool = attr.ib()
data: bytes = attr.ib()
@attr.s(slots=True)
class Segment:
"""Represent a segment."""
sequence: int = attr.ib()
# the init of the mp4
init: bytes = attr.ib()
# the video data (moof + mddat)s of the mp4
moof_data: bytes = attr.ib()
duration: float = attr.ib()
sequence: int = attr.ib(default=0)
# the init of the mp4 the segment is based on
init: bytes = attr.ib(default=None)
duration: float = attr.ib(default=0)
# For detecting discontinuities across stream restarts
stream_id: int = attr.ib(default=0)
parts: list[Part] = attr.ib(factory=list)
start_time: datetime.datetime = attr.ib(factory=datetime.datetime.utcnow)
@property
def complete(self) -> bool:
"""Return whether the Segment is complete."""
return self.duration > 0
def get_bytes_without_init(self) -> bytes:
"""Return reconstructed data for entire segment as bytes."""
return b"".join([part.data for part in self.parts])
class IdleTimer:
"""Invoke a callback after an inactivity timeout.

View File

@ -25,16 +25,6 @@ def find_box(
index += int.from_bytes(box_header[0:4], byteorder="big")
def get_init_and_moof_data(segment: memoryview) -> tuple[bytes, bytes]:
"""Get the init and moof data from a segment."""
moof_location = next(find_box(segment, b"moof"), 0)
mfra_location = next(find_box(segment, b"mfra"), len(segment))
return (
segment[:moof_location].tobytes(),
segment[moof_location:mfra_location].tobytes(),
)
def get_codec_string(mp4_bytes: bytes) -> str:
"""Get RFC 6381 codec string."""
codecs = []

View File

@ -37,9 +37,12 @@ class HlsMasterPlaylistView(StreamView):
# Need to calculate max bandwidth as input_container.bit_rate doesn't seem to work
# Calculate file size / duration and use a small multiplier to account for variation
# hls spec already allows for 25% variation
segment = track.get_segment(track.sequences[-1])
segment = track.get_segment(track.sequences[-2])
bandwidth = round(
(len(segment.init) + len(segment.moof_data)) * 8 / segment.duration * 1.2
(len(segment.init) + sum(len(part.data) for part in segment.parts))
* 8
/ segment.duration
* 1.2
)
codecs = get_codec_string(segment.init)
lines = [
@ -53,9 +56,11 @@ class HlsMasterPlaylistView(StreamView):
"""Return m3u8 playlist."""
track = stream.add_provider(HLS_PROVIDER)
stream.start()
# Wait for a segment to be ready
# Make sure at least two segments are ready (last one may not be complete)
if not track.sequences and not await track.recv():
return web.HTTPNotFound()
if len(track.sequences) == 1 and not await track.recv():
return web.HTTPNotFound()
headers = {"Content-Type": FORMAT_CONTENT_TYPE[HLS_PROVIDER]}
return web.Response(body=self.render(track).encode("utf-8"), headers=headers)
@ -68,69 +73,72 @@ class HlsPlaylistView(StreamView):
cors_allowed = True
@staticmethod
def render_preamble(track):
"""Render preamble."""
return [
"#EXT-X-VERSION:6",
f"#EXT-X-TARGETDURATION:{track.target_duration}",
'#EXT-X-MAP:URI="init.mp4"',
]
@staticmethod
def render_playlist(track):
def render(track):
"""Render playlist."""
segments = list(track.get_segments())[-NUM_PLAYLIST_SEGMENTS:]
# NUM_PLAYLIST_SEGMENTS+1 because most recent is probably not yet complete
segments = list(track.get_segments())[-(NUM_PLAYLIST_SEGMENTS + 1) :]
if not segments:
return []
# To cap the number of complete segments at NUM_PLAYLIST_SEGMENTS,
# remove the first segment if the last segment is actually complete
if segments[-1].complete:
segments = segments[-NUM_PLAYLIST_SEGMENTS:]
first_segment = segments[0]
playlist = [
"#EXTM3U",
"#EXT-X-VERSION:6",
"#EXT-X-INDEPENDENT-SEGMENTS",
'#EXT-X-MAP:URI="init.mp4"',
f"#EXT-X-TARGETDURATION:{track.target_duration:.0f}",
f"#EXT-X-MEDIA-SEQUENCE:{first_segment.sequence}",
f"#EXT-X-DISCONTINUITY-SEQUENCE:{first_segment.stream_id}",
"#EXT-X-PROGRAM-DATE-TIME:"
+ first_segment.start_time.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3]
+ "Z",
# Since our window doesn't have many segments, we don't want to start
# at the beginning or we risk a behind live window exception in exoplayer.
# at the beginning or we risk a behind live window exception in Exoplayer.
# EXT-X-START is not supposed to be within 3 target durations of the end,
# but this seems ok
f"#EXT-X-START:TIME-OFFSET=-{EXT_X_START * track.target_duration:.3f},PRECISE=YES",
# but a value as low as 1.5 doesn't seem to hurt.
# A value below 3 may not be as useful for hls.js as many hls.js clients
# don't autoplay. Also, hls.js uses the player parameter liveSyncDuration
# which seems to take precedence for setting target delay. Yet it also
# doesn't seem to hurt, so we can stick with it for now.
f"#EXT-X-START:TIME-OFFSET=-{EXT_X_START * track.target_duration:.3f}",
]
last_stream_id = first_segment.stream_id
# Add playlist sections
for segment in segments:
if last_stream_id != segment.stream_id:
# Skip last segment if it is not complete
if segment.complete:
if last_stream_id != segment.stream_id:
playlist.extend(
[
"#EXT-X-DISCONTINUITY",
"#EXT-X-PROGRAM-DATE-TIME:"
+ segment.start_time.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3]
+ "Z",
]
)
playlist.extend(
[
"#EXT-X-DISCONTINUITY",
"#EXT-X-PROGRAM-DATE-TIME:"
+ segment.start_time.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3]
+ "Z",
f"#EXTINF:{segment.duration:.3f},",
f"./segment/{segment.sequence}.m4s",
]
)
playlist.extend(
[
f"#EXTINF:{float(segment.duration):.04f},",
f"./segment/{segment.sequence}.m4s",
]
)
last_stream_id = segment.stream_id
last_stream_id = segment.stream_id
return playlist
def render(self, track):
"""Render M3U8 file."""
lines = ["#EXTM3U"] + self.render_preamble(track) + self.render_playlist(track)
return "\n".join(lines) + "\n"
return "\n".join(playlist) + "\n"
async def handle(self, request, stream, sequence):
"""Return m3u8 playlist."""
track = stream.add_provider(HLS_PROVIDER)
stream.start()
# Wait for a segment to be ready
# Make sure at least two segments are ready (last one may not be complete)
if not track.sequences and not await track.recv():
return web.HTTPNotFound()
if len(track.sequences) == 1 and not await track.recv():
return web.HTTPNotFound()
headers = {"Content-Type": FORMAT_CONTENT_TYPE[HLS_PROVIDER]}
response = web.Response(
body=self.render(track).encode("utf-8"), headers=headers
@ -170,7 +178,7 @@ class HlsSegmentView(StreamView):
return web.HTTPNotFound()
headers = {"Content-Type": "video/iso.segment"}
return web.Response(
body=segment.moof_data,
body=segment.get_bytes_without_init(),
headers=headers,
)

View File

@ -57,7 +57,7 @@ def recorder_save_worker(file_out: str, segments: deque[Segment]):
# Open segment
source = av.open(
BytesIO(segment.init + segment.moof_data),
BytesIO(segment.init + segment.get_bytes_without_init()),
"r",
format=SEGMENT_CONTAINER_FORMAT,
)

View File

@ -2,9 +2,12 @@
from __future__ import annotations
from collections import deque
from collections.abc import Iterator, Mapping
from fractions import Fraction
from io import BytesIO
import logging
from typing import cast
from threading import Event
from typing import Callable, cast
import av
@ -17,9 +20,9 @@ from .const import (
PACKETS_TO_WAIT_FOR_AUDIO,
SEGMENT_CONTAINER_FORMAT,
SOURCE_TIMEOUT,
TARGET_PART_DURATION,
)
from .core import Segment, StreamOutput
from .fmp4utils import get_init_and_moof_data
from .core import Part, Segment, StreamOutput
_LOGGER = logging.getLogger(__name__)
@ -27,22 +30,28 @@ _LOGGER = logging.getLogger(__name__)
class SegmentBuffer:
"""Buffer for writing a sequence of packets to the output as a segment."""
def __init__(self, outputs_callback) -> None:
def __init__(
self, outputs_callback: Callable[[], Mapping[str, StreamOutput]]
) -> None:
"""Initialize SegmentBuffer."""
self._stream_id = 0
self._outputs_callback = outputs_callback
self._outputs: list[StreamOutput] = []
self._stream_id: int = 0
self._outputs_callback: Callable[
[], Mapping[str, StreamOutput]
] = outputs_callback
# sequence gets incremented before the first segment so the first segment
# has a sequence number of 0.
self._sequence = -1
self._segment_start_pts = None
self._segment_start_dts: int = cast(int, None)
self._memory_file: BytesIO = cast(BytesIO, None)
self._av_output: av.container.OutputContainer = None
self._input_video_stream: av.video.VideoStream = None
self._input_audio_stream = None # av.audio.AudioStream | None
self._output_video_stream: av.video.VideoStream = None
self._output_audio_stream = None # av.audio.AudioStream | None
self._segment: Segment = cast(Segment, None)
self._segment: Segment | None = None
self._segment_last_write_pos: int = cast(int, None)
self._part_start_dts: int = cast(int, None)
self._part_has_keyframe = False
@staticmethod
def make_new_av(
@ -56,10 +65,17 @@ class SegmentBuffer:
container_options={
# Removed skip_sidx - see https://github.com/home-assistant/core/pull/39970
# "cmaf" flag replaces several of the movflags used, but too recent to use for now
"movflags": "frag_custom+empty_moov+default_base_moof+frag_discont+negative_cts_offsets+skip_trailer",
"avoid_negative_ts": "disabled",
"movflags": "empty_moov+default_base_moof+frag_discont+negative_cts_offsets+skip_trailer",
# Sometimes the first segment begins with negative timestamps, and this setting just
# adjusts the timestamps in the output from that segment to start from 0. Helps from
# having to make some adjustments in test_durations
"avoid_negative_ts": "make_non_negative",
"fragment_index": str(sequence + 1),
"video_track_timescale": str(int(1 / input_vstream.time_base)),
# Create a fragments every TARGET_PART_DURATION. The data from each fragment is stored in
# a "Part" that can be combined with the data from all the other "Part"s, plus an init
# section, to reconstitute the data in a "Segment".
"frag_duration": str(int(TARGET_PART_DURATION * 1e6)),
},
)
@ -73,15 +89,13 @@ class SegmentBuffer:
self._input_video_stream = video_stream
self._input_audio_stream = audio_stream
def reset(self, video_pts):
def reset(self, video_dts: int) -> None:
"""Initialize a new stream segment."""
# Keep track of the number of segments we've processed
self._sequence += 1
self._segment_start_pts = video_pts
# Fetch the latest StreamOutputs, which may have changed since the
# worker started.
self._outputs = self._outputs_callback().values()
self._segment_start_dts = self._part_start_dts = video_dts
self._segment = None
self._segment_last_write_pos = 0
self._memory_file = BytesIO()
self._av_output = self.make_new_av(
memory_file=self._memory_file,
@ -98,54 +112,102 @@ class SegmentBuffer:
template=self._input_audio_stream
)
def mux_packet(self, packet):
def mux_packet(self, packet: av.Packet) -> None:
"""Mux a packet to the appropriate output stream."""
# Check for end of segment
if packet.stream == self._input_video_stream and packet.is_keyframe:
duration = (packet.pts - self._segment_start_pts) * packet.time_base
if duration >= MIN_SEGMENT_DURATION:
# Save segment to outputs
self.flush(duration)
# Reinitialize
self.reset(packet.pts)
# Mux the packet
if packet.stream == self._input_video_stream:
if (
packet.is_keyframe
and (
segment_duration := (packet.dts - self._segment_start_dts)
* packet.time_base
)
>= MIN_SEGMENT_DURATION
):
# Flush segment (also flushes the stub part segment)
self.flush(segment_duration, packet)
# Reinitialize
self.reset(packet.dts)
# Mux the packet
packet.stream = self._output_video_stream
self._av_output.mux(packet)
self.check_flush_part(packet)
self._part_has_keyframe |= packet.is_keyframe
elif packet.stream == self._input_audio_stream:
packet.stream = self._output_audio_stream
self._av_output.mux(packet)
def flush(self, duration):
def check_flush_part(self, packet: av.Packet) -> None:
"""Check for and mark a part segment boundary and record its duration."""
byte_position = self._memory_file.tell()
if self._segment_last_write_pos == byte_position:
return
if self._segment is None:
# We have our first non-zero byte position. This means the init has just
# been written. Create a Segment and put it to the queue of each output.
self._segment = Segment(
sequence=self._sequence,
stream_id=self._stream_id,
init=self._memory_file.getvalue(),
)
self._segment_last_write_pos = byte_position
# Fetch the latest StreamOutputs, which may have changed since the
# worker started.
for stream_output in self._outputs_callback().values():
stream_output.put(self._segment)
else: # These are the ends of the part segments
self._segment.parts.append(
Part(
duration=float(
(packet.dts - self._part_start_dts) * packet.time_base
),
has_keyframe=self._part_has_keyframe,
data=self._memory_file.getbuffer()[
self._segment_last_write_pos : byte_position
].tobytes(),
)
)
self._segment_last_write_pos = byte_position
self._part_start_dts = packet.dts
self._part_has_keyframe = False
def flush(self, duration: Fraction, packet: av.Packet) -> None:
"""Create a segment from the buffered packets and write to output."""
self._av_output.close()
segment = Segment(
self._sequence,
*get_init_and_moof_data(self._memory_file.getbuffer()),
duration,
self._stream_id,
assert self._segment
self._segment.duration = float(duration)
# Also flush the part segment (need to close the output above before this)
self._segment.parts.append(
Part(
duration=float((packet.dts - self._part_start_dts) * packet.time_base),
has_keyframe=self._part_has_keyframe,
data=self._memory_file.getbuffer()[
self._segment_last_write_pos :
].tobytes(),
)
)
self._memory_file.close()
for stream_output in self._outputs:
stream_output.put(segment)
self._memory_file.close() # We don't need the BytesIO object anymore
def discontinuity(self):
def discontinuity(self) -> None:
"""Mark the stream as having been restarted."""
# Preserving sequence and stream_id here keep the HLS playlist logic
# simple to check for discontinuity at output time, and to determine
# the discontinuity sequence number.
self._stream_id += 1
def close(self):
def close(self) -> None:
"""Close stream buffer."""
self._av_output.close()
self._memory_file.close()
def stream_worker(source, options, segment_buffer, quit_event): # noqa: C901
def stream_worker( # noqa: C901
source: str, options: dict, segment_buffer: SegmentBuffer, quit_event: Event
) -> None:
"""Handle consuming streams."""
try:
@ -172,27 +234,27 @@ def stream_worker(source, options, segment_buffer, quit_event): # noqa: C901
audio_stream = None
# Iterator for demuxing
container_packets = None
container_packets: Iterator[av.Packet]
# The decoder timestamps of the latest packet in each stream we processed
last_dts = {video_stream: float("-inf"), audio_stream: float("-inf")}
# Keep track of consecutive packets without a dts to detect end of stream.
missing_dts = 0
# The video pts at the beginning of the segment
segment_start_pts = None
# The video dts at the beginning of the segment
segment_start_dts: int | None = None
# Because of problems 1 and 2 below, we need to store the first few packets and replay them
initial_packets = deque()
initial_packets: deque[av.Packet] = deque()
# Have to work around two problems with RTSP feeds in ffmpeg
# 1 - first frame has bad pts/dts https://trac.ffmpeg.org/ticket/5018
# 2 - seeking can be problematic https://trac.ffmpeg.org/ticket/7815
def peek_first_pts():
def peek_first_dts() -> bool:
"""Initialize by peeking into the first few packets of the stream.
Deal with problem #1 above (bad first packet pts/dts) by recalculating using pts/dts from second packet.
Also load the first video keyframe pts into segment_start_pts and check if the audio stream really exists.
Also load the first video keyframe dts into segment_start_dts and check if the audio stream really exists.
"""
nonlocal segment_start_pts, audio_stream, container_packets
nonlocal segment_start_dts, audio_stream, container_packets
missing_dts = 0
found_audio = False
try:
@ -215,8 +277,8 @@ def stream_worker(source, options, segment_buffer, quit_event): # noqa: C901
elif packet.is_keyframe: # video_keyframe
first_packet = packet
initial_packets.append(packet)
# Get first_pts from subsequent frame to first keyframe
while segment_start_pts is None or (
# Get first_dts from subsequent frame to first keyframe
while segment_start_dts is None or (
audio_stream
and not found_audio
and len(initial_packets) < PACKETS_TO_WAIT_FOR_AUDIO
@ -244,11 +306,10 @@ def stream_worker(source, options, segment_buffer, quit_event): # noqa: C901
continue
found_audio = True
elif (
segment_start_pts is None
): # This is the second video frame to calculate first_pts from
segment_start_pts = packet.dts - packet.duration
first_packet.pts = segment_start_pts
first_packet.dts = segment_start_pts
segment_start_dts is None
): # This is the second video frame to calculate first_dts from
segment_start_dts = packet.dts - packet.duration
first_packet.pts = first_packet.dts = segment_start_dts
initial_packets.append(packet)
if audio_stream and not found_audio:
_LOGGER.warning(
@ -263,12 +324,13 @@ def stream_worker(source, options, segment_buffer, quit_event): # noqa: C901
return False
return True
if not peek_first_pts():
if not peek_first_dts():
container.close()
return
segment_buffer.set_streams(video_stream, audio_stream)
segment_buffer.reset(segment_start_pts)
assert isinstance(segment_start_dts, int)
segment_buffer.reset(segment_start_dts)
while not quit_event.is_set():
try:

View File

@ -9,13 +9,21 @@ nothing for the test to verify. The solution is the WorkerSync class that
allows the tests to pause the worker thread before finalizing the stream
so that it can inspect the output.
"""
from __future__ import annotations
import asyncio
from collections import deque
import logging
import threading
from unittest.mock import patch
import async_timeout
import pytest
from homeassistant.components.stream import Stream
from homeassistant.components.stream.core import Segment
TEST_TIMEOUT = 7.0 # Lower than 9s home assistant timeout
class WorkerSync:
@ -58,3 +66,57 @@ def stream_worker_sync(hass):
autospec=True,
):
yield sync
class SaveRecordWorkerSync:
"""
Test fixture to manage RecordOutput thread for recorder_save_worker.
This is used to assert that the worker is started and stopped cleanly
to avoid thread leaks in tests.
"""
def __init__(self):
"""Initialize SaveRecordWorkerSync."""
self._save_event = None
self._segments = None
self._save_thread = None
self.reset()
def recorder_save_worker(self, file_out: str, segments: deque[Segment]):
"""Mock method for patch."""
logging.debug("recorder_save_worker thread started")
assert self._save_thread is None
self._segments = segments
self._save_thread = threading.current_thread()
self._save_event.set()
async def get_segments(self):
"""Return the recorded video segments."""
with async_timeout.timeout(TEST_TIMEOUT):
await self._save_event.wait()
return self._segments
async def join(self):
"""Verify save worker was invoked and block on shutdown."""
with async_timeout.timeout(TEST_TIMEOUT):
await self._save_event.wait()
self._save_thread.join(timeout=TEST_TIMEOUT)
assert not self._save_thread.is_alive()
def reset(self):
"""Reset callback state for reuse in tests."""
self._save_thread = None
self._save_event = asyncio.Event()
@pytest.fixture()
def record_worker_sync(hass):
"""Patch recorder_save_worker for clean thread shutdown for test."""
sync = SaveRecordWorkerSync()
with patch(
"homeassistant.components.stream.recorder.recorder_save_worker",
side_effect=sync.recorder_save_worker,
autospec=True,
):
yield sync

View File

@ -12,7 +12,7 @@ from homeassistant.components.stream.const import (
MAX_SEGMENTS,
NUM_PLAYLIST_SEGMENTS,
)
from homeassistant.components.stream.core import Segment
from homeassistant.components.stream.core import Part, Segment
from homeassistant.const import HTTP_NOT_FOUND
from homeassistant.setup import async_setup_component
import homeassistant.util.dt as dt_util
@ -22,7 +22,7 @@ from tests.components.stream.common import generate_h264_video
STREAM_SOURCE = "some-stream-source"
INIT_BYTES = b"init"
MOOF_BYTES = b"some-bytes"
FAKE_PAYLOAD = b"fake-payload"
SEGMENT_DURATION = 10
TEST_TIMEOUT = 5.0 # Lower than 9s home assistant timeout
MAX_ABORT_SEGMENTS = 20 # Abort test to avoid looping forever
@ -70,23 +70,24 @@ def make_segment(segment, discontinuity=False):
+ "Z",
]
)
response.extend(["#EXTINF:10.0000,", f"./segment/{segment}.m4s"]),
response.extend([f"#EXTINF:{SEGMENT_DURATION:.3f},", f"./segment/{segment}.m4s"])
return "\n".join(response)
def make_playlist(sequence, discontinuity_sequence=0, segments=[]):
def make_playlist(sequence, segments, discontinuity_sequence=0):
"""Create a an hls playlist response for tests to assert on."""
response = [
"#EXTM3U",
"#EXT-X-VERSION:6",
"#EXT-X-TARGETDURATION:10",
"#EXT-X-INDEPENDENT-SEGMENTS",
'#EXT-X-MAP:URI="init.mp4"',
"#EXT-X-TARGETDURATION:10",
f"#EXT-X-MEDIA-SEQUENCE:{sequence}",
f"#EXT-X-DISCONTINUITY-SEQUENCE:{discontinuity_sequence}",
"#EXT-X-PROGRAM-DATE-TIME:"
+ FAKE_TIME.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3]
+ "Z",
f"#EXT-X-START:TIME-OFFSET=-{1.5*SEGMENT_DURATION:.3f},PRECISE=YES",
f"#EXT-X-START:TIME-OFFSET=-{1.5*SEGMENT_DURATION:.3f}",
]
response.extend(segments)
response.append("")
@ -264,21 +265,26 @@ async def test_hls_playlist_view(hass, hls_stream, stream_worker_sync):
stream_worker_sync.pause()
hls = stream.add_provider(HLS_PROVIDER)
hls.put(Segment(1, INIT_BYTES, MOOF_BYTES, SEGMENT_DURATION, start_time=FAKE_TIME))
for i in range(2):
segment = Segment(sequence=i, duration=SEGMENT_DURATION, start_time=FAKE_TIME)
hls.put(segment)
await hass.async_block_till_done()
hls_client = await hls_stream(stream)
resp = await hls_client.get("/playlist.m3u8")
assert resp.status == 200
assert await resp.text() == make_playlist(sequence=1, segments=[make_segment(1)])
assert await resp.text() == make_playlist(
sequence=0, segments=[make_segment(0), make_segment(1)]
)
hls.put(Segment(2, INIT_BYTES, MOOF_BYTES, SEGMENT_DURATION, start_time=FAKE_TIME))
segment = Segment(sequence=2, duration=SEGMENT_DURATION, start_time=FAKE_TIME)
hls.put(segment)
await hass.async_block_till_done()
resp = await hls_client.get("/playlist.m3u8")
assert resp.status == 200
assert await resp.text() == make_playlist(
sequence=1, segments=[make_segment(1), make_segment(2)]
sequence=0, segments=[make_segment(0), make_segment(1), make_segment(2)]
)
stream_worker_sync.resume()
@ -296,37 +302,40 @@ async def test_hls_max_segments(hass, hls_stream, stream_worker_sync):
hls_client = await hls_stream(stream)
# Produce enough segments to overfill the output buffer by one
for sequence in range(1, MAX_SEGMENTS + 2):
hls.put(
Segment(
sequence,
INIT_BYTES,
MOOF_BYTES,
SEGMENT_DURATION,
start_time=FAKE_TIME,
)
for sequence in range(MAX_SEGMENTS + 1):
segment = Segment(
sequence=sequence, duration=SEGMENT_DURATION, start_time=FAKE_TIME
)
hls.put(segment)
await hass.async_block_till_done()
resp = await hls_client.get("/playlist.m3u8")
assert resp.status == 200
# Only NUM_PLAYLIST_SEGMENTS are returned in the playlist.
start = MAX_SEGMENTS + 2 - NUM_PLAYLIST_SEGMENTS
start = MAX_SEGMENTS + 1 - NUM_PLAYLIST_SEGMENTS
segments = []
for sequence in range(start, MAX_SEGMENTS + 2):
for sequence in range(start, MAX_SEGMENTS + 1):
segments.append(make_segment(sequence))
assert await resp.text() == make_playlist(
sequence=start,
segments=segments,
)
assert await resp.text() == make_playlist(sequence=start, segments=segments)
# Fetch the actual segments with a fake byte payload
for segment in hls.get_segments():
segment.init = INIT_BYTES
segment.parts = [
Part(
duration=SEGMENT_DURATION,
has_keyframe=True,
data=FAKE_PAYLOAD,
)
]
# The segment that fell off the buffer is not accessible
segment_response = await hls_client.get("/segment/1.m4s")
segment_response = await hls_client.get("/segment/0.m4s")
assert segment_response.status == 404
# However all segments in the buffer are accessible, even those that were not in the playlist.
for sequence in range(2, MAX_SEGMENTS + 2):
for sequence in range(1, MAX_SEGMENTS + 1):
segment_response = await hls_client.get(f"/segment/{sequence}.m4s")
assert segment_response.status == 200
@ -342,36 +351,21 @@ async def test_hls_playlist_view_discontinuity(hass, hls_stream, stream_worker_s
stream_worker_sync.pause()
hls = stream.add_provider(HLS_PROVIDER)
hls.put(
Segment(
1,
INIT_BYTES,
MOOF_BYTES,
SEGMENT_DURATION,
stream_id=0,
start_time=FAKE_TIME,
)
segment = Segment(
sequence=0, stream_id=0, duration=SEGMENT_DURATION, start_time=FAKE_TIME
)
hls.put(
Segment(
2,
INIT_BYTES,
MOOF_BYTES,
SEGMENT_DURATION,
stream_id=0,
start_time=FAKE_TIME,
)
hls.put(segment)
segment = Segment(
sequence=1, stream_id=0, duration=SEGMENT_DURATION, start_time=FAKE_TIME
)
hls.put(
Segment(
3,
INIT_BYTES,
MOOF_BYTES,
SEGMENT_DURATION,
stream_id=1,
start_time=FAKE_TIME,
)
hls.put(segment)
segment = Segment(
sequence=2,
stream_id=1,
duration=SEGMENT_DURATION,
start_time=FAKE_TIME,
)
hls.put(segment)
await hass.async_block_till_done()
hls_client = await hls_stream(stream)
@ -379,11 +373,11 @@ async def test_hls_playlist_view_discontinuity(hass, hls_stream, stream_worker_s
resp = await hls_client.get("/playlist.m3u8")
assert resp.status == 200
assert await resp.text() == make_playlist(
sequence=1,
sequence=0,
segments=[
make_segment(0),
make_segment(1),
make_segment(2),
make_segment(3, discontinuity=True),
make_segment(2, discontinuity=True),
],
)
@ -401,29 +395,20 @@ async def test_hls_max_segments_discontinuity(hass, hls_stream, stream_worker_sy
hls_client = await hls_stream(stream)
hls.put(
Segment(
1,
INIT_BYTES,
MOOF_BYTES,
SEGMENT_DURATION,
stream_id=0,
start_time=FAKE_TIME,
)
segment = Segment(
sequence=0, stream_id=0, duration=SEGMENT_DURATION, start_time=FAKE_TIME
)
hls.put(segment)
# Produce enough segments to overfill the output buffer by one
for sequence in range(1, MAX_SEGMENTS + 2):
hls.put(
Segment(
sequence,
INIT_BYTES,
MOOF_BYTES,
SEGMENT_DURATION,
stream_id=1,
start_time=FAKE_TIME,
)
for sequence in range(MAX_SEGMENTS + 1):
segment = Segment(
sequence=sequence,
stream_id=1,
duration=SEGMENT_DURATION,
start_time=FAKE_TIME,
)
hls.put(segment)
await hass.async_block_till_done()
resp = await hls_client.get("/playlist.m3u8")
@ -432,9 +417,9 @@ async def test_hls_max_segments_discontinuity(hass, hls_stream, stream_worker_sy
# Only NUM_PLAYLIST_SEGMENTS are returned in the playlist causing the
# EXT-X-DISCONTINUITY tag to be omitted and EXT-X-DISCONTINUITY-SEQUENCE
# returned instead.
start = MAX_SEGMENTS + 2 - NUM_PLAYLIST_SEGMENTS
start = MAX_SEGMENTS + 1 - NUM_PLAYLIST_SEGMENTS
segments = []
for sequence in range(start, MAX_SEGMENTS + 2):
for sequence in range(start, MAX_SEGMENTS + 1):
segments.append(make_segment(sequence))
assert await resp.text() == make_playlist(
sequence=start,

View File

@ -1,23 +1,16 @@
"""The tests for hls streams."""
from __future__ import annotations
import asyncio
from collections import deque
from datetime import timedelta
from io import BytesIO
import logging
import os
import threading
from unittest.mock import patch
import async_timeout
import av
import pytest
from homeassistant.components.stream import create_stream
from homeassistant.components.stream.const import HLS_PROVIDER, RECORDER_PROVIDER
from homeassistant.components.stream.core import Segment
from homeassistant.components.stream.fmp4utils import get_init_and_moof_data
from homeassistant.components.stream.core import Part, Segment
from homeassistant.components.stream.fmp4utils import find_box
from homeassistant.components.stream.recorder import recorder_save_worker
from homeassistant.exceptions import HomeAssistantError
from homeassistant.setup import async_setup_component
@ -26,63 +19,9 @@ import homeassistant.util.dt as dt_util
from tests.common import async_fire_time_changed
from tests.components.stream.common import generate_h264_video
TEST_TIMEOUT = 7.0 # Lower than 9s home assistant timeout
MAX_ABORT_SEGMENTS = 20 # Abort test to avoid looping forever
class SaveRecordWorkerSync:
"""
Test fixture to manage RecordOutput thread for recorder_save_worker.
This is used to assert that the worker is started and stopped cleanly
to avoid thread leaks in tests.
"""
def __init__(self):
"""Initialize SaveRecordWorkerSync."""
self.reset()
self._segments = None
self._save_thread = None
def recorder_save_worker(self, file_out: str, segments: deque[Segment]):
"""Mock method for patch."""
logging.debug("recorder_save_worker thread started")
assert self._save_thread is None
self._segments = segments
self._save_thread = threading.current_thread()
self._save_event.set()
async def get_segments(self):
"""Return the recorded video segments."""
with async_timeout.timeout(TEST_TIMEOUT):
await self._save_event.wait()
return self._segments
async def join(self):
"""Verify save worker was invoked and block on shutdown."""
with async_timeout.timeout(TEST_TIMEOUT):
await self._save_event.wait()
self._save_thread.join(timeout=TEST_TIMEOUT)
assert not self._save_thread.is_alive()
def reset(self):
"""Reset callback state for reuse in tests."""
self._save_thread = None
self._save_event = asyncio.Event()
@pytest.fixture()
def record_worker_sync(hass):
"""Patch recorder_save_worker for clean thread shutdown for test."""
sync = SaveRecordWorkerSync()
with patch(
"homeassistant.components.stream.recorder.recorder_save_worker",
side_effect=sync.recorder_save_worker,
autospec=True,
):
yield sync
async def test_record_stream(hass, hass_client, record_worker_sync):
"""
Test record stream.
@ -179,6 +118,21 @@ async def test_record_path_not_allowed(hass, hass_client):
await stream.async_record("/example/path")
def add_parts_to_segment(segment, source):
"""Add relevant part data to segment for testing recorder."""
moof_locs = list(find_box(source.getbuffer(), b"moof")) + [len(source.getbuffer())]
segment.init = source.getbuffer()[: moof_locs[0]].tobytes()
segment.parts = [
Part(
duration=None,
has_keyframe=None,
http_range_start=None,
data=source.getbuffer()[moof_locs[i] : moof_locs[i + 1]],
)
for i in range(1, len(moof_locs) - 1)
]
async def test_recorder_save(tmpdir):
"""Test recorder save."""
# Setup
@ -186,9 +140,10 @@ async def test_recorder_save(tmpdir):
filename = f"{tmpdir}/test.mp4"
# Run
recorder_save_worker(
filename, [Segment(1, *get_init_and_moof_data(source.getbuffer()), 4)]
)
segment = Segment(sequence=1)
add_parts_to_segment(segment, source)
segment.duration = 4
recorder_save_worker(filename, [segment])
# Assert
assert os.path.exists(filename)
@ -201,15 +156,13 @@ async def test_recorder_discontinuity(tmpdir):
filename = f"{tmpdir}/test.mp4"
# Run
init, moof_data = get_init_and_moof_data(source.getbuffer())
recorder_save_worker(
filename,
[
Segment(1, init, moof_data, 4, 0),
Segment(2, init, moof_data, 4, 1),
],
)
segment_1 = Segment(sequence=1, stream_id=0)
add_parts_to_segment(segment_1, source)
segment_1.duration = 4
segment_2 = Segment(sequence=2, stream_id=1)
add_parts_to_segment(segment_2, source)
segment_2.duration = 4
recorder_save_worker(filename, [segment_1, segment_2])
# Assert
assert os.path.exists(filename)
@ -263,7 +216,9 @@ async def test_record_stream_audio(
stream_worker_sync.resume()
result = av.open(
BytesIO(last_segment.init + last_segment.moof_data), "r", format="mp4"
BytesIO(last_segment.init + last_segment.get_bytes_without_init()),
"r",
format="mp4",
)
assert len(result.streams.audio) == expected_audio_streams

View File

@ -21,7 +21,7 @@ from unittest.mock import patch
import av
from homeassistant.components.stream import Stream
from homeassistant.components.stream import Stream, create_stream
from homeassistant.components.stream.const import (
HLS_PROVIDER,
MAX_MISSING_DTS,
@ -29,6 +29,9 @@ from homeassistant.components.stream.const import (
TARGET_SEGMENT_DURATION,
)
from homeassistant.components.stream.worker import SegmentBuffer, stream_worker
from homeassistant.setup import async_setup_component
from tests.components.stream.common import generate_h264_video
STREAM_SOURCE = "some-stream-source"
# Formats here are arbitrary, not exercised by tests
@ -99,9 +102,9 @@ class PacketSequence:
super().__init__(3)
time_base = fractions.Fraction(1, VIDEO_FRAME_RATE)
dts = self.packet * PACKET_DURATION / time_base
pts = self.packet * PACKET_DURATION / time_base
duration = PACKET_DURATION / time_base
dts = int(self.packet * PACKET_DURATION / time_base)
pts = int(self.packet * PACKET_DURATION / time_base)
duration = int(PACKET_DURATION / time_base)
stream = VIDEO_STREAM
# Pretend we get 1 keyframe every second
is_keyframe = not (self.packet - 1) % (VIDEO_FRAME_RATE * KEYFRAME_INTERVAL)
@ -177,6 +180,11 @@ class FakePyAvBuffer:
"""Capture the output segment for tests to inspect."""
self.segments.append(segment)
@property
def complete_segments(self):
"""Return only the complete segments."""
return [segment for segment in self.segments if segment.complete]
class MockPyAv:
"""Mocks out av.open."""
@ -197,6 +205,19 @@ class MockPyAv:
return self.container
class MockFlushPart:
"""Class to hold a wrapper function for check_flush_part."""
# Wrap this method with a preceding write so the BytesIO pointer moves
check_flush_part = SegmentBuffer.check_flush_part
@classmethod
def wrapped_check_flush_part(cls, segment_buffer, packet):
"""Wrap check_flush_part to also advance the memory_file pointer."""
segment_buffer._memory_file.write(b"0")
return cls.check_flush_part(segment_buffer, packet)
async def async_decode_stream(hass, packets, py_av=None):
"""Start a stream worker that decodes incoming stream packets into output segments."""
stream = Stream(hass, STREAM_SOURCE)
@ -209,6 +230,10 @@ async def async_decode_stream(hass, packets, py_av=None):
with patch("av.open", new=py_av.open), patch(
"homeassistant.components.stream.core.StreamOutput.put",
side_effect=py_av.capture_buffer.capture_output_segment,
), patch(
"homeassistant.components.stream.worker.SegmentBuffer.check_flush_part",
side_effect=MockFlushPart.wrapped_check_flush_part,
autospec=True,
):
segment_buffer = SegmentBuffer(stream.outputs)
stream_worker(STREAM_SOURCE, {}, segment_buffer, threading.Event())
@ -235,13 +260,16 @@ async def test_stream_worker_success(hass):
hass, PacketSequence(TEST_SEQUENCE_LENGTH)
)
segments = decoded_stream.segments
complete_segments = decoded_stream.complete_segments
# Check number of segments. A segment is only formed when a packet from the next
# segment arrives, hence the subtraction of one from the sequence length.
assert len(segments) == int((TEST_SEQUENCE_LENGTH - 1) * SEGMENTS_PER_PACKET)
assert len(complete_segments) == int(
(TEST_SEQUENCE_LENGTH - 1) * SEGMENTS_PER_PACKET
)
# Check sequence numbers
assert all(segments[i].sequence == i for i in range(len(segments)))
# Check segment durations
assert all(s.duration == SEGMENT_DURATION for s in segments)
assert all(s.duration == SEGMENT_DURATION for s in complete_segments)
assert len(decoded_stream.video_packets) == TEST_SEQUENCE_LENGTH
assert len(decoded_stream.audio_packets) == 0
@ -259,6 +287,7 @@ async def test_skip_out_of_order_packet(hass):
decoded_stream = await async_decode_stream(hass, iter(packets))
segments = decoded_stream.segments
complete_segments = decoded_stream.complete_segments
# Check sequence numbers
assert all(segments[i].sequence == i for i in range(len(segments)))
# If skipped packet would have been the first packet of a segment, the previous
@ -273,12 +302,14 @@ async def test_skip_out_of_order_packet(hass):
)
del segments[longer_segment_index]
# Check number of segments
assert len(segments) == int((len(packets) - 1 - 1) * SEGMENTS_PER_PACKET - 1)
assert len(complete_segments) == int(
(len(packets) - 1 - 1) * SEGMENTS_PER_PACKET - 1
)
else: # Otherwise segment durations and number of segments are unaffected
# Check number of segments
assert len(segments) == int((len(packets) - 1) * SEGMENTS_PER_PACKET)
assert len(complete_segments) == int((len(packets) - 1) * SEGMENTS_PER_PACKET)
# Check remaining segment durations
assert all(s.duration == SEGMENT_DURATION for s in segments)
assert all(s.duration == SEGMENT_DURATION for s in complete_segments)
assert len(decoded_stream.video_packets) == len(packets) - 1
assert len(decoded_stream.audio_packets) == 0
@ -292,12 +323,15 @@ async def test_discard_old_packets(hass):
decoded_stream = await async_decode_stream(hass, iter(packets))
segments = decoded_stream.segments
complete_segments = decoded_stream.complete_segments
# Check number of segments
assert len(segments) == int((OUT_OF_ORDER_PACKET_INDEX - 1) * SEGMENTS_PER_PACKET)
assert len(complete_segments) == int(
(OUT_OF_ORDER_PACKET_INDEX - 1) * SEGMENTS_PER_PACKET
)
# Check sequence numbers
assert all(segments[i].sequence == i for i in range(len(segments)))
# Check segment durations
assert all(s.duration == SEGMENT_DURATION for s in segments)
assert all(s.duration == SEGMENT_DURATION for s in complete_segments)
assert len(decoded_stream.video_packets) == OUT_OF_ORDER_PACKET_INDEX
assert len(decoded_stream.audio_packets) == 0
@ -311,12 +345,15 @@ async def test_packet_overflow(hass):
decoded_stream = await async_decode_stream(hass, iter(packets))
segments = decoded_stream.segments
complete_segments = decoded_stream.complete_segments
# Check number of segments
assert len(segments) == int((OUT_OF_ORDER_PACKET_INDEX - 1) * SEGMENTS_PER_PACKET)
assert len(complete_segments) == int(
(OUT_OF_ORDER_PACKET_INDEX - 1) * SEGMENTS_PER_PACKET
)
# Check sequence numbers
assert all(segments[i].sequence == i for i in range(len(segments)))
# Check segment durations
assert all(s.duration == SEGMENT_DURATION for s in segments)
assert all(s.duration == SEGMENT_DURATION for s in complete_segments)
assert len(decoded_stream.video_packets) == OUT_OF_ORDER_PACKET_INDEX
assert len(decoded_stream.audio_packets) == 0
@ -332,10 +369,11 @@ async def test_skip_initial_bad_packets(hass):
decoded_stream = await async_decode_stream(hass, iter(packets))
segments = decoded_stream.segments
complete_segments = decoded_stream.complete_segments
# Check sequence numbers
assert all(segments[i].sequence == i for i in range(len(segments)))
# Check segment durations
assert all(s.duration == SEGMENT_DURATION for s in segments)
assert all(s.duration == SEGMENT_DURATION for s in complete_segments)
assert (
len(decoded_stream.video_packets)
== num_packets
@ -344,7 +382,7 @@ async def test_skip_initial_bad_packets(hass):
* KEYFRAME_INTERVAL
)
# Check number of segments
assert len(segments) == int(
assert len(complete_segments) == int(
(len(decoded_stream.video_packets) - 1) * SEGMENTS_PER_PACKET
)
assert len(decoded_stream.audio_packets) == 0
@ -381,13 +419,11 @@ async def test_skip_missing_dts(hass):
decoded_stream = await async_decode_stream(hass, iter(packets))
segments = decoded_stream.segments
complete_segments = decoded_stream.complete_segments
# Check sequence numbers
assert all(segments[i].sequence == i for i in range(len(segments)))
# Check segment durations (not counting the last segment)
assert (
sum([segments[i].duration == SEGMENT_DURATION for i in range(len(segments))])
>= len(segments) - 1
)
assert sum(segment.duration for segment in complete_segments) >= len(segments) - 1
assert len(decoded_stream.video_packets) == num_packets - num_bad_packets
assert len(decoded_stream.audio_packets) == 0
@ -403,8 +439,8 @@ async def test_too_many_bad_packets(hass):
packets[i].dts = None
decoded_stream = await async_decode_stream(hass, iter(packets))
segments = decoded_stream.segments
assert len(segments) == int((bad_packet_start - 1) * SEGMENTS_PER_PACKET)
complete_segments = decoded_stream.complete_segments
assert len(complete_segments) == int((bad_packet_start - 1) * SEGMENTS_PER_PACKET)
assert len(decoded_stream.video_packets) == bad_packet_start
assert len(decoded_stream.audio_packets) == 0
@ -431,8 +467,8 @@ async def test_audio_packets_not_found(hass):
packets = PacketSequence(num_packets) # Contains only video packets
decoded_stream = await async_decode_stream(hass, iter(packets), py_av=py_av)
segments = decoded_stream.segments
assert len(segments) == int((num_packets - 1) * SEGMENTS_PER_PACKET)
complete_segments = decoded_stream.complete_segments
assert len(complete_segments) == int((num_packets - 1) * SEGMENTS_PER_PACKET)
assert len(decoded_stream.video_packets) == num_packets
assert len(decoded_stream.audio_packets) == 0
@ -444,8 +480,8 @@ async def test_adts_aac_audio(hass):
num_packets = PACKETS_TO_WAIT_FOR_AUDIO + 1
packets = list(PacketSequence(num_packets))
packets[1].stream = AUDIO_STREAM
packets[1].dts = packets[0].dts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE
packets[1].pts = packets[0].pts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE
packets[1].dts = int(packets[0].dts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE)
packets[1].pts = int(packets[0].pts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE)
# The following is packet data is a sign of ADTS AAC
packets[1][0] = 255
packets[1][1] = 241
@ -462,17 +498,17 @@ async def test_audio_is_first_packet(hass):
packets = list(PacketSequence(num_packets))
# Pair up an audio packet for each video packet
packets[0].stream = AUDIO_STREAM
packets[0].dts = packets[1].dts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE
packets[0].pts = packets[1].pts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE
packets[0].dts = int(packets[1].dts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE)
packets[0].pts = int(packets[1].pts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE)
packets[1].is_keyframe = True # Move the video keyframe from packet 0 to packet 1
packets[2].stream = AUDIO_STREAM
packets[2].dts = packets[3].dts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE
packets[2].pts = packets[3].pts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE
packets[2].dts = int(packets[3].dts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE)
packets[2].pts = int(packets[3].pts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE)
decoded_stream = await async_decode_stream(hass, iter(packets), py_av=py_av)
segments = decoded_stream.segments
complete_segments = decoded_stream.complete_segments
# The audio packets are segmented with the video packets
assert len(segments) == int((num_packets - 2 - 1) * SEGMENTS_PER_PACKET)
assert len(complete_segments) == int((num_packets - 2 - 1) * SEGMENTS_PER_PACKET)
assert len(decoded_stream.video_packets) == num_packets - 2
assert len(decoded_stream.audio_packets) == 1
@ -484,13 +520,13 @@ async def test_audio_packets_found(hass):
num_packets = PACKETS_TO_WAIT_FOR_AUDIO + 1
packets = list(PacketSequence(num_packets))
packets[1].stream = AUDIO_STREAM
packets[1].dts = packets[0].dts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE
packets[1].pts = packets[0].pts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE
packets[1].dts = int(packets[0].dts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE)
packets[1].pts = int(packets[0].pts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE)
decoded_stream = await async_decode_stream(hass, iter(packets), py_av=py_av)
segments = decoded_stream.segments
complete_segments = decoded_stream.complete_segments
# The audio packet above is buffered with the video packet
assert len(segments) == int((num_packets - 1 - 1) * SEGMENTS_PER_PACKET)
assert len(complete_segments) == int((num_packets - 1 - 1) * SEGMENTS_PER_PACKET)
assert len(decoded_stream.video_packets) == num_packets - 1
assert len(decoded_stream.audio_packets) == 1
@ -507,12 +543,15 @@ async def test_pts_out_of_order(hass):
decoded_stream = await async_decode_stream(hass, iter(packets))
segments = decoded_stream.segments
complete_segments = decoded_stream.complete_segments
# Check number of segments
assert len(segments) == int((TEST_SEQUENCE_LENGTH - 1) * SEGMENTS_PER_PACKET)
assert len(complete_segments) == int(
(TEST_SEQUENCE_LENGTH - 1) * SEGMENTS_PER_PACKET
)
# Check sequence numbers
assert all(segments[i].sequence == i for i in range(len(segments)))
# Check segment durations
assert all(s.duration == SEGMENT_DURATION for s in segments)
assert all(s.duration == SEGMENT_DURATION for s in complete_segments)
assert len(decoded_stream.video_packets) == len(packets)
assert len(decoded_stream.audio_packets) == 0
@ -573,7 +612,11 @@ async def test_update_stream_source(hass):
worker_wake.wait()
return py_av.open(stream_source, args, kwargs)
with patch("av.open", new=blocking_open):
with patch("av.open", new=blocking_open), patch(
"homeassistant.components.stream.worker.SegmentBuffer.check_flush_part",
side_effect=MockFlushPart.wrapped_check_flush_part,
autospec=True,
):
stream.start()
assert worker_open.wait(TIMEOUT)
assert last_stream_source == STREAM_SOURCE
@ -604,3 +647,74 @@ async def test_worker_log(hass, caplog):
await hass.async_block_till_done()
assert "https://abcd:efgh@foo.bar" not in caplog.text
assert "https://****:****@foo.bar" in caplog.text
async def test_durations(hass, record_worker_sync):
"""Test that the duration metadata matches the media."""
await async_setup_component(hass, "stream", {"stream": {}})
source = generate_h264_video()
stream = create_stream(hass, source)
# use record_worker_sync to grab output segments
with patch.object(hass.config, "is_allowed_path", return_value=True):
await stream.async_record("/example/path")
complete_segments = list(await record_worker_sync.get_segments())[:-1]
assert len(complete_segments) >= 1
# check that the Part duration metadata matches the durations in the media
running_metadata_duration = 0
for segment in complete_segments:
for part in segment.parts:
av_part = av.open(io.BytesIO(segment.init + part.data))
running_metadata_duration += part.duration
# av_part.duration will just return the largest dts in av_part.
# When we normalize by av.time_base this should equal the running duration
assert math.isclose(
running_metadata_duration,
av_part.duration / av.time_base,
abs_tol=1e-6,
)
av_part.close()
# check that the Part durations are consistent with the Segment durations
for segment in complete_segments:
assert math.isclose(
sum(part.duration for part in segment.parts), segment.duration, abs_tol=1e-6
)
await record_worker_sync.join()
stream.stop()
async def test_has_keyframe(hass, record_worker_sync):
"""Test that the has_keyframe metadata matches the media."""
await async_setup_component(hass, "stream", {"stream": {}})
source = generate_h264_video()
stream = create_stream(hass, source)
# use record_worker_sync to grab output segments
with patch.object(hass.config, "is_allowed_path", return_value=True):
await stream.async_record("/example/path")
# Our test video has keyframes every second. Use smaller parts so we have more
# part boundaries to better test keyframe logic.
with patch("homeassistant.components.stream.worker.TARGET_PART_DURATION", 0.25):
complete_segments = list(await record_worker_sync.get_segments())[:-1]
assert len(complete_segments) >= 1
# check that the Part has_keyframe metadata matches the keyframes in the media
for segment in complete_segments:
for part in segment.parts:
av_part = av.open(io.BytesIO(segment.init + part.data))
media_has_keyframe = any(
packet.is_keyframe for packet in av_part.demux(av_part.streams.video[0])
)
av_part.close()
assert part.has_keyframe == media_has_keyframe
await record_worker_sync.join()
stream.stop()