Refactor decompression timestamp validation logic in stream component (#52462)
* Refactor dts validation logic into a separate function Create a decompression timestamp validation function to move the logic out of the worker into a separate class. This also uses the python itertools.chain to chain together the initial packets with the remaining packets in the container iterator, removing additional inline if statements. * Reset dts validator when container is reset * Fix typo in a comment * Reuse existing dts_validator when disabling audio streampull/52684/head
parent
02d8d25d1d
commit
e895b6cd42
|
@ -4,6 +4,7 @@ from __future__ import annotations
|
|||
from collections import deque
|
||||
from collections.abc import Iterator, Mapping
|
||||
from io import BytesIO
|
||||
import itertools
|
||||
import logging
|
||||
from threading import Event
|
||||
from typing import Any, Callable, cast
|
||||
|
@ -201,7 +202,41 @@ class SegmentBuffer:
|
|||
self._memory_file.close()
|
||||
|
||||
|
||||
def stream_worker( # noqa: C901
|
||||
class TimestampValidator:
|
||||
"""Validate ordering of timestamps for packets in a stream."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the TimestampValidator."""
|
||||
# Decompression timestamp of last packet in each stream
|
||||
self._last_dts: dict[av.stream.Stream, float] = {}
|
||||
# Number of consecutive missing decompression timestamps
|
||||
self._missing_dts = 0
|
||||
|
||||
def is_valid(self, packet: av.Packet) -> float:
|
||||
"""Validate the packet timestamp based on ordering within the stream."""
|
||||
# Discard packets missing DTS. Terminate if too many are missing.
|
||||
if packet.dts is None:
|
||||
if self._missing_dts >= MAX_MISSING_DTS:
|
||||
raise StopIteration(
|
||||
f"No dts in {MAX_MISSING_DTS+1} consecutive packets"
|
||||
)
|
||||
self._missing_dts += 1
|
||||
return False
|
||||
self._missing_dts = 0
|
||||
# Discard when dts is not monotonic. Terminate if gap is too wide.
|
||||
prev_dts = self._last_dts.get(packet.stream, float("-inf"))
|
||||
if packet.dts <= prev_dts:
|
||||
gap = packet.time_base * (prev_dts - packet.dts)
|
||||
if gap > MAX_TIMESTAMP_GAP:
|
||||
raise StopIteration(
|
||||
f"Timestamp overflow detected: last dts = {prev_dts}, dts = {packet.dts}"
|
||||
)
|
||||
return False
|
||||
self._last_dts[packet.stream] = packet.dts
|
||||
return True
|
||||
|
||||
|
||||
def stream_worker(
|
||||
source: str,
|
||||
options: dict[str, str],
|
||||
segment_buffer: SegmentBuffer,
|
||||
|
@ -234,10 +269,6 @@ def stream_worker( # noqa: C901
|
|||
|
||||
# Iterator for demuxing
|
||||
container_packets: Iterator[av.Packet]
|
||||
# The decoder timestamps of the latest packet in each stream we processed
|
||||
last_dts = {video_stream: float("-inf"), audio_stream: float("-inf")}
|
||||
# Keep track of consecutive packets without a dts to detect end of stream.
|
||||
missing_dts = 0
|
||||
# The video dts at the beginning of the segment
|
||||
segment_start_dts: int | None = None
|
||||
# Because of problems 1 and 2 below, we need to store the first few packets and replay them
|
||||
|
@ -254,23 +285,17 @@ def stream_worker( # noqa: C901
|
|||
Also load the first video keyframe dts into segment_start_dts and check if the audio stream really exists.
|
||||
"""
|
||||
nonlocal segment_start_dts, audio_stream, container_packets
|
||||
missing_dts = 0
|
||||
found_audio = False
|
||||
try:
|
||||
container_packets = container.demux((video_stream, audio_stream))
|
||||
# Ensure packets are ordered correctly
|
||||
dts_validator = TimestampValidator()
|
||||
container_packets = filter(
|
||||
dts_validator.is_valid, container.demux((video_stream, audio_stream))
|
||||
)
|
||||
first_packet: av.Packet | None = None
|
||||
# Get to first video keyframe
|
||||
while first_packet is None:
|
||||
packet = next(container_packets)
|
||||
if (
|
||||
packet.dts is None
|
||||
): # Allow MAX_MISSING_DTS packets with no dts, raise error on the next one
|
||||
if missing_dts >= MAX_MISSING_DTS:
|
||||
raise StopIteration(
|
||||
f"Invalid data - got {MAX_MISSING_DTS+1} packets with missing DTS while initializing"
|
||||
)
|
||||
missing_dts += 1
|
||||
continue
|
||||
if packet.stream == audio_stream:
|
||||
found_audio = True
|
||||
elif packet.is_keyframe: # video_keyframe
|
||||
|
@ -283,15 +308,6 @@ def stream_worker( # noqa: C901
|
|||
and len(initial_packets) < PACKETS_TO_WAIT_FOR_AUDIO
|
||||
):
|
||||
packet = next(container_packets)
|
||||
if (
|
||||
packet.dts is None
|
||||
): # Allow MAX_MISSING_DTS packet with no dts, raise error on the next one
|
||||
if missing_dts >= MAX_MISSING_DTS:
|
||||
raise StopIteration(
|
||||
f"Invalid data - got {MAX_MISSING_DTS+1} packets with missing DTS while initializing"
|
||||
)
|
||||
missing_dts += 1
|
||||
continue
|
||||
if packet.stream == audio_stream:
|
||||
# detect ADTS AAC and disable audio
|
||||
if audio_stream.codec.name == "aac" and packet.size > 2:
|
||||
|
@ -300,7 +316,10 @@ def stream_worker( # noqa: C901
|
|||
_LOGGER.warning(
|
||||
"ADTS AAC detected - disabling audio stream"
|
||||
)
|
||||
container_packets = container.demux(video_stream)
|
||||
container_packets = filter(
|
||||
dts_validator.is_valid,
|
||||
container.demux(video_stream),
|
||||
)
|
||||
audio_stream = None
|
||||
continue
|
||||
found_audio = True
|
||||
|
@ -330,42 +349,16 @@ def stream_worker( # noqa: C901
|
|||
assert isinstance(segment_start_dts, int)
|
||||
segment_buffer.reset(segment_start_dts)
|
||||
|
||||
# Rewind the stream and iterate over the initial set of packets again
|
||||
# filtering out any packets with timestamp ordering issues.
|
||||
packets = itertools.chain(initial_packets, container_packets)
|
||||
while not quit_event.is_set():
|
||||
try:
|
||||
if len(initial_packets) > 0:
|
||||
packet = initial_packets.popleft()
|
||||
else:
|
||||
packet = next(container_packets)
|
||||
if packet.dts is None:
|
||||
# Allow MAX_MISSING_DTS consecutive packets without dts. Terminate the stream on the next one.
|
||||
if missing_dts >= MAX_MISSING_DTS:
|
||||
raise StopIteration(
|
||||
f"No dts in {MAX_MISSING_DTS+1} consecutive packets"
|
||||
)
|
||||
missing_dts += 1
|
||||
continue
|
||||
missing_dts = 0
|
||||
packet = next(packets)
|
||||
except (av.AVError, StopIteration) as ex:
|
||||
_LOGGER.error("Error demuxing stream: %s", str(ex))
|
||||
break
|
||||
|
||||
# Discard packet if dts is not monotonic
|
||||
if packet.dts <= last_dts[packet.stream]:
|
||||
if (
|
||||
packet.time_base * (last_dts[packet.stream] - packet.dts)
|
||||
> MAX_TIMESTAMP_GAP
|
||||
):
|
||||
_LOGGER.warning(
|
||||
"Timestamp overflow detected: last dts %s, dts = %s, resetting stream",
|
||||
last_dts[packet.stream],
|
||||
packet.dts,
|
||||
)
|
||||
break
|
||||
continue
|
||||
|
||||
# Update last_dts processed
|
||||
last_dts[packet.stream] = packet.dts
|
||||
|
||||
# Mux packets, and possibly write a segment to the output stream.
|
||||
# This mutates packet timestamps and stream
|
||||
segment_buffer.mux_packet(packet)
|
||||
|
|
Loading…
Reference in New Issue