write out the .flac and .stt files for users that opt-in for tagging purposes.

pull/191/head
Chris Veilleux 2019-07-01 21:57:22 -05:00
parent 613ea71d99
commit 05ba05f50d
5 changed files with 136 additions and 23 deletions

View File

@ -1,41 +1,112 @@
import os
from http import HTTPStatus
from io import BytesIO
from time import time
from speech_recognition import AudioFile, Recognizer
from selene.api import PublicEndpoint
from selene.data.account import AccountRepository, OPEN_DATASET
SELENE_DATA_DIR = '/opt/selene/data'
class GoogleSTTEndpoint(PublicEndpoint):
""" Endpoint to send a flac audio file with voice and get back a utterance"""
"""Endpoint to send a flac audio file with voice and get back a utterance"""
def __init__(self):
super(GoogleSTTEndpoint, self).__init__()
self.google_stt_key = self.config['GOOGLE_STT_KEY']
self.recognizer = Recognizer()
self.account = None
self.account_shares_data = False
def post(self):
self._authenticate()
self._get_account()
self._check_for_open_dataset_agreement()
self._write_flac_audio_file()
stt_response = self._call_google_stt()
response = self._build_response(stt_response)
self._write_stt_result_file(response)
return response, HTTPStatus.OK
def _get_account(self):
if self.device_id is not None:
account_repo = AccountRepository(self.db)
self.account = account_repo.get_account_by_device_id(self.device_id)
def _check_for_open_dataset_agreement(self):
for agreement in self.account.agreements:
if agreement.type == OPEN_DATASET:
self.account_shares_data = True
def _write_flac_audio_file(self):
"""Save the audio file for STT tagging"""
self._write_open_dataset_file(self.request.data, file_type='.flac')
def _write_stt_result_file(self, stt_result):
"""Save the STT results for tagging."""
file_contents = '\n'.join(stt_result)
self._write_open_dataset_file(file_contents.encode(), file_type='.stt')
def _write_open_dataset_file(self, content, file_type):
if self.account is not None:
file_name = '{account_id}_{time}.{file_type}'.format(
account_id=self.account.id,
file_type=file_type,
time=time()
)
file_path = os.path.join(SELENE_DATA_DIR, file_name)
with open(file_path, 'wb') as flac_file:
flac_file.write(content)
def _call_google_stt(self):
"""Use the audio data from the request to call the Google STT API
We need to replicate the first 16 bytes in the audio due a bug with
the Google speech recognition library that removes the first 16 bytes
from the flac file we are sending.
"""
lang = self.request.args['lang']
limit = int(self.request.args['limit'])
audio = self.request.data
# We need to replicate the first 16 bytes in the audio due a bug with the speech recognition library that
# removes the first 16 bytes from the flac file we are sending
with AudioFile(BytesIO(audio[:16] + audio)) as source:
data = self.recognizer.record(source)
response = self.recognizer.recognize_google(data, key=self.google_stt_key, language=lang, show_all=True)
if isinstance(response, dict):
alternative = response.get("alternative")
if 'confidence' in alternative:
# Sorting by confidence:
alternative = sorted(alternative, key=lambda alt: alt['confidence'], reverse=True)
alternative = [alt['transcript'] for alt in alternative]
# Return n transcripts with the higher confidence. That is useful for the case when send a ambiguous
# voice file and the correct utterance is not the utterance with highest confidence and the API
# client is interested in test the utterances found.
result = alternative if len(alternative) <= limit else alternative[:limit]
recording = self.recognizer.record(source)
response = self.recognizer.recognize_google(
recording,
key=self.google_stt_key,
language=lang,
show_all=True
)
return response
def _build_response(self, stt_response):
"""Build the response to return to the device.
Return n transcripts with the higher confidence. That is useful for
the case when send a ambiguous voice file and the correct utterance is
not the utterance with highest confidence and the API.
"""
limit = int(self.request.args['limit'])
if isinstance(stt_response, dict):
alternative = stt_response.get("alternative")
if 'confidence' in alternative:
# Sorting by confidence:
alternative = sorted(
alternative,
key=lambda alt: alt['confidence'],
reverse=True
)
alternative = [alt['transcript'] for alt in alternative]
# client is interested in test the utterances found.
if len(alternative) <= limit:
response = alternative
else:
result = [alternative[0]['transcript']]
response = alternative[:limit]
else:
result = []
return result
response = [alternative[0]['transcript']]
else:
response = []
return response

View File

@ -1,3 +1,5 @@
import os
from behave import fixture, use_fixture
from public_api.api import public
@ -150,8 +152,16 @@ def _add_device_specific_skill(context):
def after_tag(context, tag):
if tag == 'new_skill':
_delete_new_skill(context)
elif tag == 'stt':
_delete_stt_tagging_files()
def _delete_new_skill(context):
remove_device_skill(context.db, context.new_manifest_skill)
remove_skill(context.db, context.new_skill)
def _delete_stt_tagging_files():
data_dir = '/opt/selene/data'
for file_name in os.listdir(data_dir):
os.remove(os.path.join(data_dir, file_name))

View File

@ -1,6 +1,7 @@
Feature: Get an utterance
Test the google STT integration
@stt
Scenario: A valid flac audio with a voice record is passed
When A flac audio with the utterance "tell me a joke" is passed
Then return the utterance "tell me a joke"

View File

@ -1,17 +1,18 @@
import json
import os
from http import HTTPStatus
from io import BytesIO
from os import path
from behave import When, Then
from hamcrest import assert_that, equal_to
from hamcrest import assert_that, equal_to, not_none
@When('A flac audio with the utterance "tell me a joke" is passed')
def call_google_stt_endpoint(context):
access_token = context.device_login['accessToken']
headers = dict(Authorization='Bearer {token}'.format(token=access_token))
with open(path.join(path.dirname(__file__), 'resources/test_stt.flac'), 'rb') as flac:
resources_dir = os.path.join(os.path.dirname(__file__), 'resources')
with open(os.path.join(resources_dir, 'test_stt.flac'), 'rb') as flac:
audio = BytesIO(flac.read())
context.response = context.client.post(
'/v1/stt?lang=en-US&limit=1',
@ -26,3 +27,31 @@ def validate_response(context):
response_data = json.loads(context.response.data)
expected_response = ['tell me a joke']
assert_that(response_data, equal_to(expected_response))
resources_dir = os.path.join(os.path.dirname(__file__), 'resources')
with open(os.path.join(resources_dir, 'test_stt.flac'), 'rb') as input_file:
input_file_content = input_file.read()
flac_file_path = _get_stt_result_file(context.account.id, '.flac')
assert_that(flac_file_path, not_none())
with open(flac_file_path, 'rb') as output_file:
output_file_content = output_file.read()
assert_that(input_file_content, equal_to(output_file_content))
stt_file_path = _get_stt_result_file(context.account.id, '.stt')
assert_that(stt_file_path, not_none())
with open(stt_file_path, 'rb') as output_file:
output_file_content = output_file.read()
assert_that(b'tell me a joke', equal_to(output_file_content))
def _get_stt_result_file(account_id, file_suffix):
file_path = None
for stt_file_name in os.listdir('/opt/selene/data'):
file_name_match = (
stt_file_name.startswith(account_id)
and stt_file_name.endswith(file_suffix)
)
if file_name_match:
file_path = os.path.join('/opt/selene/data/', stt_file_name)
return file_path

View File

@ -89,6 +89,7 @@ class PublicEndpoint(MethodView):
self.cache: SeleneCache = self.config['SELENE_CACHE']
global_context.cache = self.cache
self.etag_manager: ETagManager = ETagManager(self.cache, self.config)
self.device_id = None
@property
def db(self):
@ -112,6 +113,7 @@ class PublicEndpoint(MethodView):
session = json.loads(session)
device_uuid = session['uuid']
global_context.device_id = device_uuid
self.device_id = device_uuid
if device_id is not None:
device_authenticated = (device_id == device_uuid)
else: