parent
ab76b3ffb3
commit
d0153f5031
|
@ -31,7 +31,7 @@ EXT_X_START_LL_HLS = 2
|
|||
|
||||
|
||||
PACKETS_TO_WAIT_FOR_AUDIO = 20 # Some streams have an audio stream with no audio
|
||||
MAX_TIMESTAMP_GAP = 10000 # seconds - anything from 10 to 50000 is probably reasonable
|
||||
MAX_TIMESTAMP_GAP = 30 # seconds - anything from 10 to 50000 is probably reasonable
|
||||
|
||||
MAX_MISSING_DTS = 6 # Number of packets missing DTS to allow
|
||||
SOURCE_TIMEOUT = 30 # Timeout for reading stream source
|
||||
|
|
|
@ -38,6 +38,7 @@ from .fmp4utils import read_init
|
|||
from .hls import HlsStreamOutput
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
NEGATIVE_INF = float("-inf")
|
||||
|
||||
|
||||
class StreamWorkerError(Exception):
|
||||
|
@ -416,14 +417,15 @@ class PeekIterator(Iterator):
|
|||
class TimestampValidator:
|
||||
"""Validate ordering of timestamps for packets in a stream."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
def __init__(self, inv_video_time_base: int) -> None:
|
||||
"""Initialize the TimestampValidator."""
|
||||
# Decompression timestamp of last packet in each stream
|
||||
self._last_dts: dict[av.stream.Stream, int | float] = defaultdict(
|
||||
lambda: float("-inf")
|
||||
lambda: NEGATIVE_INF
|
||||
)
|
||||
# Number of consecutive missing decompression timestamps
|
||||
self._missing_dts = 0
|
||||
self._max_dts_gap = MAX_TIMESTAMP_GAP * inv_video_time_base
|
||||
|
||||
def is_valid(self, packet: av.Packet) -> bool:
|
||||
"""Validate the packet timestamp based on ordering within the stream."""
|
||||
|
@ -438,13 +440,12 @@ class TimestampValidator:
|
|||
self._missing_dts = 0
|
||||
# Discard when dts is not monotonic. Terminate if gap is too wide.
|
||||
prev_dts = self._last_dts[packet.stream]
|
||||
if abs(prev_dts - packet.dts) > self._max_dts_gap and prev_dts != NEGATIVE_INF:
|
||||
raise StreamWorkerError(
|
||||
f"Timestamp discontinuity detected: last dts = {prev_dts}, dts ="
|
||||
f" {packet.dts}"
|
||||
)
|
||||
if packet.dts <= prev_dts:
|
||||
gap = packet.time_base * (prev_dts - packet.dts)
|
||||
if gap > MAX_TIMESTAMP_GAP:
|
||||
raise StreamWorkerError(
|
||||
f"Timestamp overflow detected: last dts = {prev_dts}, dts ="
|
||||
f" {packet.dts}"
|
||||
)
|
||||
return False
|
||||
self._last_dts[packet.stream] = packet.dts
|
||||
return True
|
||||
|
@ -527,7 +528,7 @@ def stream_worker(
|
|||
if audio_stream:
|
||||
stream_state.diagnostics.set_value("audio_codec", audio_stream.name)
|
||||
|
||||
dts_validator = TimestampValidator()
|
||||
dts_validator = TimestampValidator(int(1 / video_stream.time_base))
|
||||
container_packets = PeekIterator(
|
||||
filter(dts_validator.is_valid, container.demux((video_stream, audio_stream)))
|
||||
)
|
||||
|
|
|
@ -58,6 +58,7 @@ STREAM_SOURCE = "some-stream-source"
|
|||
AUDIO_STREAM_FORMAT = "mp3"
|
||||
VIDEO_STREAM_FORMAT = "h264"
|
||||
VIDEO_FRAME_RATE = 12
|
||||
VIDEO_TIME_BASE = fractions.Fraction(1 / 90000)
|
||||
AUDIO_SAMPLE_RATE = 11025
|
||||
KEYFRAME_INTERVAL = 1 # in seconds
|
||||
PACKET_DURATION = fractions.Fraction(1, VIDEO_FRAME_RATE) # in seconds
|
||||
|
@ -97,10 +98,10 @@ def mock_stream_settings(hass):
|
|||
class FakeAvInputStream:
|
||||
"""A fake pyav Stream."""
|
||||
|
||||
def __init__(self, name, rate):
|
||||
def __init__(self, name, time_base):
|
||||
"""Initialize the stream."""
|
||||
self.name = name
|
||||
self.time_base = fractions.Fraction(1, rate)
|
||||
self.time_base = time_base
|
||||
self.profile = "ignored-profile"
|
||||
|
||||
class FakeCodec:
|
||||
|
@ -124,8 +125,10 @@ class FakeAvInputStream:
|
|||
return f"FakePyAvStream<{self.name}, {self.time_base}>"
|
||||
|
||||
|
||||
VIDEO_STREAM = FakeAvInputStream(VIDEO_STREAM_FORMAT, VIDEO_FRAME_RATE)
|
||||
AUDIO_STREAM = FakeAvInputStream(AUDIO_STREAM_FORMAT, AUDIO_SAMPLE_RATE)
|
||||
VIDEO_STREAM = FakeAvInputStream(VIDEO_STREAM_FORMAT, VIDEO_TIME_BASE)
|
||||
AUDIO_STREAM = FakeAvInputStream(
|
||||
AUDIO_STREAM_FORMAT, fractions.Fraction(1 / AUDIO_SAMPLE_RATE)
|
||||
)
|
||||
|
||||
|
||||
class PacketSequence:
|
||||
|
@ -158,10 +161,10 @@ class PacketSequence:
|
|||
def __init__(self):
|
||||
super().__init__(3)
|
||||
|
||||
time_base = fractions.Fraction(1, VIDEO_FRAME_RATE)
|
||||
dts = int(self.packet * PACKET_DURATION / time_base)
|
||||
pts = int(self.packet * PACKET_DURATION / time_base)
|
||||
duration = int(PACKET_DURATION / time_base)
|
||||
time_base = VIDEO_TIME_BASE
|
||||
dts = round(self.packet * PACKET_DURATION / time_base)
|
||||
pts = round(self.packet * PACKET_DURATION / time_base)
|
||||
duration = round(PACKET_DURATION / time_base)
|
||||
stream = VIDEO_STREAM
|
||||
# Pretend we get 1 keyframe every second
|
||||
is_keyframe = not (self.packet - 1) % (VIDEO_FRAME_RATE * KEYFRAME_INTERVAL)
|
||||
|
@ -405,7 +408,9 @@ async def test_discard_old_packets(hass):
|
|||
|
||||
packets = list(PacketSequence(TEST_SEQUENCE_LENGTH))
|
||||
# Packets after this one are considered out of order
|
||||
packets[OUT_OF_ORDER_PACKET_INDEX - 1].dts = 9090
|
||||
packets[OUT_OF_ORDER_PACKET_INDEX - 1].dts = round(
|
||||
TEST_SEQUENCE_LENGTH / VIDEO_FRAME_RATE / VIDEO_TIME_BASE
|
||||
)
|
||||
|
||||
decoded_stream = await async_decode_stream(hass, packets)
|
||||
segments = decoded_stream.segments
|
||||
|
@ -430,7 +435,7 @@ async def test_packet_overflow(hass):
|
|||
packets[OUT_OF_ORDER_PACKET_INDEX].dts = -9000000
|
||||
|
||||
py_av = MockPyAv()
|
||||
with pytest.raises(StreamWorkerError, match=r"Timestamp overflow detected"):
|
||||
with pytest.raises(StreamWorkerError, match=r"Timestamp discontinuity detected"):
|
||||
await async_decode_stream(hass, packets, py_av=py_av)
|
||||
decoded_stream = py_av.capture_buffer
|
||||
segments = decoded_stream.segments
|
||||
|
@ -578,12 +583,12 @@ async def test_audio_is_first_packet(hass):
|
|||
packets = list(PacketSequence(num_packets))
|
||||
# Pair up an audio packet for each video packet
|
||||
packets[0].stream = AUDIO_STREAM
|
||||
packets[0].dts = int(packets[1].dts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE)
|
||||
packets[0].pts = int(packets[1].pts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE)
|
||||
packets[0].dts = round(packets[1].dts * VIDEO_TIME_BASE * AUDIO_SAMPLE_RATE)
|
||||
packets[0].pts = round(packets[1].pts * VIDEO_TIME_BASE * AUDIO_SAMPLE_RATE)
|
||||
packets[1].is_keyframe = True # Move the video keyframe from packet 0 to packet 1
|
||||
packets[2].stream = AUDIO_STREAM
|
||||
packets[2].dts = int(packets[3].dts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE)
|
||||
packets[2].pts = int(packets[3].pts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE)
|
||||
packets[2].dts = round(packets[3].dts * VIDEO_TIME_BASE * AUDIO_SAMPLE_RATE)
|
||||
packets[2].pts = round(packets[3].pts * VIDEO_TIME_BASE * AUDIO_SAMPLE_RATE)
|
||||
|
||||
decoded_stream = await async_decode_stream(hass, packets, py_av=py_av)
|
||||
complete_segments = decoded_stream.complete_segments
|
||||
|
@ -600,8 +605,8 @@ async def test_audio_packets_found(hass):
|
|||
num_packets = PACKETS_TO_WAIT_FOR_AUDIO + 1
|
||||
packets = list(PacketSequence(num_packets))
|
||||
packets[1].stream = AUDIO_STREAM
|
||||
packets[1].dts = int(packets[0].dts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE)
|
||||
packets[1].pts = int(packets[0].pts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE)
|
||||
packets[1].dts = round(packets[0].dts * VIDEO_TIME_BASE * AUDIO_SAMPLE_RATE)
|
||||
packets[1].pts = round(packets[0].pts * VIDEO_TIME_BASE * AUDIO_SAMPLE_RATE)
|
||||
|
||||
decoded_stream = await async_decode_stream(hass, packets, py_av=py_av)
|
||||
complete_segments = decoded_stream.complete_segments
|
||||
|
|
Loading…
Reference in New Issue