178 lines
6.6 KiB
Python
178 lines
6.6 KiB
Python
# Mycroft Server - Backend
|
|
# Copyright (C) 2020 Mycroft AI Inc
|
|
# SPDX-License-Identifier: AGPL-3.0-or-later
|
|
#
|
|
# This file is part of the Mycroft Server.
|
|
#
|
|
# The Mycroft Server is free software: you can redistribute it and/or
|
|
# modify it under the terms of the GNU Affero General Public License as
|
|
# published by the Free Software Foundation, either version 3 of the
|
|
# License, or (at your option) any later version.
|
|
#
|
|
# This program is distributed in the hope that it will be useful,
|
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
# GNU Affero General Public License for more details.
|
|
#
|
|
# You should have received a copy of the GNU Affero General Public License
|
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
|
|
"""Public Device API endpoint for uploading a sample wake word for tagging."""
|
|
import json
|
|
from datetime import datetime
|
|
from http import HTTPStatus
|
|
from os import environ
|
|
from pathlib import Path
|
|
|
|
from flask import jsonify
|
|
from schematics import Model
|
|
from schematics.types import StringType
|
|
from schematics.exceptions import DataError
|
|
|
|
from selene.api import PublicEndpoint
|
|
from selene.data.account import Account, AccountRepository
|
|
from selene.data.tagging import (
|
|
build_tagging_file_name,
|
|
TaggingFileLocationRepository,
|
|
UPLOADED_STATUS,
|
|
WakeWordFile,
|
|
WakeWordFileRepository,
|
|
)
|
|
from selene.data.wake_word import WakeWordRepository
|
|
from selene.util.log import get_selene_logger
|
|
|
|
LOCAL_IP = "127.0.0.1"
|
|
|
|
_log = get_selene_logger(__name__)
|
|
|
|
|
|
class UploadRequest(Model):
|
|
"""Data class for validating the content of the POST request."""
|
|
|
|
wake_word = StringType(required=True)
|
|
engine = StringType(required=True)
|
|
timestamp = StringType(required=True)
|
|
model = StringType(required=True)
|
|
|
|
|
|
class WakeWordFileUpload(PublicEndpoint):
|
|
"""Endpoint for submitting and retrieving wake word sample files.
|
|
|
|
Samples will be saved to a temporary location on the API host until a daily batch
|
|
job moves them to a permanent one. Each file will be logged on the sample table
|
|
for their location and classification data.
|
|
"""
|
|
|
|
_file_location = None
|
|
_wake_word_repository = None
|
|
_wake_word = None
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.request_data = None
|
|
|
|
@property
|
|
def wake_word_repository(self):
|
|
"""Lazy instantiation of wake word repository object."""
|
|
if self._wake_word_repository is None:
|
|
self._wake_word_repository = WakeWordRepository(self.db)
|
|
|
|
return self._wake_word_repository
|
|
|
|
@property
|
|
def wake_word(self):
|
|
"""Build and return a WakeWord object."""
|
|
if self._wake_word is None:
|
|
self._wake_word = self.wake_word_repository.ensure_wake_word_exists(
|
|
name=self.request_data["wake_word"].strip().replace("-", " "),
|
|
engine=self.request_data["engine"],
|
|
)
|
|
|
|
return self._wake_word
|
|
|
|
@property
|
|
def file_location(self):
|
|
"""Build and return a TaggingFileLocation object."""
|
|
if self._file_location is None:
|
|
data_dir = Path(environ["SELENE_DATA_DIR"])
|
|
wake_word = self.request_data["wake_word"].replace(" ", "-")
|
|
wake_word_dir = data_dir.joinpath("wake-word").joinpath(wake_word)
|
|
wake_word_dir.mkdir(parents=True, exist_ok=True)
|
|
file_location_repository = TaggingFileLocationRepository(self.db)
|
|
self._file_location = file_location_repository.ensure_location_exists(
|
|
server=LOCAL_IP, directory=str(wake_word_dir)
|
|
)
|
|
|
|
return self._file_location
|
|
|
|
def post(self, device_id):
|
|
"""
|
|
Process a HTTP POST request submitting a wake word sample from a device.
|
|
|
|
:param device_id: UUID of the device that originated the request.
|
|
:return: HTTP response indicating status of the request.
|
|
"""
|
|
self._authenticate(device_id)
|
|
self._validate_post_request()
|
|
account = self._get_account(device_id)
|
|
file_contents = self.request.files["audio"].read()
|
|
hashed_file_name = build_tagging_file_name(file_contents)
|
|
new_file_name = self._add_wake_word_file(account, hashed_file_name)
|
|
if new_file_name is not None:
|
|
hashed_file_name = new_file_name
|
|
self._save_audio_file(hashed_file_name, file_contents)
|
|
|
|
return jsonify("Wake word sample uploaded successfully"), HTTPStatus.OK
|
|
|
|
def _validate_post_request(self):
|
|
"""Load the post request into the validation class and perform validations."""
|
|
if "audio" not in self.request.files:
|
|
raise DataError(dict(audio="No audio file included in request"))
|
|
if "metadata" not in self.request.files:
|
|
raise DataError(dict(metadata="No metadata file included in request"))
|
|
metadata = json.loads(self.request.files["metadata"].read().decode())
|
|
upload_request = UploadRequest(
|
|
dict(
|
|
wake_word=metadata.get("wake_word"),
|
|
engine=metadata.get("engine"),
|
|
timestamp=metadata.get("timestamp"),
|
|
model=metadata.get("model"),
|
|
)
|
|
)
|
|
upload_request.validate()
|
|
self.request_data = upload_request.to_native()
|
|
|
|
def _get_account(self, device_id: str):
|
|
"""Use the device ID to find the account.
|
|
|
|
:param device_id: The database ID for the device that made this API call
|
|
"""
|
|
account_repository = AccountRepository(self.db)
|
|
return account_repository.get_account_by_device_id(device_id)
|
|
|
|
def _save_audio_file(self, hashed_file_name: str, file_contents: bytes):
|
|
"""Build the file path for the audio file."""
|
|
file_path = Path(self.file_location.directory).joinpath(hashed_file_name)
|
|
with open(file_path, "wb") as audio_file:
|
|
audio_file.write(file_contents)
|
|
|
|
def _add_wake_word_file(self, account: Account, hashed_file_name: str):
|
|
"""Add the sample to the database for reference and classification.
|
|
|
|
:param account: the account from which sample originated
|
|
:param hashed_file_name: name of the audio file saved to file system
|
|
"""
|
|
sample = WakeWordFile(
|
|
account_id=account.id,
|
|
location=self.file_location,
|
|
name=hashed_file_name,
|
|
origin="mycroft",
|
|
submission_date=datetime.utcnow().date(),
|
|
wake_word=self.wake_word,
|
|
status=UPLOADED_STATUS,
|
|
)
|
|
file_repository = WakeWordFileRepository(self.db)
|
|
new_file_name = file_repository.add(sample)
|
|
|
|
return new_file_name
|