selene-backend/api/public/public_api/endpoints/audio_transcription.py

135 lines
5.2 KiB
Python

# Mycroft Server - Backend
# Copyright (C) 2019 Mycroft AI Inc
# SPDX-License-Identifier: AGPL-3.0-or-later
#
# This file is part of the Mycroft Server.
#
# The Mycroft Server is free software: you can redistribute it and/or
# modify it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
"""Public API endpoint for official Mycroft-supported audio transcriptions.
When a device is configured to use the Mycroft STT plugin for transcribing audio,
this endpoint will be called to do the transcription anonymously.
"""
from datetime import datetime
from decimal import Decimal
from http import HTTPStatus
from io import BytesIO
from typing import Optional
import librosa
from google.cloud import speech
from selene.api import PublicEndpoint, track_account_activity
from selene.data.account import AccountRepository
from selene.data.metric import SttTranscriptionMetric, TranscriptionMetricRepository
from selene.util.log import get_selene_logger
SAMPLE_RATE = 16000
_log = get_selene_logger(__name__)
class AudioTranscriptionEndpoint(PublicEndpoint):
"""Transcribes audio data to text and responds with the result."""
def __init__(self):
super().__init__()
self.audio_duration = Decimal(0.0)
self.transcription_duration = Decimal(0.0)
def post(self):
"""Processes an HTTP Post request."""
self._authenticate()
transcription = self._transcribe()
self._add_transcription_metric(transcription)
if transcription is not None:
track_account_activity(self.db, self.device_id)
return dict(transcription=transcription), HTTPStatus.OK
def _transcribe(self) -> Optional[str]:
"""Transcribes the audio in the request to text using a transcription service.
:returns: None if the transcription failed or the transcription
"""
response = self._call_transcription_api()
transcription = self._get_transcription(response)
return transcription
def _call_transcription_api(self) -> Optional[speech.RecognizeResponse]:
"""Calls the configured audio transcription service API.
:returns: None if the call fails or the result of the API call
"""
response = None
client = speech.SpeechClient()
audio = speech.RecognitionAudio(content=self.request.data)
config_values = dict(
encoding=speech.RecognitionConfig.AudioEncoding.FLAC,
sample_rate_hertz=SAMPLE_RATE,
language_code="en-US",
)
config = speech.RecognitionConfig(**config_values)
start_timestamp = datetime.now()
try:
response = client.recognize(config=config, audio=audio)
except Exception:
_log.exception(f"{self.request_id}: Transcription failed.")
finally:
end_timestamp = datetime.now()
transcription_duration = (end_timestamp - start_timestamp).total_seconds()
self.transcription_duration = Decimal(str(transcription_duration))
return response
def _get_transcription(
self, response: Optional[speech.RecognizeResponse]
) -> Optional[str]:
"""Interrogates the response from the transcription service API.
:param response: the transcription service API response
:return: None if the audio could not be transcribed or the transcription
"""
transcription = None
if response:
highest_confidence = 0
for result in response.results:
for alternative in result.alternatives:
if alternative.confidence > highest_confidence:
transcription = alternative.transcript
return transcription
def _add_transcription_metric(self, transcription: str):
"""Adds metrics for this STT transcription to the database."""
account_repo = AccountRepository(self.db)
account = account_repo.get_account_by_device_id(self.device_id)
transcription_metric = SttTranscriptionMetric(
account_id=account.id,
engine="Google Cloud",
success=transcription is not None,
audio_duration=Decimal(str(self._determine_audio_duration())),
transcription_duration=Decimal(str(self.transcription_duration)),
)
transcription_metric_repo = TranscriptionMetricRepository(self.db)
transcription_metric_repo.add(transcription_metric)
def _determine_audio_duration(self) -> float:
"""Determines the duration of the audio data for the metrics."""
with BytesIO(self.request.data) as request_audio:
audio, _ = librosa.load(request_audio, sr=SAMPLE_RATE, mono=True)
return librosa.get_duration(y=audio, sr=SAMPLE_RATE)