Refactor stream to create partial segments (#51282)
parent
1adeb82930
commit
123e8f01a1
|
@ -18,8 +18,9 @@ FORMAT_CONTENT_TYPE = {HLS_PROVIDER: "application/vnd.apple.mpegurl"}
|
|||
OUTPUT_IDLE_TIMEOUT = 300 # Idle timeout due to inactivity
|
||||
|
||||
NUM_PLAYLIST_SEGMENTS = 3 # Number of segments to use in HLS playlist
|
||||
MAX_SEGMENTS = 4 # Max number of segments to keep around
|
||||
MAX_SEGMENTS = 5 # Max number of segments to keep around
|
||||
TARGET_SEGMENT_DURATION = 2.0 # Each segment is about this many seconds
|
||||
TARGET_PART_DURATION = 1.0
|
||||
SEGMENT_DURATION_ADJUSTER = 0.1 # Used to avoid missing keyframe boundaries
|
||||
# Each segment is at least this many seconds
|
||||
MIN_SEGMENT_DURATION = TARGET_SEGMENT_DURATION - SEGMENT_DURATION_ADJUSTER
|
||||
|
|
|
@ -19,20 +19,37 @@ from .const import ATTR_STREAMS, DOMAIN
|
|||
PROVIDERS = Registry()
|
||||
|
||||
|
||||
@attr.s(slots=True)
|
||||
class Part:
|
||||
"""Represent a segment part."""
|
||||
|
||||
duration: float = attr.ib()
|
||||
has_keyframe: bool = attr.ib()
|
||||
data: bytes = attr.ib()
|
||||
|
||||
|
||||
@attr.s(slots=True)
|
||||
class Segment:
|
||||
"""Represent a segment."""
|
||||
|
||||
sequence: int = attr.ib()
|
||||
# the init of the mp4
|
||||
init: bytes = attr.ib()
|
||||
# the video data (moof + mddat)s of the mp4
|
||||
moof_data: bytes = attr.ib()
|
||||
duration: float = attr.ib()
|
||||
sequence: int = attr.ib(default=0)
|
||||
# the init of the mp4 the segment is based on
|
||||
init: bytes = attr.ib(default=None)
|
||||
duration: float = attr.ib(default=0)
|
||||
# For detecting discontinuities across stream restarts
|
||||
stream_id: int = attr.ib(default=0)
|
||||
parts: list[Part] = attr.ib(factory=list)
|
||||
start_time: datetime.datetime = attr.ib(factory=datetime.datetime.utcnow)
|
||||
|
||||
@property
|
||||
def complete(self) -> bool:
|
||||
"""Return whether the Segment is complete."""
|
||||
return self.duration > 0
|
||||
|
||||
def get_bytes_without_init(self) -> bytes:
|
||||
"""Return reconstructed data for entire segment as bytes."""
|
||||
return b"".join([part.data for part in self.parts])
|
||||
|
||||
|
||||
class IdleTimer:
|
||||
"""Invoke a callback after an inactivity timeout.
|
||||
|
|
|
@ -25,16 +25,6 @@ def find_box(
|
|||
index += int.from_bytes(box_header[0:4], byteorder="big")
|
||||
|
||||
|
||||
def get_init_and_moof_data(segment: memoryview) -> tuple[bytes, bytes]:
|
||||
"""Get the init and moof data from a segment."""
|
||||
moof_location = next(find_box(segment, b"moof"), 0)
|
||||
mfra_location = next(find_box(segment, b"mfra"), len(segment))
|
||||
return (
|
||||
segment[:moof_location].tobytes(),
|
||||
segment[moof_location:mfra_location].tobytes(),
|
||||
)
|
||||
|
||||
|
||||
def get_codec_string(mp4_bytes: bytes) -> str:
|
||||
"""Get RFC 6381 codec string."""
|
||||
codecs = []
|
||||
|
|
|
@ -37,9 +37,12 @@ class HlsMasterPlaylistView(StreamView):
|
|||
# Need to calculate max bandwidth as input_container.bit_rate doesn't seem to work
|
||||
# Calculate file size / duration and use a small multiplier to account for variation
|
||||
# hls spec already allows for 25% variation
|
||||
segment = track.get_segment(track.sequences[-1])
|
||||
segment = track.get_segment(track.sequences[-2])
|
||||
bandwidth = round(
|
||||
(len(segment.init) + len(segment.moof_data)) * 8 / segment.duration * 1.2
|
||||
(len(segment.init) + sum(len(part.data) for part in segment.parts))
|
||||
* 8
|
||||
/ segment.duration
|
||||
* 1.2
|
||||
)
|
||||
codecs = get_codec_string(segment.init)
|
||||
lines = [
|
||||
|
@ -53,9 +56,11 @@ class HlsMasterPlaylistView(StreamView):
|
|||
"""Return m3u8 playlist."""
|
||||
track = stream.add_provider(HLS_PROVIDER)
|
||||
stream.start()
|
||||
# Wait for a segment to be ready
|
||||
# Make sure at least two segments are ready (last one may not be complete)
|
||||
if not track.sequences and not await track.recv():
|
||||
return web.HTTPNotFound()
|
||||
if len(track.sequences) == 1 and not await track.recv():
|
||||
return web.HTTPNotFound()
|
||||
headers = {"Content-Type": FORMAT_CONTENT_TYPE[HLS_PROVIDER]}
|
||||
return web.Response(body=self.render(track).encode("utf-8"), headers=headers)
|
||||
|
||||
|
@ -68,69 +73,72 @@ class HlsPlaylistView(StreamView):
|
|||
cors_allowed = True
|
||||
|
||||
@staticmethod
|
||||
def render_preamble(track):
|
||||
"""Render preamble."""
|
||||
return [
|
||||
"#EXT-X-VERSION:6",
|
||||
f"#EXT-X-TARGETDURATION:{track.target_duration}",
|
||||
'#EXT-X-MAP:URI="init.mp4"',
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def render_playlist(track):
|
||||
def render(track):
|
||||
"""Render playlist."""
|
||||
segments = list(track.get_segments())[-NUM_PLAYLIST_SEGMENTS:]
|
||||
# NUM_PLAYLIST_SEGMENTS+1 because most recent is probably not yet complete
|
||||
segments = list(track.get_segments())[-(NUM_PLAYLIST_SEGMENTS + 1) :]
|
||||
|
||||
if not segments:
|
||||
return []
|
||||
# To cap the number of complete segments at NUM_PLAYLIST_SEGMENTS,
|
||||
# remove the first segment if the last segment is actually complete
|
||||
if segments[-1].complete:
|
||||
segments = segments[-NUM_PLAYLIST_SEGMENTS:]
|
||||
|
||||
first_segment = segments[0]
|
||||
playlist = [
|
||||
"#EXTM3U",
|
||||
"#EXT-X-VERSION:6",
|
||||
"#EXT-X-INDEPENDENT-SEGMENTS",
|
||||
'#EXT-X-MAP:URI="init.mp4"',
|
||||
f"#EXT-X-TARGETDURATION:{track.target_duration:.0f}",
|
||||
f"#EXT-X-MEDIA-SEQUENCE:{first_segment.sequence}",
|
||||
f"#EXT-X-DISCONTINUITY-SEQUENCE:{first_segment.stream_id}",
|
||||
"#EXT-X-PROGRAM-DATE-TIME:"
|
||||
+ first_segment.start_time.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3]
|
||||
+ "Z",
|
||||
# Since our window doesn't have many segments, we don't want to start
|
||||
# at the beginning or we risk a behind live window exception in exoplayer.
|
||||
# at the beginning or we risk a behind live window exception in Exoplayer.
|
||||
# EXT-X-START is not supposed to be within 3 target durations of the end,
|
||||
# but this seems ok
|
||||
f"#EXT-X-START:TIME-OFFSET=-{EXT_X_START * track.target_duration:.3f},PRECISE=YES",
|
||||
# but a value as low as 1.5 doesn't seem to hurt.
|
||||
# A value below 3 may not be as useful for hls.js as many hls.js clients
|
||||
# don't autoplay. Also, hls.js uses the player parameter liveSyncDuration
|
||||
# which seems to take precedence for setting target delay. Yet it also
|
||||
# doesn't seem to hurt, so we can stick with it for now.
|
||||
f"#EXT-X-START:TIME-OFFSET=-{EXT_X_START * track.target_duration:.3f}",
|
||||
]
|
||||
|
||||
last_stream_id = first_segment.stream_id
|
||||
# Add playlist sections
|
||||
for segment in segments:
|
||||
if last_stream_id != segment.stream_id:
|
||||
# Skip last segment if it is not complete
|
||||
if segment.complete:
|
||||
if last_stream_id != segment.stream_id:
|
||||
playlist.extend(
|
||||
[
|
||||
"#EXT-X-DISCONTINUITY",
|
||||
"#EXT-X-PROGRAM-DATE-TIME:"
|
||||
+ segment.start_time.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3]
|
||||
+ "Z",
|
||||
]
|
||||
)
|
||||
playlist.extend(
|
||||
[
|
||||
"#EXT-X-DISCONTINUITY",
|
||||
"#EXT-X-PROGRAM-DATE-TIME:"
|
||||
+ segment.start_time.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3]
|
||||
+ "Z",
|
||||
f"#EXTINF:{segment.duration:.3f},",
|
||||
f"./segment/{segment.sequence}.m4s",
|
||||
]
|
||||
)
|
||||
playlist.extend(
|
||||
[
|
||||
f"#EXTINF:{float(segment.duration):.04f},",
|
||||
f"./segment/{segment.sequence}.m4s",
|
||||
]
|
||||
)
|
||||
last_stream_id = segment.stream_id
|
||||
last_stream_id = segment.stream_id
|
||||
|
||||
return playlist
|
||||
|
||||
def render(self, track):
|
||||
"""Render M3U8 file."""
|
||||
lines = ["#EXTM3U"] + self.render_preamble(track) + self.render_playlist(track)
|
||||
return "\n".join(lines) + "\n"
|
||||
return "\n".join(playlist) + "\n"
|
||||
|
||||
async def handle(self, request, stream, sequence):
|
||||
"""Return m3u8 playlist."""
|
||||
track = stream.add_provider(HLS_PROVIDER)
|
||||
stream.start()
|
||||
# Wait for a segment to be ready
|
||||
# Make sure at least two segments are ready (last one may not be complete)
|
||||
if not track.sequences and not await track.recv():
|
||||
return web.HTTPNotFound()
|
||||
if len(track.sequences) == 1 and not await track.recv():
|
||||
return web.HTTPNotFound()
|
||||
headers = {"Content-Type": FORMAT_CONTENT_TYPE[HLS_PROVIDER]}
|
||||
response = web.Response(
|
||||
body=self.render(track).encode("utf-8"), headers=headers
|
||||
|
@ -170,7 +178,7 @@ class HlsSegmentView(StreamView):
|
|||
return web.HTTPNotFound()
|
||||
headers = {"Content-Type": "video/iso.segment"}
|
||||
return web.Response(
|
||||
body=segment.moof_data,
|
||||
body=segment.get_bytes_without_init(),
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
|
|
|
@ -57,7 +57,7 @@ def recorder_save_worker(file_out: str, segments: deque[Segment]):
|
|||
|
||||
# Open segment
|
||||
source = av.open(
|
||||
BytesIO(segment.init + segment.moof_data),
|
||||
BytesIO(segment.init + segment.get_bytes_without_init()),
|
||||
"r",
|
||||
format=SEGMENT_CONTAINER_FORMAT,
|
||||
)
|
||||
|
|
|
@ -2,9 +2,12 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from collections import deque
|
||||
from collections.abc import Iterator, Mapping
|
||||
from fractions import Fraction
|
||||
from io import BytesIO
|
||||
import logging
|
||||
from typing import cast
|
||||
from threading import Event
|
||||
from typing import Callable, cast
|
||||
|
||||
import av
|
||||
|
||||
|
@ -17,9 +20,9 @@ from .const import (
|
|||
PACKETS_TO_WAIT_FOR_AUDIO,
|
||||
SEGMENT_CONTAINER_FORMAT,
|
||||
SOURCE_TIMEOUT,
|
||||
TARGET_PART_DURATION,
|
||||
)
|
||||
from .core import Segment, StreamOutput
|
||||
from .fmp4utils import get_init_and_moof_data
|
||||
from .core import Part, Segment, StreamOutput
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
@ -27,22 +30,28 @@ _LOGGER = logging.getLogger(__name__)
|
|||
class SegmentBuffer:
|
||||
"""Buffer for writing a sequence of packets to the output as a segment."""
|
||||
|
||||
def __init__(self, outputs_callback) -> None:
|
||||
def __init__(
|
||||
self, outputs_callback: Callable[[], Mapping[str, StreamOutput]]
|
||||
) -> None:
|
||||
"""Initialize SegmentBuffer."""
|
||||
self._stream_id = 0
|
||||
self._outputs_callback = outputs_callback
|
||||
self._outputs: list[StreamOutput] = []
|
||||
self._stream_id: int = 0
|
||||
self._outputs_callback: Callable[
|
||||
[], Mapping[str, StreamOutput]
|
||||
] = outputs_callback
|
||||
# sequence gets incremented before the first segment so the first segment
|
||||
# has a sequence number of 0.
|
||||
self._sequence = -1
|
||||
self._segment_start_pts = None
|
||||
self._segment_start_dts: int = cast(int, None)
|
||||
self._memory_file: BytesIO = cast(BytesIO, None)
|
||||
self._av_output: av.container.OutputContainer = None
|
||||
self._input_video_stream: av.video.VideoStream = None
|
||||
self._input_audio_stream = None # av.audio.AudioStream | None
|
||||
self._output_video_stream: av.video.VideoStream = None
|
||||
self._output_audio_stream = None # av.audio.AudioStream | None
|
||||
self._segment: Segment = cast(Segment, None)
|
||||
self._segment: Segment | None = None
|
||||
self._segment_last_write_pos: int = cast(int, None)
|
||||
self._part_start_dts: int = cast(int, None)
|
||||
self._part_has_keyframe = False
|
||||
|
||||
@staticmethod
|
||||
def make_new_av(
|
||||
|
@ -56,10 +65,17 @@ class SegmentBuffer:
|
|||
container_options={
|
||||
# Removed skip_sidx - see https://github.com/home-assistant/core/pull/39970
|
||||
# "cmaf" flag replaces several of the movflags used, but too recent to use for now
|
||||
"movflags": "frag_custom+empty_moov+default_base_moof+frag_discont+negative_cts_offsets+skip_trailer",
|
||||
"avoid_negative_ts": "disabled",
|
||||
"movflags": "empty_moov+default_base_moof+frag_discont+negative_cts_offsets+skip_trailer",
|
||||
# Sometimes the first segment begins with negative timestamps, and this setting just
|
||||
# adjusts the timestamps in the output from that segment to start from 0. Helps from
|
||||
# having to make some adjustments in test_durations
|
||||
"avoid_negative_ts": "make_non_negative",
|
||||
"fragment_index": str(sequence + 1),
|
||||
"video_track_timescale": str(int(1 / input_vstream.time_base)),
|
||||
# Create a fragments every TARGET_PART_DURATION. The data from each fragment is stored in
|
||||
# a "Part" that can be combined with the data from all the other "Part"s, plus an init
|
||||
# section, to reconstitute the data in a "Segment".
|
||||
"frag_duration": str(int(TARGET_PART_DURATION * 1e6)),
|
||||
},
|
||||
)
|
||||
|
||||
|
@ -73,15 +89,13 @@ class SegmentBuffer:
|
|||
self._input_video_stream = video_stream
|
||||
self._input_audio_stream = audio_stream
|
||||
|
||||
def reset(self, video_pts):
|
||||
def reset(self, video_dts: int) -> None:
|
||||
"""Initialize a new stream segment."""
|
||||
# Keep track of the number of segments we've processed
|
||||
self._sequence += 1
|
||||
self._segment_start_pts = video_pts
|
||||
|
||||
# Fetch the latest StreamOutputs, which may have changed since the
|
||||
# worker started.
|
||||
self._outputs = self._outputs_callback().values()
|
||||
self._segment_start_dts = self._part_start_dts = video_dts
|
||||
self._segment = None
|
||||
self._segment_last_write_pos = 0
|
||||
self._memory_file = BytesIO()
|
||||
self._av_output = self.make_new_av(
|
||||
memory_file=self._memory_file,
|
||||
|
@ -98,54 +112,102 @@ class SegmentBuffer:
|
|||
template=self._input_audio_stream
|
||||
)
|
||||
|
||||
def mux_packet(self, packet):
|
||||
def mux_packet(self, packet: av.Packet) -> None:
|
||||
"""Mux a packet to the appropriate output stream."""
|
||||
|
||||
# Check for end of segment
|
||||
if packet.stream == self._input_video_stream and packet.is_keyframe:
|
||||
duration = (packet.pts - self._segment_start_pts) * packet.time_base
|
||||
if duration >= MIN_SEGMENT_DURATION:
|
||||
# Save segment to outputs
|
||||
self.flush(duration)
|
||||
|
||||
# Reinitialize
|
||||
self.reset(packet.pts)
|
||||
|
||||
# Mux the packet
|
||||
if packet.stream == self._input_video_stream:
|
||||
|
||||
if (
|
||||
packet.is_keyframe
|
||||
and (
|
||||
segment_duration := (packet.dts - self._segment_start_dts)
|
||||
* packet.time_base
|
||||
)
|
||||
>= MIN_SEGMENT_DURATION
|
||||
):
|
||||
# Flush segment (also flushes the stub part segment)
|
||||
self.flush(segment_duration, packet)
|
||||
# Reinitialize
|
||||
self.reset(packet.dts)
|
||||
|
||||
# Mux the packet
|
||||
packet.stream = self._output_video_stream
|
||||
self._av_output.mux(packet)
|
||||
self.check_flush_part(packet)
|
||||
self._part_has_keyframe |= packet.is_keyframe
|
||||
|
||||
elif packet.stream == self._input_audio_stream:
|
||||
packet.stream = self._output_audio_stream
|
||||
self._av_output.mux(packet)
|
||||
|
||||
def flush(self, duration):
|
||||
def check_flush_part(self, packet: av.Packet) -> None:
|
||||
"""Check for and mark a part segment boundary and record its duration."""
|
||||
byte_position = self._memory_file.tell()
|
||||
if self._segment_last_write_pos == byte_position:
|
||||
return
|
||||
if self._segment is None:
|
||||
# We have our first non-zero byte position. This means the init has just
|
||||
# been written. Create a Segment and put it to the queue of each output.
|
||||
self._segment = Segment(
|
||||
sequence=self._sequence,
|
||||
stream_id=self._stream_id,
|
||||
init=self._memory_file.getvalue(),
|
||||
)
|
||||
self._segment_last_write_pos = byte_position
|
||||
# Fetch the latest StreamOutputs, which may have changed since the
|
||||
# worker started.
|
||||
for stream_output in self._outputs_callback().values():
|
||||
stream_output.put(self._segment)
|
||||
else: # These are the ends of the part segments
|
||||
self._segment.parts.append(
|
||||
Part(
|
||||
duration=float(
|
||||
(packet.dts - self._part_start_dts) * packet.time_base
|
||||
),
|
||||
has_keyframe=self._part_has_keyframe,
|
||||
data=self._memory_file.getbuffer()[
|
||||
self._segment_last_write_pos : byte_position
|
||||
].tobytes(),
|
||||
)
|
||||
)
|
||||
self._segment_last_write_pos = byte_position
|
||||
self._part_start_dts = packet.dts
|
||||
self._part_has_keyframe = False
|
||||
|
||||
def flush(self, duration: Fraction, packet: av.Packet) -> None:
|
||||
"""Create a segment from the buffered packets and write to output."""
|
||||
self._av_output.close()
|
||||
segment = Segment(
|
||||
self._sequence,
|
||||
*get_init_and_moof_data(self._memory_file.getbuffer()),
|
||||
duration,
|
||||
self._stream_id,
|
||||
assert self._segment
|
||||
self._segment.duration = float(duration)
|
||||
# Also flush the part segment (need to close the output above before this)
|
||||
self._segment.parts.append(
|
||||
Part(
|
||||
duration=float((packet.dts - self._part_start_dts) * packet.time_base),
|
||||
has_keyframe=self._part_has_keyframe,
|
||||
data=self._memory_file.getbuffer()[
|
||||
self._segment_last_write_pos :
|
||||
].tobytes(),
|
||||
)
|
||||
)
|
||||
self._memory_file.close()
|
||||
for stream_output in self._outputs:
|
||||
stream_output.put(segment)
|
||||
self._memory_file.close() # We don't need the BytesIO object anymore
|
||||
|
||||
def discontinuity(self):
|
||||
def discontinuity(self) -> None:
|
||||
"""Mark the stream as having been restarted."""
|
||||
# Preserving sequence and stream_id here keep the HLS playlist logic
|
||||
# simple to check for discontinuity at output time, and to determine
|
||||
# the discontinuity sequence number.
|
||||
self._stream_id += 1
|
||||
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
"""Close stream buffer."""
|
||||
self._av_output.close()
|
||||
self._memory_file.close()
|
||||
|
||||
|
||||
def stream_worker(source, options, segment_buffer, quit_event): # noqa: C901
|
||||
def stream_worker( # noqa: C901
|
||||
source: str, options: dict, segment_buffer: SegmentBuffer, quit_event: Event
|
||||
) -> None:
|
||||
"""Handle consuming streams."""
|
||||
|
||||
try:
|
||||
|
@ -172,27 +234,27 @@ def stream_worker(source, options, segment_buffer, quit_event): # noqa: C901
|
|||
audio_stream = None
|
||||
|
||||
# Iterator for demuxing
|
||||
container_packets = None
|
||||
container_packets: Iterator[av.Packet]
|
||||
# The decoder timestamps of the latest packet in each stream we processed
|
||||
last_dts = {video_stream: float("-inf"), audio_stream: float("-inf")}
|
||||
# Keep track of consecutive packets without a dts to detect end of stream.
|
||||
missing_dts = 0
|
||||
# The video pts at the beginning of the segment
|
||||
segment_start_pts = None
|
||||
# The video dts at the beginning of the segment
|
||||
segment_start_dts: int | None = None
|
||||
# Because of problems 1 and 2 below, we need to store the first few packets and replay them
|
||||
initial_packets = deque()
|
||||
initial_packets: deque[av.Packet] = deque()
|
||||
|
||||
# Have to work around two problems with RTSP feeds in ffmpeg
|
||||
# 1 - first frame has bad pts/dts https://trac.ffmpeg.org/ticket/5018
|
||||
# 2 - seeking can be problematic https://trac.ffmpeg.org/ticket/7815
|
||||
|
||||
def peek_first_pts():
|
||||
def peek_first_dts() -> bool:
|
||||
"""Initialize by peeking into the first few packets of the stream.
|
||||
|
||||
Deal with problem #1 above (bad first packet pts/dts) by recalculating using pts/dts from second packet.
|
||||
Also load the first video keyframe pts into segment_start_pts and check if the audio stream really exists.
|
||||
Also load the first video keyframe dts into segment_start_dts and check if the audio stream really exists.
|
||||
"""
|
||||
nonlocal segment_start_pts, audio_stream, container_packets
|
||||
nonlocal segment_start_dts, audio_stream, container_packets
|
||||
missing_dts = 0
|
||||
found_audio = False
|
||||
try:
|
||||
|
@ -215,8 +277,8 @@ def stream_worker(source, options, segment_buffer, quit_event): # noqa: C901
|
|||
elif packet.is_keyframe: # video_keyframe
|
||||
first_packet = packet
|
||||
initial_packets.append(packet)
|
||||
# Get first_pts from subsequent frame to first keyframe
|
||||
while segment_start_pts is None or (
|
||||
# Get first_dts from subsequent frame to first keyframe
|
||||
while segment_start_dts is None or (
|
||||
audio_stream
|
||||
and not found_audio
|
||||
and len(initial_packets) < PACKETS_TO_WAIT_FOR_AUDIO
|
||||
|
@ -244,11 +306,10 @@ def stream_worker(source, options, segment_buffer, quit_event): # noqa: C901
|
|||
continue
|
||||
found_audio = True
|
||||
elif (
|
||||
segment_start_pts is None
|
||||
): # This is the second video frame to calculate first_pts from
|
||||
segment_start_pts = packet.dts - packet.duration
|
||||
first_packet.pts = segment_start_pts
|
||||
first_packet.dts = segment_start_pts
|
||||
segment_start_dts is None
|
||||
): # This is the second video frame to calculate first_dts from
|
||||
segment_start_dts = packet.dts - packet.duration
|
||||
first_packet.pts = first_packet.dts = segment_start_dts
|
||||
initial_packets.append(packet)
|
||||
if audio_stream and not found_audio:
|
||||
_LOGGER.warning(
|
||||
|
@ -263,12 +324,13 @@ def stream_worker(source, options, segment_buffer, quit_event): # noqa: C901
|
|||
return False
|
||||
return True
|
||||
|
||||
if not peek_first_pts():
|
||||
if not peek_first_dts():
|
||||
container.close()
|
||||
return
|
||||
|
||||
segment_buffer.set_streams(video_stream, audio_stream)
|
||||
segment_buffer.reset(segment_start_pts)
|
||||
assert isinstance(segment_start_dts, int)
|
||||
segment_buffer.reset(segment_start_dts)
|
||||
|
||||
while not quit_event.is_set():
|
||||
try:
|
||||
|
|
|
@ -9,13 +9,21 @@ nothing for the test to verify. The solution is the WorkerSync class that
|
|||
allows the tests to pause the worker thread before finalizing the stream
|
||||
so that it can inspect the output.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections import deque
|
||||
import logging
|
||||
import threading
|
||||
from unittest.mock import patch
|
||||
|
||||
import async_timeout
|
||||
import pytest
|
||||
|
||||
from homeassistant.components.stream import Stream
|
||||
from homeassistant.components.stream.core import Segment
|
||||
|
||||
TEST_TIMEOUT = 7.0 # Lower than 9s home assistant timeout
|
||||
|
||||
|
||||
class WorkerSync:
|
||||
|
@ -58,3 +66,57 @@ def stream_worker_sync(hass):
|
|||
autospec=True,
|
||||
):
|
||||
yield sync
|
||||
|
||||
|
||||
class SaveRecordWorkerSync:
|
||||
"""
|
||||
Test fixture to manage RecordOutput thread for recorder_save_worker.
|
||||
|
||||
This is used to assert that the worker is started and stopped cleanly
|
||||
to avoid thread leaks in tests.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize SaveRecordWorkerSync."""
|
||||
self._save_event = None
|
||||
self._segments = None
|
||||
self._save_thread = None
|
||||
self.reset()
|
||||
|
||||
def recorder_save_worker(self, file_out: str, segments: deque[Segment]):
|
||||
"""Mock method for patch."""
|
||||
logging.debug("recorder_save_worker thread started")
|
||||
assert self._save_thread is None
|
||||
self._segments = segments
|
||||
self._save_thread = threading.current_thread()
|
||||
self._save_event.set()
|
||||
|
||||
async def get_segments(self):
|
||||
"""Return the recorded video segments."""
|
||||
with async_timeout.timeout(TEST_TIMEOUT):
|
||||
await self._save_event.wait()
|
||||
return self._segments
|
||||
|
||||
async def join(self):
|
||||
"""Verify save worker was invoked and block on shutdown."""
|
||||
with async_timeout.timeout(TEST_TIMEOUT):
|
||||
await self._save_event.wait()
|
||||
self._save_thread.join(timeout=TEST_TIMEOUT)
|
||||
assert not self._save_thread.is_alive()
|
||||
|
||||
def reset(self):
|
||||
"""Reset callback state for reuse in tests."""
|
||||
self._save_thread = None
|
||||
self._save_event = asyncio.Event()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def record_worker_sync(hass):
|
||||
"""Patch recorder_save_worker for clean thread shutdown for test."""
|
||||
sync = SaveRecordWorkerSync()
|
||||
with patch(
|
||||
"homeassistant.components.stream.recorder.recorder_save_worker",
|
||||
side_effect=sync.recorder_save_worker,
|
||||
autospec=True,
|
||||
):
|
||||
yield sync
|
||||
|
|
|
@ -12,7 +12,7 @@ from homeassistant.components.stream.const import (
|
|||
MAX_SEGMENTS,
|
||||
NUM_PLAYLIST_SEGMENTS,
|
||||
)
|
||||
from homeassistant.components.stream.core import Segment
|
||||
from homeassistant.components.stream.core import Part, Segment
|
||||
from homeassistant.const import HTTP_NOT_FOUND
|
||||
from homeassistant.setup import async_setup_component
|
||||
import homeassistant.util.dt as dt_util
|
||||
|
@ -22,7 +22,7 @@ from tests.components.stream.common import generate_h264_video
|
|||
|
||||
STREAM_SOURCE = "some-stream-source"
|
||||
INIT_BYTES = b"init"
|
||||
MOOF_BYTES = b"some-bytes"
|
||||
FAKE_PAYLOAD = b"fake-payload"
|
||||
SEGMENT_DURATION = 10
|
||||
TEST_TIMEOUT = 5.0 # Lower than 9s home assistant timeout
|
||||
MAX_ABORT_SEGMENTS = 20 # Abort test to avoid looping forever
|
||||
|
@ -70,23 +70,24 @@ def make_segment(segment, discontinuity=False):
|
|||
+ "Z",
|
||||
]
|
||||
)
|
||||
response.extend(["#EXTINF:10.0000,", f"./segment/{segment}.m4s"]),
|
||||
response.extend([f"#EXTINF:{SEGMENT_DURATION:.3f},", f"./segment/{segment}.m4s"])
|
||||
return "\n".join(response)
|
||||
|
||||
|
||||
def make_playlist(sequence, discontinuity_sequence=0, segments=[]):
|
||||
def make_playlist(sequence, segments, discontinuity_sequence=0):
|
||||
"""Create a an hls playlist response for tests to assert on."""
|
||||
response = [
|
||||
"#EXTM3U",
|
||||
"#EXT-X-VERSION:6",
|
||||
"#EXT-X-TARGETDURATION:10",
|
||||
"#EXT-X-INDEPENDENT-SEGMENTS",
|
||||
'#EXT-X-MAP:URI="init.mp4"',
|
||||
"#EXT-X-TARGETDURATION:10",
|
||||
f"#EXT-X-MEDIA-SEQUENCE:{sequence}",
|
||||
f"#EXT-X-DISCONTINUITY-SEQUENCE:{discontinuity_sequence}",
|
||||
"#EXT-X-PROGRAM-DATE-TIME:"
|
||||
+ FAKE_TIME.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3]
|
||||
+ "Z",
|
||||
f"#EXT-X-START:TIME-OFFSET=-{1.5*SEGMENT_DURATION:.3f},PRECISE=YES",
|
||||
f"#EXT-X-START:TIME-OFFSET=-{1.5*SEGMENT_DURATION:.3f}",
|
||||
]
|
||||
response.extend(segments)
|
||||
response.append("")
|
||||
|
@ -264,21 +265,26 @@ async def test_hls_playlist_view(hass, hls_stream, stream_worker_sync):
|
|||
stream_worker_sync.pause()
|
||||
hls = stream.add_provider(HLS_PROVIDER)
|
||||
|
||||
hls.put(Segment(1, INIT_BYTES, MOOF_BYTES, SEGMENT_DURATION, start_time=FAKE_TIME))
|
||||
for i in range(2):
|
||||
segment = Segment(sequence=i, duration=SEGMENT_DURATION, start_time=FAKE_TIME)
|
||||
hls.put(segment)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
hls_client = await hls_stream(stream)
|
||||
|
||||
resp = await hls_client.get("/playlist.m3u8")
|
||||
assert resp.status == 200
|
||||
assert await resp.text() == make_playlist(sequence=1, segments=[make_segment(1)])
|
||||
assert await resp.text() == make_playlist(
|
||||
sequence=0, segments=[make_segment(0), make_segment(1)]
|
||||
)
|
||||
|
||||
hls.put(Segment(2, INIT_BYTES, MOOF_BYTES, SEGMENT_DURATION, start_time=FAKE_TIME))
|
||||
segment = Segment(sequence=2, duration=SEGMENT_DURATION, start_time=FAKE_TIME)
|
||||
hls.put(segment)
|
||||
await hass.async_block_till_done()
|
||||
resp = await hls_client.get("/playlist.m3u8")
|
||||
assert resp.status == 200
|
||||
assert await resp.text() == make_playlist(
|
||||
sequence=1, segments=[make_segment(1), make_segment(2)]
|
||||
sequence=0, segments=[make_segment(0), make_segment(1), make_segment(2)]
|
||||
)
|
||||
|
||||
stream_worker_sync.resume()
|
||||
|
@ -296,37 +302,40 @@ async def test_hls_max_segments(hass, hls_stream, stream_worker_sync):
|
|||
hls_client = await hls_stream(stream)
|
||||
|
||||
# Produce enough segments to overfill the output buffer by one
|
||||
for sequence in range(1, MAX_SEGMENTS + 2):
|
||||
hls.put(
|
||||
Segment(
|
||||
sequence,
|
||||
INIT_BYTES,
|
||||
MOOF_BYTES,
|
||||
SEGMENT_DURATION,
|
||||
start_time=FAKE_TIME,
|
||||
)
|
||||
for sequence in range(MAX_SEGMENTS + 1):
|
||||
segment = Segment(
|
||||
sequence=sequence, duration=SEGMENT_DURATION, start_time=FAKE_TIME
|
||||
)
|
||||
hls.put(segment)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
resp = await hls_client.get("/playlist.m3u8")
|
||||
assert resp.status == 200
|
||||
|
||||
# Only NUM_PLAYLIST_SEGMENTS are returned in the playlist.
|
||||
start = MAX_SEGMENTS + 2 - NUM_PLAYLIST_SEGMENTS
|
||||
start = MAX_SEGMENTS + 1 - NUM_PLAYLIST_SEGMENTS
|
||||
segments = []
|
||||
for sequence in range(start, MAX_SEGMENTS + 2):
|
||||
for sequence in range(start, MAX_SEGMENTS + 1):
|
||||
segments.append(make_segment(sequence))
|
||||
assert await resp.text() == make_playlist(
|
||||
sequence=start,
|
||||
segments=segments,
|
||||
)
|
||||
assert await resp.text() == make_playlist(sequence=start, segments=segments)
|
||||
|
||||
# Fetch the actual segments with a fake byte payload
|
||||
for segment in hls.get_segments():
|
||||
segment.init = INIT_BYTES
|
||||
segment.parts = [
|
||||
Part(
|
||||
duration=SEGMENT_DURATION,
|
||||
has_keyframe=True,
|
||||
data=FAKE_PAYLOAD,
|
||||
)
|
||||
]
|
||||
|
||||
# The segment that fell off the buffer is not accessible
|
||||
segment_response = await hls_client.get("/segment/1.m4s")
|
||||
segment_response = await hls_client.get("/segment/0.m4s")
|
||||
assert segment_response.status == 404
|
||||
|
||||
# However all segments in the buffer are accessible, even those that were not in the playlist.
|
||||
for sequence in range(2, MAX_SEGMENTS + 2):
|
||||
for sequence in range(1, MAX_SEGMENTS + 1):
|
||||
segment_response = await hls_client.get(f"/segment/{sequence}.m4s")
|
||||
assert segment_response.status == 200
|
||||
|
||||
|
@ -342,36 +351,21 @@ async def test_hls_playlist_view_discontinuity(hass, hls_stream, stream_worker_s
|
|||
stream_worker_sync.pause()
|
||||
hls = stream.add_provider(HLS_PROVIDER)
|
||||
|
||||
hls.put(
|
||||
Segment(
|
||||
1,
|
||||
INIT_BYTES,
|
||||
MOOF_BYTES,
|
||||
SEGMENT_DURATION,
|
||||
stream_id=0,
|
||||
start_time=FAKE_TIME,
|
||||
)
|
||||
segment = Segment(
|
||||
sequence=0, stream_id=0, duration=SEGMENT_DURATION, start_time=FAKE_TIME
|
||||
)
|
||||
hls.put(
|
||||
Segment(
|
||||
2,
|
||||
INIT_BYTES,
|
||||
MOOF_BYTES,
|
||||
SEGMENT_DURATION,
|
||||
stream_id=0,
|
||||
start_time=FAKE_TIME,
|
||||
)
|
||||
hls.put(segment)
|
||||
segment = Segment(
|
||||
sequence=1, stream_id=0, duration=SEGMENT_DURATION, start_time=FAKE_TIME
|
||||
)
|
||||
hls.put(
|
||||
Segment(
|
||||
3,
|
||||
INIT_BYTES,
|
||||
MOOF_BYTES,
|
||||
SEGMENT_DURATION,
|
||||
stream_id=1,
|
||||
start_time=FAKE_TIME,
|
||||
)
|
||||
hls.put(segment)
|
||||
segment = Segment(
|
||||
sequence=2,
|
||||
stream_id=1,
|
||||
duration=SEGMENT_DURATION,
|
||||
start_time=FAKE_TIME,
|
||||
)
|
||||
hls.put(segment)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
hls_client = await hls_stream(stream)
|
||||
|
@ -379,11 +373,11 @@ async def test_hls_playlist_view_discontinuity(hass, hls_stream, stream_worker_s
|
|||
resp = await hls_client.get("/playlist.m3u8")
|
||||
assert resp.status == 200
|
||||
assert await resp.text() == make_playlist(
|
||||
sequence=1,
|
||||
sequence=0,
|
||||
segments=[
|
||||
make_segment(0),
|
||||
make_segment(1),
|
||||
make_segment(2),
|
||||
make_segment(3, discontinuity=True),
|
||||
make_segment(2, discontinuity=True),
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -401,29 +395,20 @@ async def test_hls_max_segments_discontinuity(hass, hls_stream, stream_worker_sy
|
|||
|
||||
hls_client = await hls_stream(stream)
|
||||
|
||||
hls.put(
|
||||
Segment(
|
||||
1,
|
||||
INIT_BYTES,
|
||||
MOOF_BYTES,
|
||||
SEGMENT_DURATION,
|
||||
stream_id=0,
|
||||
start_time=FAKE_TIME,
|
||||
)
|
||||
segment = Segment(
|
||||
sequence=0, stream_id=0, duration=SEGMENT_DURATION, start_time=FAKE_TIME
|
||||
)
|
||||
hls.put(segment)
|
||||
|
||||
# Produce enough segments to overfill the output buffer by one
|
||||
for sequence in range(1, MAX_SEGMENTS + 2):
|
||||
hls.put(
|
||||
Segment(
|
||||
sequence,
|
||||
INIT_BYTES,
|
||||
MOOF_BYTES,
|
||||
SEGMENT_DURATION,
|
||||
stream_id=1,
|
||||
start_time=FAKE_TIME,
|
||||
)
|
||||
for sequence in range(MAX_SEGMENTS + 1):
|
||||
segment = Segment(
|
||||
sequence=sequence,
|
||||
stream_id=1,
|
||||
duration=SEGMENT_DURATION,
|
||||
start_time=FAKE_TIME,
|
||||
)
|
||||
hls.put(segment)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
resp = await hls_client.get("/playlist.m3u8")
|
||||
|
@ -432,9 +417,9 @@ async def test_hls_max_segments_discontinuity(hass, hls_stream, stream_worker_sy
|
|||
# Only NUM_PLAYLIST_SEGMENTS are returned in the playlist causing the
|
||||
# EXT-X-DISCONTINUITY tag to be omitted and EXT-X-DISCONTINUITY-SEQUENCE
|
||||
# returned instead.
|
||||
start = MAX_SEGMENTS + 2 - NUM_PLAYLIST_SEGMENTS
|
||||
start = MAX_SEGMENTS + 1 - NUM_PLAYLIST_SEGMENTS
|
||||
segments = []
|
||||
for sequence in range(start, MAX_SEGMENTS + 2):
|
||||
for sequence in range(start, MAX_SEGMENTS + 1):
|
||||
segments.append(make_segment(sequence))
|
||||
assert await resp.text() == make_playlist(
|
||||
sequence=start,
|
||||
|
|
|
@ -1,23 +1,16 @@
|
|||
"""The tests for hls streams."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections import deque
|
||||
from datetime import timedelta
|
||||
from io import BytesIO
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
from unittest.mock import patch
|
||||
|
||||
import async_timeout
|
||||
import av
|
||||
import pytest
|
||||
|
||||
from homeassistant.components.stream import create_stream
|
||||
from homeassistant.components.stream.const import HLS_PROVIDER, RECORDER_PROVIDER
|
||||
from homeassistant.components.stream.core import Segment
|
||||
from homeassistant.components.stream.fmp4utils import get_init_and_moof_data
|
||||
from homeassistant.components.stream.core import Part, Segment
|
||||
from homeassistant.components.stream.fmp4utils import find_box
|
||||
from homeassistant.components.stream.recorder import recorder_save_worker
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
@ -26,63 +19,9 @@ import homeassistant.util.dt as dt_util
|
|||
from tests.common import async_fire_time_changed
|
||||
from tests.components.stream.common import generate_h264_video
|
||||
|
||||
TEST_TIMEOUT = 7.0 # Lower than 9s home assistant timeout
|
||||
MAX_ABORT_SEGMENTS = 20 # Abort test to avoid looping forever
|
||||
|
||||
|
||||
class SaveRecordWorkerSync:
|
||||
"""
|
||||
Test fixture to manage RecordOutput thread for recorder_save_worker.
|
||||
|
||||
This is used to assert that the worker is started and stopped cleanly
|
||||
to avoid thread leaks in tests.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize SaveRecordWorkerSync."""
|
||||
self.reset()
|
||||
self._segments = None
|
||||
self._save_thread = None
|
||||
|
||||
def recorder_save_worker(self, file_out: str, segments: deque[Segment]):
|
||||
"""Mock method for patch."""
|
||||
logging.debug("recorder_save_worker thread started")
|
||||
assert self._save_thread is None
|
||||
self._segments = segments
|
||||
self._save_thread = threading.current_thread()
|
||||
self._save_event.set()
|
||||
|
||||
async def get_segments(self):
|
||||
"""Return the recorded video segments."""
|
||||
with async_timeout.timeout(TEST_TIMEOUT):
|
||||
await self._save_event.wait()
|
||||
return self._segments
|
||||
|
||||
async def join(self):
|
||||
"""Verify save worker was invoked and block on shutdown."""
|
||||
with async_timeout.timeout(TEST_TIMEOUT):
|
||||
await self._save_event.wait()
|
||||
self._save_thread.join(timeout=TEST_TIMEOUT)
|
||||
assert not self._save_thread.is_alive()
|
||||
|
||||
def reset(self):
|
||||
"""Reset callback state for reuse in tests."""
|
||||
self._save_thread = None
|
||||
self._save_event = asyncio.Event()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def record_worker_sync(hass):
|
||||
"""Patch recorder_save_worker for clean thread shutdown for test."""
|
||||
sync = SaveRecordWorkerSync()
|
||||
with patch(
|
||||
"homeassistant.components.stream.recorder.recorder_save_worker",
|
||||
side_effect=sync.recorder_save_worker,
|
||||
autospec=True,
|
||||
):
|
||||
yield sync
|
||||
|
||||
|
||||
async def test_record_stream(hass, hass_client, record_worker_sync):
|
||||
"""
|
||||
Test record stream.
|
||||
|
@ -179,6 +118,21 @@ async def test_record_path_not_allowed(hass, hass_client):
|
|||
await stream.async_record("/example/path")
|
||||
|
||||
|
||||
def add_parts_to_segment(segment, source):
|
||||
"""Add relevant part data to segment for testing recorder."""
|
||||
moof_locs = list(find_box(source.getbuffer(), b"moof")) + [len(source.getbuffer())]
|
||||
segment.init = source.getbuffer()[: moof_locs[0]].tobytes()
|
||||
segment.parts = [
|
||||
Part(
|
||||
duration=None,
|
||||
has_keyframe=None,
|
||||
http_range_start=None,
|
||||
data=source.getbuffer()[moof_locs[i] : moof_locs[i + 1]],
|
||||
)
|
||||
for i in range(1, len(moof_locs) - 1)
|
||||
]
|
||||
|
||||
|
||||
async def test_recorder_save(tmpdir):
|
||||
"""Test recorder save."""
|
||||
# Setup
|
||||
|
@ -186,9 +140,10 @@ async def test_recorder_save(tmpdir):
|
|||
filename = f"{tmpdir}/test.mp4"
|
||||
|
||||
# Run
|
||||
recorder_save_worker(
|
||||
filename, [Segment(1, *get_init_and_moof_data(source.getbuffer()), 4)]
|
||||
)
|
||||
segment = Segment(sequence=1)
|
||||
add_parts_to_segment(segment, source)
|
||||
segment.duration = 4
|
||||
recorder_save_worker(filename, [segment])
|
||||
|
||||
# Assert
|
||||
assert os.path.exists(filename)
|
||||
|
@ -201,15 +156,13 @@ async def test_recorder_discontinuity(tmpdir):
|
|||
filename = f"{tmpdir}/test.mp4"
|
||||
|
||||
# Run
|
||||
init, moof_data = get_init_and_moof_data(source.getbuffer())
|
||||
recorder_save_worker(
|
||||
filename,
|
||||
[
|
||||
Segment(1, init, moof_data, 4, 0),
|
||||
Segment(2, init, moof_data, 4, 1),
|
||||
],
|
||||
)
|
||||
|
||||
segment_1 = Segment(sequence=1, stream_id=0)
|
||||
add_parts_to_segment(segment_1, source)
|
||||
segment_1.duration = 4
|
||||
segment_2 = Segment(sequence=2, stream_id=1)
|
||||
add_parts_to_segment(segment_2, source)
|
||||
segment_2.duration = 4
|
||||
recorder_save_worker(filename, [segment_1, segment_2])
|
||||
# Assert
|
||||
assert os.path.exists(filename)
|
||||
|
||||
|
@ -263,7 +216,9 @@ async def test_record_stream_audio(
|
|||
stream_worker_sync.resume()
|
||||
|
||||
result = av.open(
|
||||
BytesIO(last_segment.init + last_segment.moof_data), "r", format="mp4"
|
||||
BytesIO(last_segment.init + last_segment.get_bytes_without_init()),
|
||||
"r",
|
||||
format="mp4",
|
||||
)
|
||||
|
||||
assert len(result.streams.audio) == expected_audio_streams
|
||||
|
|
|
@ -21,7 +21,7 @@ from unittest.mock import patch
|
|||
|
||||
import av
|
||||
|
||||
from homeassistant.components.stream import Stream
|
||||
from homeassistant.components.stream import Stream, create_stream
|
||||
from homeassistant.components.stream.const import (
|
||||
HLS_PROVIDER,
|
||||
MAX_MISSING_DTS,
|
||||
|
@ -29,6 +29,9 @@ from homeassistant.components.stream.const import (
|
|||
TARGET_SEGMENT_DURATION,
|
||||
)
|
||||
from homeassistant.components.stream.worker import SegmentBuffer, stream_worker
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.components.stream.common import generate_h264_video
|
||||
|
||||
STREAM_SOURCE = "some-stream-source"
|
||||
# Formats here are arbitrary, not exercised by tests
|
||||
|
@ -99,9 +102,9 @@ class PacketSequence:
|
|||
super().__init__(3)
|
||||
|
||||
time_base = fractions.Fraction(1, VIDEO_FRAME_RATE)
|
||||
dts = self.packet * PACKET_DURATION / time_base
|
||||
pts = self.packet * PACKET_DURATION / time_base
|
||||
duration = PACKET_DURATION / time_base
|
||||
dts = int(self.packet * PACKET_DURATION / time_base)
|
||||
pts = int(self.packet * PACKET_DURATION / time_base)
|
||||
duration = int(PACKET_DURATION / time_base)
|
||||
stream = VIDEO_STREAM
|
||||
# Pretend we get 1 keyframe every second
|
||||
is_keyframe = not (self.packet - 1) % (VIDEO_FRAME_RATE * KEYFRAME_INTERVAL)
|
||||
|
@ -177,6 +180,11 @@ class FakePyAvBuffer:
|
|||
"""Capture the output segment for tests to inspect."""
|
||||
self.segments.append(segment)
|
||||
|
||||
@property
|
||||
def complete_segments(self):
|
||||
"""Return only the complete segments."""
|
||||
return [segment for segment in self.segments if segment.complete]
|
||||
|
||||
|
||||
class MockPyAv:
|
||||
"""Mocks out av.open."""
|
||||
|
@ -197,6 +205,19 @@ class MockPyAv:
|
|||
return self.container
|
||||
|
||||
|
||||
class MockFlushPart:
|
||||
"""Class to hold a wrapper function for check_flush_part."""
|
||||
|
||||
# Wrap this method with a preceding write so the BytesIO pointer moves
|
||||
check_flush_part = SegmentBuffer.check_flush_part
|
||||
|
||||
@classmethod
|
||||
def wrapped_check_flush_part(cls, segment_buffer, packet):
|
||||
"""Wrap check_flush_part to also advance the memory_file pointer."""
|
||||
segment_buffer._memory_file.write(b"0")
|
||||
return cls.check_flush_part(segment_buffer, packet)
|
||||
|
||||
|
||||
async def async_decode_stream(hass, packets, py_av=None):
|
||||
"""Start a stream worker that decodes incoming stream packets into output segments."""
|
||||
stream = Stream(hass, STREAM_SOURCE)
|
||||
|
@ -209,6 +230,10 @@ async def async_decode_stream(hass, packets, py_av=None):
|
|||
with patch("av.open", new=py_av.open), patch(
|
||||
"homeassistant.components.stream.core.StreamOutput.put",
|
||||
side_effect=py_av.capture_buffer.capture_output_segment,
|
||||
), patch(
|
||||
"homeassistant.components.stream.worker.SegmentBuffer.check_flush_part",
|
||||
side_effect=MockFlushPart.wrapped_check_flush_part,
|
||||
autospec=True,
|
||||
):
|
||||
segment_buffer = SegmentBuffer(stream.outputs)
|
||||
stream_worker(STREAM_SOURCE, {}, segment_buffer, threading.Event())
|
||||
|
@ -235,13 +260,16 @@ async def test_stream_worker_success(hass):
|
|||
hass, PacketSequence(TEST_SEQUENCE_LENGTH)
|
||||
)
|
||||
segments = decoded_stream.segments
|
||||
complete_segments = decoded_stream.complete_segments
|
||||
# Check number of segments. A segment is only formed when a packet from the next
|
||||
# segment arrives, hence the subtraction of one from the sequence length.
|
||||
assert len(segments) == int((TEST_SEQUENCE_LENGTH - 1) * SEGMENTS_PER_PACKET)
|
||||
assert len(complete_segments) == int(
|
||||
(TEST_SEQUENCE_LENGTH - 1) * SEGMENTS_PER_PACKET
|
||||
)
|
||||
# Check sequence numbers
|
||||
assert all(segments[i].sequence == i for i in range(len(segments)))
|
||||
# Check segment durations
|
||||
assert all(s.duration == SEGMENT_DURATION for s in segments)
|
||||
assert all(s.duration == SEGMENT_DURATION for s in complete_segments)
|
||||
assert len(decoded_stream.video_packets) == TEST_SEQUENCE_LENGTH
|
||||
assert len(decoded_stream.audio_packets) == 0
|
||||
|
||||
|
@ -259,6 +287,7 @@ async def test_skip_out_of_order_packet(hass):
|
|||
|
||||
decoded_stream = await async_decode_stream(hass, iter(packets))
|
||||
segments = decoded_stream.segments
|
||||
complete_segments = decoded_stream.complete_segments
|
||||
# Check sequence numbers
|
||||
assert all(segments[i].sequence == i for i in range(len(segments)))
|
||||
# If skipped packet would have been the first packet of a segment, the previous
|
||||
|
@ -273,12 +302,14 @@ async def test_skip_out_of_order_packet(hass):
|
|||
)
|
||||
del segments[longer_segment_index]
|
||||
# Check number of segments
|
||||
assert len(segments) == int((len(packets) - 1 - 1) * SEGMENTS_PER_PACKET - 1)
|
||||
assert len(complete_segments) == int(
|
||||
(len(packets) - 1 - 1) * SEGMENTS_PER_PACKET - 1
|
||||
)
|
||||
else: # Otherwise segment durations and number of segments are unaffected
|
||||
# Check number of segments
|
||||
assert len(segments) == int((len(packets) - 1) * SEGMENTS_PER_PACKET)
|
||||
assert len(complete_segments) == int((len(packets) - 1) * SEGMENTS_PER_PACKET)
|
||||
# Check remaining segment durations
|
||||
assert all(s.duration == SEGMENT_DURATION for s in segments)
|
||||
assert all(s.duration == SEGMENT_DURATION for s in complete_segments)
|
||||
assert len(decoded_stream.video_packets) == len(packets) - 1
|
||||
assert len(decoded_stream.audio_packets) == 0
|
||||
|
||||
|
@ -292,12 +323,15 @@ async def test_discard_old_packets(hass):
|
|||
|
||||
decoded_stream = await async_decode_stream(hass, iter(packets))
|
||||
segments = decoded_stream.segments
|
||||
complete_segments = decoded_stream.complete_segments
|
||||
# Check number of segments
|
||||
assert len(segments) == int((OUT_OF_ORDER_PACKET_INDEX - 1) * SEGMENTS_PER_PACKET)
|
||||
assert len(complete_segments) == int(
|
||||
(OUT_OF_ORDER_PACKET_INDEX - 1) * SEGMENTS_PER_PACKET
|
||||
)
|
||||
# Check sequence numbers
|
||||
assert all(segments[i].sequence == i for i in range(len(segments)))
|
||||
# Check segment durations
|
||||
assert all(s.duration == SEGMENT_DURATION for s in segments)
|
||||
assert all(s.duration == SEGMENT_DURATION for s in complete_segments)
|
||||
assert len(decoded_stream.video_packets) == OUT_OF_ORDER_PACKET_INDEX
|
||||
assert len(decoded_stream.audio_packets) == 0
|
||||
|
||||
|
@ -311,12 +345,15 @@ async def test_packet_overflow(hass):
|
|||
|
||||
decoded_stream = await async_decode_stream(hass, iter(packets))
|
||||
segments = decoded_stream.segments
|
||||
complete_segments = decoded_stream.complete_segments
|
||||
# Check number of segments
|
||||
assert len(segments) == int((OUT_OF_ORDER_PACKET_INDEX - 1) * SEGMENTS_PER_PACKET)
|
||||
assert len(complete_segments) == int(
|
||||
(OUT_OF_ORDER_PACKET_INDEX - 1) * SEGMENTS_PER_PACKET
|
||||
)
|
||||
# Check sequence numbers
|
||||
assert all(segments[i].sequence == i for i in range(len(segments)))
|
||||
# Check segment durations
|
||||
assert all(s.duration == SEGMENT_DURATION for s in segments)
|
||||
assert all(s.duration == SEGMENT_DURATION for s in complete_segments)
|
||||
assert len(decoded_stream.video_packets) == OUT_OF_ORDER_PACKET_INDEX
|
||||
assert len(decoded_stream.audio_packets) == 0
|
||||
|
||||
|
@ -332,10 +369,11 @@ async def test_skip_initial_bad_packets(hass):
|
|||
|
||||
decoded_stream = await async_decode_stream(hass, iter(packets))
|
||||
segments = decoded_stream.segments
|
||||
complete_segments = decoded_stream.complete_segments
|
||||
# Check sequence numbers
|
||||
assert all(segments[i].sequence == i for i in range(len(segments)))
|
||||
# Check segment durations
|
||||
assert all(s.duration == SEGMENT_DURATION for s in segments)
|
||||
assert all(s.duration == SEGMENT_DURATION for s in complete_segments)
|
||||
assert (
|
||||
len(decoded_stream.video_packets)
|
||||
== num_packets
|
||||
|
@ -344,7 +382,7 @@ async def test_skip_initial_bad_packets(hass):
|
|||
* KEYFRAME_INTERVAL
|
||||
)
|
||||
# Check number of segments
|
||||
assert len(segments) == int(
|
||||
assert len(complete_segments) == int(
|
||||
(len(decoded_stream.video_packets) - 1) * SEGMENTS_PER_PACKET
|
||||
)
|
||||
assert len(decoded_stream.audio_packets) == 0
|
||||
|
@ -381,13 +419,11 @@ async def test_skip_missing_dts(hass):
|
|||
|
||||
decoded_stream = await async_decode_stream(hass, iter(packets))
|
||||
segments = decoded_stream.segments
|
||||
complete_segments = decoded_stream.complete_segments
|
||||
# Check sequence numbers
|
||||
assert all(segments[i].sequence == i for i in range(len(segments)))
|
||||
# Check segment durations (not counting the last segment)
|
||||
assert (
|
||||
sum([segments[i].duration == SEGMENT_DURATION for i in range(len(segments))])
|
||||
>= len(segments) - 1
|
||||
)
|
||||
assert sum(segment.duration for segment in complete_segments) >= len(segments) - 1
|
||||
assert len(decoded_stream.video_packets) == num_packets - num_bad_packets
|
||||
assert len(decoded_stream.audio_packets) == 0
|
||||
|
||||
|
@ -403,8 +439,8 @@ async def test_too_many_bad_packets(hass):
|
|||
packets[i].dts = None
|
||||
|
||||
decoded_stream = await async_decode_stream(hass, iter(packets))
|
||||
segments = decoded_stream.segments
|
||||
assert len(segments) == int((bad_packet_start - 1) * SEGMENTS_PER_PACKET)
|
||||
complete_segments = decoded_stream.complete_segments
|
||||
assert len(complete_segments) == int((bad_packet_start - 1) * SEGMENTS_PER_PACKET)
|
||||
assert len(decoded_stream.video_packets) == bad_packet_start
|
||||
assert len(decoded_stream.audio_packets) == 0
|
||||
|
||||
|
@ -431,8 +467,8 @@ async def test_audio_packets_not_found(hass):
|
|||
packets = PacketSequence(num_packets) # Contains only video packets
|
||||
|
||||
decoded_stream = await async_decode_stream(hass, iter(packets), py_av=py_av)
|
||||
segments = decoded_stream.segments
|
||||
assert len(segments) == int((num_packets - 1) * SEGMENTS_PER_PACKET)
|
||||
complete_segments = decoded_stream.complete_segments
|
||||
assert len(complete_segments) == int((num_packets - 1) * SEGMENTS_PER_PACKET)
|
||||
assert len(decoded_stream.video_packets) == num_packets
|
||||
assert len(decoded_stream.audio_packets) == 0
|
||||
|
||||
|
@ -444,8 +480,8 @@ async def test_adts_aac_audio(hass):
|
|||
num_packets = PACKETS_TO_WAIT_FOR_AUDIO + 1
|
||||
packets = list(PacketSequence(num_packets))
|
||||
packets[1].stream = AUDIO_STREAM
|
||||
packets[1].dts = packets[0].dts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE
|
||||
packets[1].pts = packets[0].pts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE
|
||||
packets[1].dts = int(packets[0].dts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE)
|
||||
packets[1].pts = int(packets[0].pts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE)
|
||||
# The following is packet data is a sign of ADTS AAC
|
||||
packets[1][0] = 255
|
||||
packets[1][1] = 241
|
||||
|
@ -462,17 +498,17 @@ async def test_audio_is_first_packet(hass):
|
|||
packets = list(PacketSequence(num_packets))
|
||||
# Pair up an audio packet for each video packet
|
||||
packets[0].stream = AUDIO_STREAM
|
||||
packets[0].dts = packets[1].dts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE
|
||||
packets[0].pts = packets[1].pts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE
|
||||
packets[0].dts = int(packets[1].dts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE)
|
||||
packets[0].pts = int(packets[1].pts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE)
|
||||
packets[1].is_keyframe = True # Move the video keyframe from packet 0 to packet 1
|
||||
packets[2].stream = AUDIO_STREAM
|
||||
packets[2].dts = packets[3].dts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE
|
||||
packets[2].pts = packets[3].pts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE
|
||||
packets[2].dts = int(packets[3].dts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE)
|
||||
packets[2].pts = int(packets[3].pts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE)
|
||||
|
||||
decoded_stream = await async_decode_stream(hass, iter(packets), py_av=py_av)
|
||||
segments = decoded_stream.segments
|
||||
complete_segments = decoded_stream.complete_segments
|
||||
# The audio packets are segmented with the video packets
|
||||
assert len(segments) == int((num_packets - 2 - 1) * SEGMENTS_PER_PACKET)
|
||||
assert len(complete_segments) == int((num_packets - 2 - 1) * SEGMENTS_PER_PACKET)
|
||||
assert len(decoded_stream.video_packets) == num_packets - 2
|
||||
assert len(decoded_stream.audio_packets) == 1
|
||||
|
||||
|
@ -484,13 +520,13 @@ async def test_audio_packets_found(hass):
|
|||
num_packets = PACKETS_TO_WAIT_FOR_AUDIO + 1
|
||||
packets = list(PacketSequence(num_packets))
|
||||
packets[1].stream = AUDIO_STREAM
|
||||
packets[1].dts = packets[0].dts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE
|
||||
packets[1].pts = packets[0].pts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE
|
||||
packets[1].dts = int(packets[0].dts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE)
|
||||
packets[1].pts = int(packets[0].pts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE)
|
||||
|
||||
decoded_stream = await async_decode_stream(hass, iter(packets), py_av=py_av)
|
||||
segments = decoded_stream.segments
|
||||
complete_segments = decoded_stream.complete_segments
|
||||
# The audio packet above is buffered with the video packet
|
||||
assert len(segments) == int((num_packets - 1 - 1) * SEGMENTS_PER_PACKET)
|
||||
assert len(complete_segments) == int((num_packets - 1 - 1) * SEGMENTS_PER_PACKET)
|
||||
assert len(decoded_stream.video_packets) == num_packets - 1
|
||||
assert len(decoded_stream.audio_packets) == 1
|
||||
|
||||
|
@ -507,12 +543,15 @@ async def test_pts_out_of_order(hass):
|
|||
|
||||
decoded_stream = await async_decode_stream(hass, iter(packets))
|
||||
segments = decoded_stream.segments
|
||||
complete_segments = decoded_stream.complete_segments
|
||||
# Check number of segments
|
||||
assert len(segments) == int((TEST_SEQUENCE_LENGTH - 1) * SEGMENTS_PER_PACKET)
|
||||
assert len(complete_segments) == int(
|
||||
(TEST_SEQUENCE_LENGTH - 1) * SEGMENTS_PER_PACKET
|
||||
)
|
||||
# Check sequence numbers
|
||||
assert all(segments[i].sequence == i for i in range(len(segments)))
|
||||
# Check segment durations
|
||||
assert all(s.duration == SEGMENT_DURATION for s in segments)
|
||||
assert all(s.duration == SEGMENT_DURATION for s in complete_segments)
|
||||
assert len(decoded_stream.video_packets) == len(packets)
|
||||
assert len(decoded_stream.audio_packets) == 0
|
||||
|
||||
|
@ -573,7 +612,11 @@ async def test_update_stream_source(hass):
|
|||
worker_wake.wait()
|
||||
return py_av.open(stream_source, args, kwargs)
|
||||
|
||||
with patch("av.open", new=blocking_open):
|
||||
with patch("av.open", new=blocking_open), patch(
|
||||
"homeassistant.components.stream.worker.SegmentBuffer.check_flush_part",
|
||||
side_effect=MockFlushPart.wrapped_check_flush_part,
|
||||
autospec=True,
|
||||
):
|
||||
stream.start()
|
||||
assert worker_open.wait(TIMEOUT)
|
||||
assert last_stream_source == STREAM_SOURCE
|
||||
|
@ -604,3 +647,74 @@ async def test_worker_log(hass, caplog):
|
|||
await hass.async_block_till_done()
|
||||
assert "https://abcd:efgh@foo.bar" not in caplog.text
|
||||
assert "https://****:****@foo.bar" in caplog.text
|
||||
|
||||
|
||||
async def test_durations(hass, record_worker_sync):
|
||||
"""Test that the duration metadata matches the media."""
|
||||
await async_setup_component(hass, "stream", {"stream": {}})
|
||||
|
||||
source = generate_h264_video()
|
||||
stream = create_stream(hass, source)
|
||||
|
||||
# use record_worker_sync to grab output segments
|
||||
with patch.object(hass.config, "is_allowed_path", return_value=True):
|
||||
await stream.async_record("/example/path")
|
||||
|
||||
complete_segments = list(await record_worker_sync.get_segments())[:-1]
|
||||
assert len(complete_segments) >= 1
|
||||
|
||||
# check that the Part duration metadata matches the durations in the media
|
||||
running_metadata_duration = 0
|
||||
for segment in complete_segments:
|
||||
for part in segment.parts:
|
||||
av_part = av.open(io.BytesIO(segment.init + part.data))
|
||||
running_metadata_duration += part.duration
|
||||
# av_part.duration will just return the largest dts in av_part.
|
||||
# When we normalize by av.time_base this should equal the running duration
|
||||
assert math.isclose(
|
||||
running_metadata_duration,
|
||||
av_part.duration / av.time_base,
|
||||
abs_tol=1e-6,
|
||||
)
|
||||
av_part.close()
|
||||
# check that the Part durations are consistent with the Segment durations
|
||||
for segment in complete_segments:
|
||||
assert math.isclose(
|
||||
sum(part.duration for part in segment.parts), segment.duration, abs_tol=1e-6
|
||||
)
|
||||
|
||||
await record_worker_sync.join()
|
||||
|
||||
stream.stop()
|
||||
|
||||
|
||||
async def test_has_keyframe(hass, record_worker_sync):
|
||||
"""Test that the has_keyframe metadata matches the media."""
|
||||
await async_setup_component(hass, "stream", {"stream": {}})
|
||||
|
||||
source = generate_h264_video()
|
||||
stream = create_stream(hass, source)
|
||||
|
||||
# use record_worker_sync to grab output segments
|
||||
with patch.object(hass.config, "is_allowed_path", return_value=True):
|
||||
await stream.async_record("/example/path")
|
||||
|
||||
# Our test video has keyframes every second. Use smaller parts so we have more
|
||||
# part boundaries to better test keyframe logic.
|
||||
with patch("homeassistant.components.stream.worker.TARGET_PART_DURATION", 0.25):
|
||||
complete_segments = list(await record_worker_sync.get_segments())[:-1]
|
||||
assert len(complete_segments) >= 1
|
||||
|
||||
# check that the Part has_keyframe metadata matches the keyframes in the media
|
||||
for segment in complete_segments:
|
||||
for part in segment.parts:
|
||||
av_part = av.open(io.BytesIO(segment.init + part.data))
|
||||
media_has_keyframe = any(
|
||||
packet.is_keyframe for packet in av_part.demux(av_part.streams.video[0])
|
||||
)
|
||||
av_part.close()
|
||||
assert part.has_keyframe == media_has_keyframe
|
||||
|
||||
await record_worker_sync.join()
|
||||
|
||||
stream.stop()
|
||||
|
|
Loading…
Reference in New Issue