Detect timestamp discontinuity in stream (#86430)

fixes undefined
pull/86446/head
uvjustin 2023-01-23 20:09:46 +11:00 committed by GitHub
parent ab76b3ffb3
commit d0153f5031
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 32 additions and 26 deletions

View File

@ -31,7 +31,7 @@ EXT_X_START_LL_HLS = 2
PACKETS_TO_WAIT_FOR_AUDIO = 20 # Some streams have an audio stream with no audio
MAX_TIMESTAMP_GAP = 10000 # seconds - anything from 10 to 50000 is probably reasonable
MAX_TIMESTAMP_GAP = 30 # seconds - anything from 10 to 50000 is probably reasonable
MAX_MISSING_DTS = 6 # Number of packets missing DTS to allow
SOURCE_TIMEOUT = 30 # Timeout for reading stream source

View File

@ -38,6 +38,7 @@ from .fmp4utils import read_init
from .hls import HlsStreamOutput
_LOGGER = logging.getLogger(__name__)
NEGATIVE_INF = float("-inf")
class StreamWorkerError(Exception):
@ -416,14 +417,15 @@ class PeekIterator(Iterator):
class TimestampValidator:
"""Validate ordering of timestamps for packets in a stream."""
def __init__(self) -> None:
def __init__(self, inv_video_time_base: int) -> None:
"""Initialize the TimestampValidator."""
# Decompression timestamp of last packet in each stream
self._last_dts: dict[av.stream.Stream, int | float] = defaultdict(
lambda: float("-inf")
lambda: NEGATIVE_INF
)
# Number of consecutive missing decompression timestamps
self._missing_dts = 0
self._max_dts_gap = MAX_TIMESTAMP_GAP * inv_video_time_base
def is_valid(self, packet: av.Packet) -> bool:
"""Validate the packet timestamp based on ordering within the stream."""
@ -438,13 +440,12 @@ class TimestampValidator:
self._missing_dts = 0
# Discard when dts is not monotonic. Terminate if gap is too wide.
prev_dts = self._last_dts[packet.stream]
if abs(prev_dts - packet.dts) > self._max_dts_gap and prev_dts != NEGATIVE_INF:
raise StreamWorkerError(
f"Timestamp discontinuity detected: last dts = {prev_dts}, dts ="
f" {packet.dts}"
)
if packet.dts <= prev_dts:
gap = packet.time_base * (prev_dts - packet.dts)
if gap > MAX_TIMESTAMP_GAP:
raise StreamWorkerError(
f"Timestamp overflow detected: last dts = {prev_dts}, dts ="
f" {packet.dts}"
)
return False
self._last_dts[packet.stream] = packet.dts
return True
@ -527,7 +528,7 @@ def stream_worker(
if audio_stream:
stream_state.diagnostics.set_value("audio_codec", audio_stream.name)
dts_validator = TimestampValidator()
dts_validator = TimestampValidator(int(1 / video_stream.time_base))
container_packets = PeekIterator(
filter(dts_validator.is_valid, container.demux((video_stream, audio_stream)))
)

View File

@ -58,6 +58,7 @@ STREAM_SOURCE = "some-stream-source"
AUDIO_STREAM_FORMAT = "mp3"
VIDEO_STREAM_FORMAT = "h264"
VIDEO_FRAME_RATE = 12
VIDEO_TIME_BASE = fractions.Fraction(1 / 90000)
AUDIO_SAMPLE_RATE = 11025
KEYFRAME_INTERVAL = 1 # in seconds
PACKET_DURATION = fractions.Fraction(1, VIDEO_FRAME_RATE) # in seconds
@ -97,10 +98,10 @@ def mock_stream_settings(hass):
class FakeAvInputStream:
"""A fake pyav Stream."""
def __init__(self, name, rate):
def __init__(self, name, time_base):
"""Initialize the stream."""
self.name = name
self.time_base = fractions.Fraction(1, rate)
self.time_base = time_base
self.profile = "ignored-profile"
class FakeCodec:
@ -124,8 +125,10 @@ class FakeAvInputStream:
return f"FakePyAvStream<{self.name}, {self.time_base}>"
VIDEO_STREAM = FakeAvInputStream(VIDEO_STREAM_FORMAT, VIDEO_FRAME_RATE)
AUDIO_STREAM = FakeAvInputStream(AUDIO_STREAM_FORMAT, AUDIO_SAMPLE_RATE)
VIDEO_STREAM = FakeAvInputStream(VIDEO_STREAM_FORMAT, VIDEO_TIME_BASE)
AUDIO_STREAM = FakeAvInputStream(
AUDIO_STREAM_FORMAT, fractions.Fraction(1 / AUDIO_SAMPLE_RATE)
)
class PacketSequence:
@ -158,10 +161,10 @@ class PacketSequence:
def __init__(self):
super().__init__(3)
time_base = fractions.Fraction(1, VIDEO_FRAME_RATE)
dts = int(self.packet * PACKET_DURATION / time_base)
pts = int(self.packet * PACKET_DURATION / time_base)
duration = int(PACKET_DURATION / time_base)
time_base = VIDEO_TIME_BASE
dts = round(self.packet * PACKET_DURATION / time_base)
pts = round(self.packet * PACKET_DURATION / time_base)
duration = round(PACKET_DURATION / time_base)
stream = VIDEO_STREAM
# Pretend we get 1 keyframe every second
is_keyframe = not (self.packet - 1) % (VIDEO_FRAME_RATE * KEYFRAME_INTERVAL)
@ -405,7 +408,9 @@ async def test_discard_old_packets(hass):
packets = list(PacketSequence(TEST_SEQUENCE_LENGTH))
# Packets after this one are considered out of order
packets[OUT_OF_ORDER_PACKET_INDEX - 1].dts = 9090
packets[OUT_OF_ORDER_PACKET_INDEX - 1].dts = round(
TEST_SEQUENCE_LENGTH / VIDEO_FRAME_RATE / VIDEO_TIME_BASE
)
decoded_stream = await async_decode_stream(hass, packets)
segments = decoded_stream.segments
@ -430,7 +435,7 @@ async def test_packet_overflow(hass):
packets[OUT_OF_ORDER_PACKET_INDEX].dts = -9000000
py_av = MockPyAv()
with pytest.raises(StreamWorkerError, match=r"Timestamp overflow detected"):
with pytest.raises(StreamWorkerError, match=r"Timestamp discontinuity detected"):
await async_decode_stream(hass, packets, py_av=py_av)
decoded_stream = py_av.capture_buffer
segments = decoded_stream.segments
@ -578,12 +583,12 @@ async def test_audio_is_first_packet(hass):
packets = list(PacketSequence(num_packets))
# Pair up an audio packet for each video packet
packets[0].stream = AUDIO_STREAM
packets[0].dts = int(packets[1].dts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE)
packets[0].pts = int(packets[1].pts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE)
packets[0].dts = round(packets[1].dts * VIDEO_TIME_BASE * AUDIO_SAMPLE_RATE)
packets[0].pts = round(packets[1].pts * VIDEO_TIME_BASE * AUDIO_SAMPLE_RATE)
packets[1].is_keyframe = True # Move the video keyframe from packet 0 to packet 1
packets[2].stream = AUDIO_STREAM
packets[2].dts = int(packets[3].dts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE)
packets[2].pts = int(packets[3].pts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE)
packets[2].dts = round(packets[3].dts * VIDEO_TIME_BASE * AUDIO_SAMPLE_RATE)
packets[2].pts = round(packets[3].pts * VIDEO_TIME_BASE * AUDIO_SAMPLE_RATE)
decoded_stream = await async_decode_stream(hass, packets, py_av=py_av)
complete_segments = decoded_stream.complete_segments
@ -600,8 +605,8 @@ async def test_audio_packets_found(hass):
num_packets = PACKETS_TO_WAIT_FOR_AUDIO + 1
packets = list(PacketSequence(num_packets))
packets[1].stream = AUDIO_STREAM
packets[1].dts = int(packets[0].dts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE)
packets[1].pts = int(packets[0].pts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE)
packets[1].dts = round(packets[0].dts * VIDEO_TIME_BASE * AUDIO_SAMPLE_RATE)
packets[1].pts = round(packets[0].pts * VIDEO_TIME_BASE * AUDIO_SAMPLE_RATE)
decoded_stream = await async_decode_stream(hass, packets, py_av=py_av)
complete_segments = decoded_stream.complete_segments