write out the .flac and .stt files for users that opt-in for tagging purposes.
parent
613ea71d99
commit
05ba05f50d
|
@ -1,41 +1,112 @@
|
|||
import os
|
||||
from http import HTTPStatus
|
||||
from io import BytesIO
|
||||
from time import time
|
||||
|
||||
from speech_recognition import AudioFile, Recognizer
|
||||
|
||||
from selene.api import PublicEndpoint
|
||||
from selene.data.account import AccountRepository, OPEN_DATASET
|
||||
|
||||
SELENE_DATA_DIR = '/opt/selene/data'
|
||||
|
||||
|
||||
class GoogleSTTEndpoint(PublicEndpoint):
|
||||
""" Endpoint to send a flac audio file with voice and get back a utterance"""
|
||||
"""Endpoint to send a flac audio file with voice and get back a utterance"""
|
||||
def __init__(self):
|
||||
super(GoogleSTTEndpoint, self).__init__()
|
||||
self.google_stt_key = self.config['GOOGLE_STT_KEY']
|
||||
self.recognizer = Recognizer()
|
||||
self.account = None
|
||||
self.account_shares_data = False
|
||||
|
||||
def post(self):
|
||||
self._authenticate()
|
||||
self._get_account()
|
||||
self._check_for_open_dataset_agreement()
|
||||
self._write_flac_audio_file()
|
||||
stt_response = self._call_google_stt()
|
||||
response = self._build_response(stt_response)
|
||||
self._write_stt_result_file(response)
|
||||
|
||||
return response, HTTPStatus.OK
|
||||
|
||||
def _get_account(self):
|
||||
if self.device_id is not None:
|
||||
account_repo = AccountRepository(self.db)
|
||||
self.account = account_repo.get_account_by_device_id(self.device_id)
|
||||
|
||||
def _check_for_open_dataset_agreement(self):
|
||||
for agreement in self.account.agreements:
|
||||
if agreement.type == OPEN_DATASET:
|
||||
self.account_shares_data = True
|
||||
|
||||
def _write_flac_audio_file(self):
|
||||
"""Save the audio file for STT tagging"""
|
||||
self._write_open_dataset_file(self.request.data, file_type='.flac')
|
||||
|
||||
def _write_stt_result_file(self, stt_result):
|
||||
"""Save the STT results for tagging."""
|
||||
file_contents = '\n'.join(stt_result)
|
||||
self._write_open_dataset_file(file_contents.encode(), file_type='.stt')
|
||||
|
||||
def _write_open_dataset_file(self, content, file_type):
|
||||
if self.account is not None:
|
||||
file_name = '{account_id}_{time}.{file_type}'.format(
|
||||
account_id=self.account.id,
|
||||
file_type=file_type,
|
||||
time=time()
|
||||
)
|
||||
file_path = os.path.join(SELENE_DATA_DIR, file_name)
|
||||
with open(file_path, 'wb') as flac_file:
|
||||
flac_file.write(content)
|
||||
|
||||
def _call_google_stt(self):
|
||||
"""Use the audio data from the request to call the Google STT API
|
||||
|
||||
We need to replicate the first 16 bytes in the audio due a bug with
|
||||
the Google speech recognition library that removes the first 16 bytes
|
||||
from the flac file we are sending.
|
||||
"""
|
||||
lang = self.request.args['lang']
|
||||
limit = int(self.request.args['limit'])
|
||||
audio = self.request.data
|
||||
# We need to replicate the first 16 bytes in the audio due a bug with the speech recognition library that
|
||||
# removes the first 16 bytes from the flac file we are sending
|
||||
with AudioFile(BytesIO(audio[:16] + audio)) as source:
|
||||
data = self.recognizer.record(source)
|
||||
response = self.recognizer.recognize_google(data, key=self.google_stt_key, language=lang, show_all=True)
|
||||
if isinstance(response, dict):
|
||||
alternative = response.get("alternative")
|
||||
if 'confidence' in alternative:
|
||||
# Sorting by confidence:
|
||||
alternative = sorted(alternative, key=lambda alt: alt['confidence'], reverse=True)
|
||||
alternative = [alt['transcript'] for alt in alternative]
|
||||
# Return n transcripts with the higher confidence. That is useful for the case when send a ambiguous
|
||||
# voice file and the correct utterance is not the utterance with highest confidence and the API
|
||||
# client is interested in test the utterances found.
|
||||
result = alternative if len(alternative) <= limit else alternative[:limit]
|
||||
recording = self.recognizer.record(source)
|
||||
response = self.recognizer.recognize_google(
|
||||
recording,
|
||||
key=self.google_stt_key,
|
||||
language=lang,
|
||||
show_all=True
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def _build_response(self, stt_response):
|
||||
"""Build the response to return to the device.
|
||||
|
||||
Return n transcripts with the higher confidence. That is useful for
|
||||
the case when send a ambiguous voice file and the correct utterance is
|
||||
not the utterance with highest confidence and the API.
|
||||
"""
|
||||
limit = int(self.request.args['limit'])
|
||||
if isinstance(stt_response, dict):
|
||||
alternative = stt_response.get("alternative")
|
||||
if 'confidence' in alternative:
|
||||
# Sorting by confidence:
|
||||
alternative = sorted(
|
||||
alternative,
|
||||
key=lambda alt: alt['confidence'],
|
||||
reverse=True
|
||||
)
|
||||
alternative = [alt['transcript'] for alt in alternative]
|
||||
# client is interested in test the utterances found.
|
||||
if len(alternative) <= limit:
|
||||
response = alternative
|
||||
else:
|
||||
result = [alternative[0]['transcript']]
|
||||
response = alternative[:limit]
|
||||
else:
|
||||
result = []
|
||||
return result
|
||||
|
||||
response = [alternative[0]['transcript']]
|
||||
else:
|
||||
response = []
|
||||
|
||||
return response
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
import os
|
||||
|
||||
from behave import fixture, use_fixture
|
||||
|
||||
from public_api.api import public
|
||||
|
@ -150,8 +152,16 @@ def _add_device_specific_skill(context):
|
|||
def after_tag(context, tag):
|
||||
if tag == 'new_skill':
|
||||
_delete_new_skill(context)
|
||||
elif tag == 'stt':
|
||||
_delete_stt_tagging_files()
|
||||
|
||||
|
||||
def _delete_new_skill(context):
|
||||
remove_device_skill(context.db, context.new_manifest_skill)
|
||||
remove_skill(context.db, context.new_skill)
|
||||
|
||||
|
||||
def _delete_stt_tagging_files():
|
||||
data_dir = '/opt/selene/data'
|
||||
for file_name in os.listdir(data_dir):
|
||||
os.remove(os.path.join(data_dir, file_name))
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
Feature: Get an utterance
|
||||
Test the google STT integration
|
||||
|
||||
@stt
|
||||
Scenario: A valid flac audio with a voice record is passed
|
||||
When A flac audio with the utterance "tell me a joke" is passed
|
||||
Then return the utterance "tell me a joke"
|
||||
|
|
|
@ -1,17 +1,18 @@
|
|||
import json
|
||||
import os
|
||||
from http import HTTPStatus
|
||||
from io import BytesIO
|
||||
from os import path
|
||||
|
||||
from behave import When, Then
|
||||
from hamcrest import assert_that, equal_to
|
||||
from hamcrest import assert_that, equal_to, not_none
|
||||
|
||||
|
||||
@When('A flac audio with the utterance "tell me a joke" is passed')
|
||||
def call_google_stt_endpoint(context):
|
||||
access_token = context.device_login['accessToken']
|
||||
headers = dict(Authorization='Bearer {token}'.format(token=access_token))
|
||||
with open(path.join(path.dirname(__file__), 'resources/test_stt.flac'), 'rb') as flac:
|
||||
resources_dir = os.path.join(os.path.dirname(__file__), 'resources')
|
||||
with open(os.path.join(resources_dir, 'test_stt.flac'), 'rb') as flac:
|
||||
audio = BytesIO(flac.read())
|
||||
context.response = context.client.post(
|
||||
'/v1/stt?lang=en-US&limit=1',
|
||||
|
@ -26,3 +27,31 @@ def validate_response(context):
|
|||
response_data = json.loads(context.response.data)
|
||||
expected_response = ['tell me a joke']
|
||||
assert_that(response_data, equal_to(expected_response))
|
||||
|
||||
resources_dir = os.path.join(os.path.dirname(__file__), 'resources')
|
||||
with open(os.path.join(resources_dir, 'test_stt.flac'), 'rb') as input_file:
|
||||
input_file_content = input_file.read()
|
||||
flac_file_path = _get_stt_result_file(context.account.id, '.flac')
|
||||
assert_that(flac_file_path, not_none())
|
||||
with open(flac_file_path, 'rb') as output_file:
|
||||
output_file_content = output_file.read()
|
||||
assert_that(input_file_content, equal_to(output_file_content))
|
||||
|
||||
stt_file_path = _get_stt_result_file(context.account.id, '.stt')
|
||||
assert_that(stt_file_path, not_none())
|
||||
with open(stt_file_path, 'rb') as output_file:
|
||||
output_file_content = output_file.read()
|
||||
assert_that(b'tell me a joke', equal_to(output_file_content))
|
||||
|
||||
|
||||
def _get_stt_result_file(account_id, file_suffix):
|
||||
file_path = None
|
||||
for stt_file_name in os.listdir('/opt/selene/data'):
|
||||
file_name_match = (
|
||||
stt_file_name.startswith(account_id)
|
||||
and stt_file_name.endswith(file_suffix)
|
||||
)
|
||||
if file_name_match:
|
||||
file_path = os.path.join('/opt/selene/data/', stt_file_name)
|
||||
|
||||
return file_path
|
||||
|
|
|
@ -89,6 +89,7 @@ class PublicEndpoint(MethodView):
|
|||
self.cache: SeleneCache = self.config['SELENE_CACHE']
|
||||
global_context.cache = self.cache
|
||||
self.etag_manager: ETagManager = ETagManager(self.cache, self.config)
|
||||
self.device_id = None
|
||||
|
||||
@property
|
||||
def db(self):
|
||||
|
@ -112,6 +113,7 @@ class PublicEndpoint(MethodView):
|
|||
session = json.loads(session)
|
||||
device_uuid = session['uuid']
|
||||
global_context.device_id = device_uuid
|
||||
self.device_id = device_uuid
|
||||
if device_id is not None:
|
||||
device_authenticated = (device_id == device_uuid)
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue