diff --git a/homeassistant/components/stream/const.py b/homeassistant/components/stream/const.py index 35af633435e..eb954a6a8f5 100644 --- a/homeassistant/components/stream/const.py +++ b/homeassistant/components/stream/const.py @@ -31,7 +31,7 @@ EXT_X_START_LL_HLS = 2 PACKETS_TO_WAIT_FOR_AUDIO = 20 # Some streams have an audio stream with no audio -MAX_TIMESTAMP_GAP = 10000 # seconds - anything from 10 to 50000 is probably reasonable +MAX_TIMESTAMP_GAP = 30 # seconds - anything from 10 to 50000 is probably reasonable MAX_MISSING_DTS = 6 # Number of packets missing DTS to allow SOURCE_TIMEOUT = 30 # Timeout for reading stream source diff --git a/homeassistant/components/stream/worker.py b/homeassistant/components/stream/worker.py index aefbbf698f1..f7908ca469d 100644 --- a/homeassistant/components/stream/worker.py +++ b/homeassistant/components/stream/worker.py @@ -38,6 +38,7 @@ from .fmp4utils import read_init from .hls import HlsStreamOutput _LOGGER = logging.getLogger(__name__) +NEGATIVE_INF = float("-inf") class StreamWorkerError(Exception): @@ -416,14 +417,15 @@ class PeekIterator(Iterator): class TimestampValidator: """Validate ordering of timestamps for packets in a stream.""" - def __init__(self) -> None: + def __init__(self, inv_video_time_base: int) -> None: """Initialize the TimestampValidator.""" # Decompression timestamp of last packet in each stream self._last_dts: dict[av.stream.Stream, int | float] = defaultdict( - lambda: float("-inf") + lambda: NEGATIVE_INF ) # Number of consecutive missing decompression timestamps self._missing_dts = 0 + self._max_dts_gap = MAX_TIMESTAMP_GAP * inv_video_time_base def is_valid(self, packet: av.Packet) -> bool: """Validate the packet timestamp based on ordering within the stream.""" @@ -438,13 +440,12 @@ class TimestampValidator: self._missing_dts = 0 # Discard when dts is not monotonic. Terminate if gap is too wide. prev_dts = self._last_dts[packet.stream] + if abs(prev_dts - packet.dts) > self._max_dts_gap and prev_dts != NEGATIVE_INF: + raise StreamWorkerError( + f"Timestamp discontinuity detected: last dts = {prev_dts}, dts =" + f" {packet.dts}" + ) if packet.dts <= prev_dts: - gap = packet.time_base * (prev_dts - packet.dts) - if gap > MAX_TIMESTAMP_GAP: - raise StreamWorkerError( - f"Timestamp overflow detected: last dts = {prev_dts}, dts =" - f" {packet.dts}" - ) return False self._last_dts[packet.stream] = packet.dts return True @@ -527,7 +528,7 @@ def stream_worker( if audio_stream: stream_state.diagnostics.set_value("audio_codec", audio_stream.name) - dts_validator = TimestampValidator() + dts_validator = TimestampValidator(int(1 / video_stream.time_base)) container_packets = PeekIterator( filter(dts_validator.is_valid, container.demux((video_stream, audio_stream))) ) diff --git a/tests/components/stream/test_worker.py b/tests/components/stream/test_worker.py index a5a1f00d90a..22a7627c062 100644 --- a/tests/components/stream/test_worker.py +++ b/tests/components/stream/test_worker.py @@ -58,6 +58,7 @@ STREAM_SOURCE = "some-stream-source" AUDIO_STREAM_FORMAT = "mp3" VIDEO_STREAM_FORMAT = "h264" VIDEO_FRAME_RATE = 12 +VIDEO_TIME_BASE = fractions.Fraction(1 / 90000) AUDIO_SAMPLE_RATE = 11025 KEYFRAME_INTERVAL = 1 # in seconds PACKET_DURATION = fractions.Fraction(1, VIDEO_FRAME_RATE) # in seconds @@ -97,10 +98,10 @@ def mock_stream_settings(hass): class FakeAvInputStream: """A fake pyav Stream.""" - def __init__(self, name, rate): + def __init__(self, name, time_base): """Initialize the stream.""" self.name = name - self.time_base = fractions.Fraction(1, rate) + self.time_base = time_base self.profile = "ignored-profile" class FakeCodec: @@ -124,8 +125,10 @@ class FakeAvInputStream: return f"FakePyAvStream<{self.name}, {self.time_base}>" -VIDEO_STREAM = FakeAvInputStream(VIDEO_STREAM_FORMAT, VIDEO_FRAME_RATE) -AUDIO_STREAM = FakeAvInputStream(AUDIO_STREAM_FORMAT, AUDIO_SAMPLE_RATE) +VIDEO_STREAM = FakeAvInputStream(VIDEO_STREAM_FORMAT, VIDEO_TIME_BASE) +AUDIO_STREAM = FakeAvInputStream( + AUDIO_STREAM_FORMAT, fractions.Fraction(1 / AUDIO_SAMPLE_RATE) +) class PacketSequence: @@ -158,10 +161,10 @@ class PacketSequence: def __init__(self): super().__init__(3) - time_base = fractions.Fraction(1, VIDEO_FRAME_RATE) - dts = int(self.packet * PACKET_DURATION / time_base) - pts = int(self.packet * PACKET_DURATION / time_base) - duration = int(PACKET_DURATION / time_base) + time_base = VIDEO_TIME_BASE + dts = round(self.packet * PACKET_DURATION / time_base) + pts = round(self.packet * PACKET_DURATION / time_base) + duration = round(PACKET_DURATION / time_base) stream = VIDEO_STREAM # Pretend we get 1 keyframe every second is_keyframe = not (self.packet - 1) % (VIDEO_FRAME_RATE * KEYFRAME_INTERVAL) @@ -405,7 +408,9 @@ async def test_discard_old_packets(hass): packets = list(PacketSequence(TEST_SEQUENCE_LENGTH)) # Packets after this one are considered out of order - packets[OUT_OF_ORDER_PACKET_INDEX - 1].dts = 9090 + packets[OUT_OF_ORDER_PACKET_INDEX - 1].dts = round( + TEST_SEQUENCE_LENGTH / VIDEO_FRAME_RATE / VIDEO_TIME_BASE + ) decoded_stream = await async_decode_stream(hass, packets) segments = decoded_stream.segments @@ -430,7 +435,7 @@ async def test_packet_overflow(hass): packets[OUT_OF_ORDER_PACKET_INDEX].dts = -9000000 py_av = MockPyAv() - with pytest.raises(StreamWorkerError, match=r"Timestamp overflow detected"): + with pytest.raises(StreamWorkerError, match=r"Timestamp discontinuity detected"): await async_decode_stream(hass, packets, py_av=py_av) decoded_stream = py_av.capture_buffer segments = decoded_stream.segments @@ -578,12 +583,12 @@ async def test_audio_is_first_packet(hass): packets = list(PacketSequence(num_packets)) # Pair up an audio packet for each video packet packets[0].stream = AUDIO_STREAM - packets[0].dts = int(packets[1].dts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE) - packets[0].pts = int(packets[1].pts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE) + packets[0].dts = round(packets[1].dts * VIDEO_TIME_BASE * AUDIO_SAMPLE_RATE) + packets[0].pts = round(packets[1].pts * VIDEO_TIME_BASE * AUDIO_SAMPLE_RATE) packets[1].is_keyframe = True # Move the video keyframe from packet 0 to packet 1 packets[2].stream = AUDIO_STREAM - packets[2].dts = int(packets[3].dts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE) - packets[2].pts = int(packets[3].pts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE) + packets[2].dts = round(packets[3].dts * VIDEO_TIME_BASE * AUDIO_SAMPLE_RATE) + packets[2].pts = round(packets[3].pts * VIDEO_TIME_BASE * AUDIO_SAMPLE_RATE) decoded_stream = await async_decode_stream(hass, packets, py_av=py_av) complete_segments = decoded_stream.complete_segments @@ -600,8 +605,8 @@ async def test_audio_packets_found(hass): num_packets = PACKETS_TO_WAIT_FOR_AUDIO + 1 packets = list(PacketSequence(num_packets)) packets[1].stream = AUDIO_STREAM - packets[1].dts = int(packets[0].dts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE) - packets[1].pts = int(packets[0].pts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE) + packets[1].dts = round(packets[0].dts * VIDEO_TIME_BASE * AUDIO_SAMPLE_RATE) + packets[1].pts = round(packets[0].pts * VIDEO_TIME_BASE * AUDIO_SAMPLE_RATE) decoded_stream = await async_decode_stream(hass, packets, py_av=py_av) complete_segments = decoded_stream.complete_segments