Refactor decompression timestamp validation logic in stream component (#52462)

* Refactor dts validation logic into a separate function

Create a decompression timestamp validation function to move the logic out of
the worker into a separate class. This also uses the python itertools.chain
to chain together the initial packets with the remaining packets in the
container iterator, removing additional inline if statements.

* Reset dts validator when container is reset

* Fix typo in a comment

* Reuse existing dts_validator when disabling audio stream
pull/52684/head
Allen Porter 2021-07-07 15:29:15 -07:00 committed by GitHub
parent 02d8d25d1d
commit e895b6cd42
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 49 additions and 56 deletions

View File

@ -4,6 +4,7 @@ from __future__ import annotations
from collections import deque
from collections.abc import Iterator, Mapping
from io import BytesIO
import itertools
import logging
from threading import Event
from typing import Any, Callable, cast
@ -201,7 +202,41 @@ class SegmentBuffer:
self._memory_file.close()
def stream_worker( # noqa: C901
class TimestampValidator:
"""Validate ordering of timestamps for packets in a stream."""
def __init__(self) -> None:
"""Initialize the TimestampValidator."""
# Decompression timestamp of last packet in each stream
self._last_dts: dict[av.stream.Stream, float] = {}
# Number of consecutive missing decompression timestamps
self._missing_dts = 0
def is_valid(self, packet: av.Packet) -> float:
"""Validate the packet timestamp based on ordering within the stream."""
# Discard packets missing DTS. Terminate if too many are missing.
if packet.dts is None:
if self._missing_dts >= MAX_MISSING_DTS:
raise StopIteration(
f"No dts in {MAX_MISSING_DTS+1} consecutive packets"
)
self._missing_dts += 1
return False
self._missing_dts = 0
# Discard when dts is not monotonic. Terminate if gap is too wide.
prev_dts = self._last_dts.get(packet.stream, float("-inf"))
if packet.dts <= prev_dts:
gap = packet.time_base * (prev_dts - packet.dts)
if gap > MAX_TIMESTAMP_GAP:
raise StopIteration(
f"Timestamp overflow detected: last dts = {prev_dts}, dts = {packet.dts}"
)
return False
self._last_dts[packet.stream] = packet.dts
return True
def stream_worker(
source: str,
options: dict[str, str],
segment_buffer: SegmentBuffer,
@ -234,10 +269,6 @@ def stream_worker( # noqa: C901
# Iterator for demuxing
container_packets: Iterator[av.Packet]
# The decoder timestamps of the latest packet in each stream we processed
last_dts = {video_stream: float("-inf"), audio_stream: float("-inf")}
# Keep track of consecutive packets without a dts to detect end of stream.
missing_dts = 0
# The video dts at the beginning of the segment
segment_start_dts: int | None = None
# Because of problems 1 and 2 below, we need to store the first few packets and replay them
@ -254,23 +285,17 @@ def stream_worker( # noqa: C901
Also load the first video keyframe dts into segment_start_dts and check if the audio stream really exists.
"""
nonlocal segment_start_dts, audio_stream, container_packets
missing_dts = 0
found_audio = False
try:
container_packets = container.demux((video_stream, audio_stream))
# Ensure packets are ordered correctly
dts_validator = TimestampValidator()
container_packets = filter(
dts_validator.is_valid, container.demux((video_stream, audio_stream))
)
first_packet: av.Packet | None = None
# Get to first video keyframe
while first_packet is None:
packet = next(container_packets)
if (
packet.dts is None
): # Allow MAX_MISSING_DTS packets with no dts, raise error on the next one
if missing_dts >= MAX_MISSING_DTS:
raise StopIteration(
f"Invalid data - got {MAX_MISSING_DTS+1} packets with missing DTS while initializing"
)
missing_dts += 1
continue
if packet.stream == audio_stream:
found_audio = True
elif packet.is_keyframe: # video_keyframe
@ -283,15 +308,6 @@ def stream_worker( # noqa: C901
and len(initial_packets) < PACKETS_TO_WAIT_FOR_AUDIO
):
packet = next(container_packets)
if (
packet.dts is None
): # Allow MAX_MISSING_DTS packet with no dts, raise error on the next one
if missing_dts >= MAX_MISSING_DTS:
raise StopIteration(
f"Invalid data - got {MAX_MISSING_DTS+1} packets with missing DTS while initializing"
)
missing_dts += 1
continue
if packet.stream == audio_stream:
# detect ADTS AAC and disable audio
if audio_stream.codec.name == "aac" and packet.size > 2:
@ -300,7 +316,10 @@ def stream_worker( # noqa: C901
_LOGGER.warning(
"ADTS AAC detected - disabling audio stream"
)
container_packets = container.demux(video_stream)
container_packets = filter(
dts_validator.is_valid,
container.demux(video_stream),
)
audio_stream = None
continue
found_audio = True
@ -330,42 +349,16 @@ def stream_worker( # noqa: C901
assert isinstance(segment_start_dts, int)
segment_buffer.reset(segment_start_dts)
# Rewind the stream and iterate over the initial set of packets again
# filtering out any packets with timestamp ordering issues.
packets = itertools.chain(initial_packets, container_packets)
while not quit_event.is_set():
try:
if len(initial_packets) > 0:
packet = initial_packets.popleft()
else:
packet = next(container_packets)
if packet.dts is None:
# Allow MAX_MISSING_DTS consecutive packets without dts. Terminate the stream on the next one.
if missing_dts >= MAX_MISSING_DTS:
raise StopIteration(
f"No dts in {MAX_MISSING_DTS+1} consecutive packets"
)
missing_dts += 1
continue
missing_dts = 0
packet = next(packets)
except (av.AVError, StopIteration) as ex:
_LOGGER.error("Error demuxing stream: %s", str(ex))
break
# Discard packet if dts is not monotonic
if packet.dts <= last_dts[packet.stream]:
if (
packet.time_base * (last_dts[packet.stream] - packet.dts)
> MAX_TIMESTAMP_GAP
):
_LOGGER.warning(
"Timestamp overflow detected: last dts %s, dts = %s, resetting stream",
last_dts[packet.stream],
packet.dts,
)
break
continue
# Update last_dts processed
last_dts[packet.stream] = packet.dts
# Mux packets, and possibly write a segment to the output stream.
# This mutates packet timestamps and stream
segment_buffer.mux_packet(packet)