parent
ab76b3ffb3
commit
d0153f5031
|
@ -31,7 +31,7 @@ EXT_X_START_LL_HLS = 2
|
||||||
|
|
||||||
|
|
||||||
PACKETS_TO_WAIT_FOR_AUDIO = 20 # Some streams have an audio stream with no audio
|
PACKETS_TO_WAIT_FOR_AUDIO = 20 # Some streams have an audio stream with no audio
|
||||||
MAX_TIMESTAMP_GAP = 10000 # seconds - anything from 10 to 50000 is probably reasonable
|
MAX_TIMESTAMP_GAP = 30 # seconds - anything from 10 to 50000 is probably reasonable
|
||||||
|
|
||||||
MAX_MISSING_DTS = 6 # Number of packets missing DTS to allow
|
MAX_MISSING_DTS = 6 # Number of packets missing DTS to allow
|
||||||
SOURCE_TIMEOUT = 30 # Timeout for reading stream source
|
SOURCE_TIMEOUT = 30 # Timeout for reading stream source
|
||||||
|
|
|
@ -38,6 +38,7 @@ from .fmp4utils import read_init
|
||||||
from .hls import HlsStreamOutput
|
from .hls import HlsStreamOutput
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
NEGATIVE_INF = float("-inf")
|
||||||
|
|
||||||
|
|
||||||
class StreamWorkerError(Exception):
|
class StreamWorkerError(Exception):
|
||||||
|
@ -416,14 +417,15 @@ class PeekIterator(Iterator):
|
||||||
class TimestampValidator:
|
class TimestampValidator:
|
||||||
"""Validate ordering of timestamps for packets in a stream."""
|
"""Validate ordering of timestamps for packets in a stream."""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self, inv_video_time_base: int) -> None:
|
||||||
"""Initialize the TimestampValidator."""
|
"""Initialize the TimestampValidator."""
|
||||||
# Decompression timestamp of last packet in each stream
|
# Decompression timestamp of last packet in each stream
|
||||||
self._last_dts: dict[av.stream.Stream, int | float] = defaultdict(
|
self._last_dts: dict[av.stream.Stream, int | float] = defaultdict(
|
||||||
lambda: float("-inf")
|
lambda: NEGATIVE_INF
|
||||||
)
|
)
|
||||||
# Number of consecutive missing decompression timestamps
|
# Number of consecutive missing decompression timestamps
|
||||||
self._missing_dts = 0
|
self._missing_dts = 0
|
||||||
|
self._max_dts_gap = MAX_TIMESTAMP_GAP * inv_video_time_base
|
||||||
|
|
||||||
def is_valid(self, packet: av.Packet) -> bool:
|
def is_valid(self, packet: av.Packet) -> bool:
|
||||||
"""Validate the packet timestamp based on ordering within the stream."""
|
"""Validate the packet timestamp based on ordering within the stream."""
|
||||||
|
@ -438,13 +440,12 @@ class TimestampValidator:
|
||||||
self._missing_dts = 0
|
self._missing_dts = 0
|
||||||
# Discard when dts is not monotonic. Terminate if gap is too wide.
|
# Discard when dts is not monotonic. Terminate if gap is too wide.
|
||||||
prev_dts = self._last_dts[packet.stream]
|
prev_dts = self._last_dts[packet.stream]
|
||||||
|
if abs(prev_dts - packet.dts) > self._max_dts_gap and prev_dts != NEGATIVE_INF:
|
||||||
|
raise StreamWorkerError(
|
||||||
|
f"Timestamp discontinuity detected: last dts = {prev_dts}, dts ="
|
||||||
|
f" {packet.dts}"
|
||||||
|
)
|
||||||
if packet.dts <= prev_dts:
|
if packet.dts <= prev_dts:
|
||||||
gap = packet.time_base * (prev_dts - packet.dts)
|
|
||||||
if gap > MAX_TIMESTAMP_GAP:
|
|
||||||
raise StreamWorkerError(
|
|
||||||
f"Timestamp overflow detected: last dts = {prev_dts}, dts ="
|
|
||||||
f" {packet.dts}"
|
|
||||||
)
|
|
||||||
return False
|
return False
|
||||||
self._last_dts[packet.stream] = packet.dts
|
self._last_dts[packet.stream] = packet.dts
|
||||||
return True
|
return True
|
||||||
|
@ -527,7 +528,7 @@ def stream_worker(
|
||||||
if audio_stream:
|
if audio_stream:
|
||||||
stream_state.diagnostics.set_value("audio_codec", audio_stream.name)
|
stream_state.diagnostics.set_value("audio_codec", audio_stream.name)
|
||||||
|
|
||||||
dts_validator = TimestampValidator()
|
dts_validator = TimestampValidator(int(1 / video_stream.time_base))
|
||||||
container_packets = PeekIterator(
|
container_packets = PeekIterator(
|
||||||
filter(dts_validator.is_valid, container.demux((video_stream, audio_stream)))
|
filter(dts_validator.is_valid, container.demux((video_stream, audio_stream)))
|
||||||
)
|
)
|
||||||
|
|
|
@ -58,6 +58,7 @@ STREAM_SOURCE = "some-stream-source"
|
||||||
AUDIO_STREAM_FORMAT = "mp3"
|
AUDIO_STREAM_FORMAT = "mp3"
|
||||||
VIDEO_STREAM_FORMAT = "h264"
|
VIDEO_STREAM_FORMAT = "h264"
|
||||||
VIDEO_FRAME_RATE = 12
|
VIDEO_FRAME_RATE = 12
|
||||||
|
VIDEO_TIME_BASE = fractions.Fraction(1 / 90000)
|
||||||
AUDIO_SAMPLE_RATE = 11025
|
AUDIO_SAMPLE_RATE = 11025
|
||||||
KEYFRAME_INTERVAL = 1 # in seconds
|
KEYFRAME_INTERVAL = 1 # in seconds
|
||||||
PACKET_DURATION = fractions.Fraction(1, VIDEO_FRAME_RATE) # in seconds
|
PACKET_DURATION = fractions.Fraction(1, VIDEO_FRAME_RATE) # in seconds
|
||||||
|
@ -97,10 +98,10 @@ def mock_stream_settings(hass):
|
||||||
class FakeAvInputStream:
|
class FakeAvInputStream:
|
||||||
"""A fake pyav Stream."""
|
"""A fake pyav Stream."""
|
||||||
|
|
||||||
def __init__(self, name, rate):
|
def __init__(self, name, time_base):
|
||||||
"""Initialize the stream."""
|
"""Initialize the stream."""
|
||||||
self.name = name
|
self.name = name
|
||||||
self.time_base = fractions.Fraction(1, rate)
|
self.time_base = time_base
|
||||||
self.profile = "ignored-profile"
|
self.profile = "ignored-profile"
|
||||||
|
|
||||||
class FakeCodec:
|
class FakeCodec:
|
||||||
|
@ -124,8 +125,10 @@ class FakeAvInputStream:
|
||||||
return f"FakePyAvStream<{self.name}, {self.time_base}>"
|
return f"FakePyAvStream<{self.name}, {self.time_base}>"
|
||||||
|
|
||||||
|
|
||||||
VIDEO_STREAM = FakeAvInputStream(VIDEO_STREAM_FORMAT, VIDEO_FRAME_RATE)
|
VIDEO_STREAM = FakeAvInputStream(VIDEO_STREAM_FORMAT, VIDEO_TIME_BASE)
|
||||||
AUDIO_STREAM = FakeAvInputStream(AUDIO_STREAM_FORMAT, AUDIO_SAMPLE_RATE)
|
AUDIO_STREAM = FakeAvInputStream(
|
||||||
|
AUDIO_STREAM_FORMAT, fractions.Fraction(1 / AUDIO_SAMPLE_RATE)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class PacketSequence:
|
class PacketSequence:
|
||||||
|
@ -158,10 +161,10 @@ class PacketSequence:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(3)
|
super().__init__(3)
|
||||||
|
|
||||||
time_base = fractions.Fraction(1, VIDEO_FRAME_RATE)
|
time_base = VIDEO_TIME_BASE
|
||||||
dts = int(self.packet * PACKET_DURATION / time_base)
|
dts = round(self.packet * PACKET_DURATION / time_base)
|
||||||
pts = int(self.packet * PACKET_DURATION / time_base)
|
pts = round(self.packet * PACKET_DURATION / time_base)
|
||||||
duration = int(PACKET_DURATION / time_base)
|
duration = round(PACKET_DURATION / time_base)
|
||||||
stream = VIDEO_STREAM
|
stream = VIDEO_STREAM
|
||||||
# Pretend we get 1 keyframe every second
|
# Pretend we get 1 keyframe every second
|
||||||
is_keyframe = not (self.packet - 1) % (VIDEO_FRAME_RATE * KEYFRAME_INTERVAL)
|
is_keyframe = not (self.packet - 1) % (VIDEO_FRAME_RATE * KEYFRAME_INTERVAL)
|
||||||
|
@ -405,7 +408,9 @@ async def test_discard_old_packets(hass):
|
||||||
|
|
||||||
packets = list(PacketSequence(TEST_SEQUENCE_LENGTH))
|
packets = list(PacketSequence(TEST_SEQUENCE_LENGTH))
|
||||||
# Packets after this one are considered out of order
|
# Packets after this one are considered out of order
|
||||||
packets[OUT_OF_ORDER_PACKET_INDEX - 1].dts = 9090
|
packets[OUT_OF_ORDER_PACKET_INDEX - 1].dts = round(
|
||||||
|
TEST_SEQUENCE_LENGTH / VIDEO_FRAME_RATE / VIDEO_TIME_BASE
|
||||||
|
)
|
||||||
|
|
||||||
decoded_stream = await async_decode_stream(hass, packets)
|
decoded_stream = await async_decode_stream(hass, packets)
|
||||||
segments = decoded_stream.segments
|
segments = decoded_stream.segments
|
||||||
|
@ -430,7 +435,7 @@ async def test_packet_overflow(hass):
|
||||||
packets[OUT_OF_ORDER_PACKET_INDEX].dts = -9000000
|
packets[OUT_OF_ORDER_PACKET_INDEX].dts = -9000000
|
||||||
|
|
||||||
py_av = MockPyAv()
|
py_av = MockPyAv()
|
||||||
with pytest.raises(StreamWorkerError, match=r"Timestamp overflow detected"):
|
with pytest.raises(StreamWorkerError, match=r"Timestamp discontinuity detected"):
|
||||||
await async_decode_stream(hass, packets, py_av=py_av)
|
await async_decode_stream(hass, packets, py_av=py_av)
|
||||||
decoded_stream = py_av.capture_buffer
|
decoded_stream = py_av.capture_buffer
|
||||||
segments = decoded_stream.segments
|
segments = decoded_stream.segments
|
||||||
|
@ -578,12 +583,12 @@ async def test_audio_is_first_packet(hass):
|
||||||
packets = list(PacketSequence(num_packets))
|
packets = list(PacketSequence(num_packets))
|
||||||
# Pair up an audio packet for each video packet
|
# Pair up an audio packet for each video packet
|
||||||
packets[0].stream = AUDIO_STREAM
|
packets[0].stream = AUDIO_STREAM
|
||||||
packets[0].dts = int(packets[1].dts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE)
|
packets[0].dts = round(packets[1].dts * VIDEO_TIME_BASE * AUDIO_SAMPLE_RATE)
|
||||||
packets[0].pts = int(packets[1].pts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE)
|
packets[0].pts = round(packets[1].pts * VIDEO_TIME_BASE * AUDIO_SAMPLE_RATE)
|
||||||
packets[1].is_keyframe = True # Move the video keyframe from packet 0 to packet 1
|
packets[1].is_keyframe = True # Move the video keyframe from packet 0 to packet 1
|
||||||
packets[2].stream = AUDIO_STREAM
|
packets[2].stream = AUDIO_STREAM
|
||||||
packets[2].dts = int(packets[3].dts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE)
|
packets[2].dts = round(packets[3].dts * VIDEO_TIME_BASE * AUDIO_SAMPLE_RATE)
|
||||||
packets[2].pts = int(packets[3].pts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE)
|
packets[2].pts = round(packets[3].pts * VIDEO_TIME_BASE * AUDIO_SAMPLE_RATE)
|
||||||
|
|
||||||
decoded_stream = await async_decode_stream(hass, packets, py_av=py_av)
|
decoded_stream = await async_decode_stream(hass, packets, py_av=py_av)
|
||||||
complete_segments = decoded_stream.complete_segments
|
complete_segments = decoded_stream.complete_segments
|
||||||
|
@ -600,8 +605,8 @@ async def test_audio_packets_found(hass):
|
||||||
num_packets = PACKETS_TO_WAIT_FOR_AUDIO + 1
|
num_packets = PACKETS_TO_WAIT_FOR_AUDIO + 1
|
||||||
packets = list(PacketSequence(num_packets))
|
packets = list(PacketSequence(num_packets))
|
||||||
packets[1].stream = AUDIO_STREAM
|
packets[1].stream = AUDIO_STREAM
|
||||||
packets[1].dts = int(packets[0].dts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE)
|
packets[1].dts = round(packets[0].dts * VIDEO_TIME_BASE * AUDIO_SAMPLE_RATE)
|
||||||
packets[1].pts = int(packets[0].pts / VIDEO_FRAME_RATE * AUDIO_SAMPLE_RATE)
|
packets[1].pts = round(packets[0].pts * VIDEO_TIME_BASE * AUDIO_SAMPLE_RATE)
|
||||||
|
|
||||||
decoded_stream = await async_decode_stream(hass, packets, py_av=py_av)
|
decoded_stream = await async_decode_stream(hass, packets, py_av=py_av)
|
||||||
complete_segments = decoded_stream.complete_segments
|
complete_segments = decoded_stream.complete_segments
|
||||||
|
|
Loading…
Reference in New Issue