From e895b6cd428f02d5afb0e993389c2bb601ca8f92 Mon Sep 17 00:00:00 2001 From: Allen Porter Date: Wed, 7 Jul 2021 15:29:15 -0700 Subject: [PATCH] Refactor decompression timestamp validation logic in stream component (#52462) * Refactor dts validation logic into a separate function Create a decompression timestamp validation function to move the logic out of the worker into a separate class. This also uses the python itertools.chain to chain together the initial packets with the remaining packets in the container iterator, removing additional inline if statements. * Reset dts validator when container is reset * Fix typo in a comment * Reuse existing dts_validator when disabling audio stream --- homeassistant/components/stream/worker.py | 105 ++++++++++------------ 1 file changed, 49 insertions(+), 56 deletions(-) diff --git a/homeassistant/components/stream/worker.py b/homeassistant/components/stream/worker.py index 04be79e668e..ed1e1b9551d 100644 --- a/homeassistant/components/stream/worker.py +++ b/homeassistant/components/stream/worker.py @@ -4,6 +4,7 @@ from __future__ import annotations from collections import deque from collections.abc import Iterator, Mapping from io import BytesIO +import itertools import logging from threading import Event from typing import Any, Callable, cast @@ -201,7 +202,41 @@ class SegmentBuffer: self._memory_file.close() -def stream_worker( # noqa: C901 +class TimestampValidator: + """Validate ordering of timestamps for packets in a stream.""" + + def __init__(self) -> None: + """Initialize the TimestampValidator.""" + # Decompression timestamp of last packet in each stream + self._last_dts: dict[av.stream.Stream, float] = {} + # Number of consecutive missing decompression timestamps + self._missing_dts = 0 + + def is_valid(self, packet: av.Packet) -> float: + """Validate the packet timestamp based on ordering within the stream.""" + # Discard packets missing DTS. Terminate if too many are missing. + if packet.dts is None: + if self._missing_dts >= MAX_MISSING_DTS: + raise StopIteration( + f"No dts in {MAX_MISSING_DTS+1} consecutive packets" + ) + self._missing_dts += 1 + return False + self._missing_dts = 0 + # Discard when dts is not monotonic. Terminate if gap is too wide. + prev_dts = self._last_dts.get(packet.stream, float("-inf")) + if packet.dts <= prev_dts: + gap = packet.time_base * (prev_dts - packet.dts) + if gap > MAX_TIMESTAMP_GAP: + raise StopIteration( + f"Timestamp overflow detected: last dts = {prev_dts}, dts = {packet.dts}" + ) + return False + self._last_dts[packet.stream] = packet.dts + return True + + +def stream_worker( source: str, options: dict[str, str], segment_buffer: SegmentBuffer, @@ -234,10 +269,6 @@ def stream_worker( # noqa: C901 # Iterator for demuxing container_packets: Iterator[av.Packet] - # The decoder timestamps of the latest packet in each stream we processed - last_dts = {video_stream: float("-inf"), audio_stream: float("-inf")} - # Keep track of consecutive packets without a dts to detect end of stream. - missing_dts = 0 # The video dts at the beginning of the segment segment_start_dts: int | None = None # Because of problems 1 and 2 below, we need to store the first few packets and replay them @@ -254,23 +285,17 @@ def stream_worker( # noqa: C901 Also load the first video keyframe dts into segment_start_dts and check if the audio stream really exists. """ nonlocal segment_start_dts, audio_stream, container_packets - missing_dts = 0 found_audio = False try: - container_packets = container.demux((video_stream, audio_stream)) + # Ensure packets are ordered correctly + dts_validator = TimestampValidator() + container_packets = filter( + dts_validator.is_valid, container.demux((video_stream, audio_stream)) + ) first_packet: av.Packet | None = None # Get to first video keyframe while first_packet is None: packet = next(container_packets) - if ( - packet.dts is None - ): # Allow MAX_MISSING_DTS packets with no dts, raise error on the next one - if missing_dts >= MAX_MISSING_DTS: - raise StopIteration( - f"Invalid data - got {MAX_MISSING_DTS+1} packets with missing DTS while initializing" - ) - missing_dts += 1 - continue if packet.stream == audio_stream: found_audio = True elif packet.is_keyframe: # video_keyframe @@ -283,15 +308,6 @@ def stream_worker( # noqa: C901 and len(initial_packets) < PACKETS_TO_WAIT_FOR_AUDIO ): packet = next(container_packets) - if ( - packet.dts is None - ): # Allow MAX_MISSING_DTS packet with no dts, raise error on the next one - if missing_dts >= MAX_MISSING_DTS: - raise StopIteration( - f"Invalid data - got {MAX_MISSING_DTS+1} packets with missing DTS while initializing" - ) - missing_dts += 1 - continue if packet.stream == audio_stream: # detect ADTS AAC and disable audio if audio_stream.codec.name == "aac" and packet.size > 2: @@ -300,7 +316,10 @@ def stream_worker( # noqa: C901 _LOGGER.warning( "ADTS AAC detected - disabling audio stream" ) - container_packets = container.demux(video_stream) + container_packets = filter( + dts_validator.is_valid, + container.demux(video_stream), + ) audio_stream = None continue found_audio = True @@ -330,42 +349,16 @@ def stream_worker( # noqa: C901 assert isinstance(segment_start_dts, int) segment_buffer.reset(segment_start_dts) + # Rewind the stream and iterate over the initial set of packets again + # filtering out any packets with timestamp ordering issues. + packets = itertools.chain(initial_packets, container_packets) while not quit_event.is_set(): try: - if len(initial_packets) > 0: - packet = initial_packets.popleft() - else: - packet = next(container_packets) - if packet.dts is None: - # Allow MAX_MISSING_DTS consecutive packets without dts. Terminate the stream on the next one. - if missing_dts >= MAX_MISSING_DTS: - raise StopIteration( - f"No dts in {MAX_MISSING_DTS+1} consecutive packets" - ) - missing_dts += 1 - continue - missing_dts = 0 + packet = next(packets) except (av.AVError, StopIteration) as ex: _LOGGER.error("Error demuxing stream: %s", str(ex)) break - # Discard packet if dts is not monotonic - if packet.dts <= last_dts[packet.stream]: - if ( - packet.time_base * (last_dts[packet.stream] - packet.dts) - > MAX_TIMESTAMP_GAP - ): - _LOGGER.warning( - "Timestamp overflow detected: last dts %s, dts = %s, resetting stream", - last_dts[packet.stream], - packet.dts, - ) - break - continue - - # Update last_dts processed - last_dts[packet.stream] = packet.dts - # Mux packets, and possibly write a segment to the output stream. # This mutates packet timestamps and stream segment_buffer.mux_packet(packet)