Standardize imports and type hints
parent
9440e51324
commit
2a3ff7dc26
|
@ -1,9 +1,10 @@
|
|||
# Python 3
|
||||
# Copyright (c) 2017 Mycroft AI Inc.
|
||||
|
||||
import json
|
||||
from os.path import isfile
|
||||
from typing import Tuple, List, Any
|
||||
from argparse import ArgumentParser
|
||||
from os.path import isfile
|
||||
from typing import *
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
@ -128,7 +129,7 @@ def vectorize_inhibit(audio: np.ndarray) -> np.ndarray:
|
|||
return np.array(inputs) if inputs else np.empty((0, pr.n_features, pr.feature_size))
|
||||
|
||||
|
||||
def load_vector(name: str, vectorizer=vectorize) -> np.ndarray:
|
||||
def load_vector(name: str, vectorizer: Callable = vectorize) -> np.ndarray:
|
||||
"""Loads and caches a vector input from a wav or npy file"""
|
||||
import os
|
||||
|
||||
|
@ -170,13 +171,13 @@ def save_audio(filename: str, audio: np.ndarray):
|
|||
wavio.write(filename, save_audio, pr.sample_rate, sampwidth=pr.sample_depth, scale='none')
|
||||
|
||||
|
||||
def glob_all(folder: str, filter: str) -> List[str]:
|
||||
def glob_all(folder: str, filt: str) -> List[str]:
|
||||
"""Recursive glob"""
|
||||
import os
|
||||
import fnmatch
|
||||
matches = []
|
||||
for root, dirnames, filenames in os.walk(folder):
|
||||
for filename in fnmatch.filter(filenames, filter):
|
||||
for filename in fnmatch.filter(filenames, filt):
|
||||
matches.append(os.path.join(root, filename))
|
||||
return matches
|
||||
|
||||
|
@ -211,17 +212,17 @@ def weighted_mse_loss(yt, yp) -> Any:
|
|||
return weight * neg_loss + (1. - weight) * pos_loss
|
||||
|
||||
|
||||
def false_pos(yt, yp):
|
||||
def false_pos(yt, yp) -> Any:
|
||||
from keras import backend as K
|
||||
return K.sum(K.cast(yp * (1 - yt) > 0.5, 'float')) / K.sum(1 - yt)
|
||||
|
||||
|
||||
def false_neg(yt, yp):
|
||||
def false_neg(yt, yp) -> Any:
|
||||
from keras import backend as K
|
||||
return K.sum(K.cast((1 - yp) * (0 + yt) > 0.5, 'float')) / K.sum(0 + yt)
|
||||
|
||||
|
||||
def load_keras():
|
||||
def load_keras() -> Any:
|
||||
import keras
|
||||
keras.losses.weighted_log_loss = weighted_log_loss
|
||||
keras.metrics.false_pos = false_pos
|
||||
|
|
|
@ -1,13 +1,22 @@
|
|||
# Python 3
|
||||
# Copyright (c) 2017 Mycroft AI Inc.
|
||||
from abc import abstractmethod, ABCMeta
|
||||
from importlib import import_module
|
||||
from os.path import splitext
|
||||
from typing import BinaryIO, Union
|
||||
from typing import *
|
||||
|
||||
import numpy as np
|
||||
|
||||
from precise.common import buffer_to_audio, load_precise_model, inject_params
|
||||
|
||||
|
||||
class TensorflowRunner:
|
||||
class Runner(metaclass=ABCMeta):
|
||||
@abstractmethod
|
||||
def run(self, inp: np.ndarray) -> float:
|
||||
pass
|
||||
|
||||
|
||||
class TensorflowRunner(Runner):
|
||||
def __init__(self, model_name: str):
|
||||
if model_name.endswith('.net'):
|
||||
print('Warning: ', model_name, 'looks like a Keras model.')
|
||||
|
@ -19,7 +28,7 @@ class TensorflowRunner:
|
|||
|
||||
self.sess = self.tf.Session(graph=self.graph)
|
||||
|
||||
def load_graph(self, model_file: str): # returns: tf.Graph
|
||||
def load_graph(self, model_file: str) -> 'tf.Graph':
|
||||
graph = self.tf.Graph()
|
||||
graph_def = self.tf.GraphDef()
|
||||
|
||||
|
@ -34,13 +43,13 @@ class TensorflowRunner:
|
|||
return self.sess.run(self.out_var, {self.inp_var: inp[np.newaxis]})[0][0]
|
||||
|
||||
|
||||
class KerasRunner:
|
||||
class KerasRunner(Runner):
|
||||
def __init__(self, model_name: str):
|
||||
import tensorflow as tf
|
||||
self.model = load_precise_model(model_name)
|
||||
self.graph = tf.get_default_graph()
|
||||
|
||||
def run(self, inp: np.ndarray):
|
||||
def run(self, inp: np.ndarray) -> float:
|
||||
with self.graph.as_default():
|
||||
return self.model.predict(np.array([inp]))[0][0]
|
||||
|
||||
|
@ -56,7 +65,7 @@ class Listener:
|
|||
self.mfcc = import_module('speechpy.feature').mfcc
|
||||
|
||||
@staticmethod
|
||||
def find_runner(model_name):
|
||||
def find_runner(model_name: str) -> Type[Runner]:
|
||||
runners = {
|
||||
'.net': KerasRunner,
|
||||
'.pb': TensorflowRunner
|
||||
|
|
|
@ -3,13 +3,14 @@
|
|||
|
||||
from collections import namedtuple
|
||||
from math import floor
|
||||
from typing import *
|
||||
|
||||
|
||||
def _make_cls() -> type:
|
||||
cls = namedtuple('ListenerParams',
|
||||
'window_t hop_t buffer_t sample_rate sample_depth n_mfcc n_filt n_fft')
|
||||
|
||||
def add_prop(name, fn):
|
||||
def add_prop(name: str, fn: Callable):
|
||||
setattr(cls, name, property(fn))
|
||||
|
||||
import numpy as np
|
||||
|
|
|
@ -7,7 +7,6 @@ sys.path += ['.', 'runner'] # noqa
|
|||
|
||||
from threading import Event
|
||||
from random import randint
|
||||
from argparse import ArgumentParser
|
||||
from os.path import join
|
||||
from subprocess import call
|
||||
import numpy as np
|
||||
|
|
|
@ -2,7 +2,6 @@
|
|||
# Copyright (c) 2017 Mycroft AI Inc.
|
||||
|
||||
import sys
|
||||
from time import sleep
|
||||
|
||||
sys.path += ['.'] # noqa
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
#!/usr/bin/env python3
|
||||
# Copyright (c) 2017 Mycroft AI Inc.
|
||||
|
||||
import sys
|
||||
|
||||
sys.path += ['.'] # noqa
|
||||
|
|
|
@ -1,12 +1,15 @@
|
|||
# Python 3
|
||||
# Copyright (c) 2017 Mycroft AI Inc.
|
||||
|
||||
import json
|
||||
from argparse import ArgumentParser
|
||||
from hashlib import md5
|
||||
from os.path import join, isfile
|
||||
from typing import Tuple, Callable, List
|
||||
from typing import *
|
||||
|
||||
import numpy as np
|
||||
|
||||
from precise.common import find_wavs, load_vector, vectorize_inhibit, vectorize, pr
|
||||
from precise.common import find_wavs, load_vector, vectorize_inhibit, vectorize
|
||||
|
||||
|
||||
class TrainData:
|
||||
|
@ -15,13 +18,13 @@ class TrainData:
|
|||
self.train_files, self.test_files = train_files, test_files
|
||||
|
||||
@classmethod
|
||||
def from_folder(cls, prefix):
|
||||
def from_folder(cls, prefix: str) -> 'TrainData':
|
||||
return cls(find_wavs(prefix), find_wavs(join(prefix, 'test')))
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, db_file, db_folder):
|
||||
def from_db(cls, db_file: str, db_folder: str) -> 'TrainData':
|
||||
if not db_file:
|
||||
return
|
||||
return cls(([], []), ([], []))
|
||||
if not isfile(db_file):
|
||||
raise RuntimeError('Database file does not exist: ' + db_file)
|
||||
import dataset
|
||||
|
@ -59,16 +62,17 @@ class TrainData:
|
|||
return cls(train_files, test_files)
|
||||
|
||||
@classmethod
|
||||
def from_both(cls, db_file, db_folder, data_dir):
|
||||
def from_both(cls, db_file: str, db_folder: str, data_dir: str) -> 'TrainData':
|
||||
return cls.from_db(db_file, db_folder) + cls.from_folder(data_dir)
|
||||
|
||||
def load(self, skip_test=False):
|
||||
def load(self, skip_test=False) -> tuple:
|
||||
return self.__load(self.__load_files, skip_test)
|
||||
|
||||
def load_inhibit(self, skip_test=False):
|
||||
def load_inhibit(self, skip_test=False) -> tuple:
|
||||
"""Generate data with inhibitory inputs created from keyword samples"""
|
||||
|
||||
def loader(kws, nkws):
|
||||
def loader(kws: list, nkws: list):
|
||||
from precise.common import pr
|
||||
inputs = np.empty((0, pr.n_features, pr.feature_size))
|
||||
outputs = np.zeros((len(kws), 1))
|
||||
for f in kws:
|
||||
|
@ -82,13 +86,11 @@ class TrainData:
|
|||
return self.__load(loader, skip_test)
|
||||
|
||||
@staticmethod
|
||||
def merge(data_a, data_b):
|
||||
if None in (data_a, data_b):
|
||||
return None
|
||||
def merge(data_a: tuple, data_b: tuple) -> tuple:
|
||||
return np.concatenate((data_a[0], data_b[0])), np.concatenate((data_a[1], data_b[1]))
|
||||
|
||||
@staticmethod
|
||||
def parse_args(parser):
|
||||
def parse_args(parser: ArgumentParser) -> Any:
|
||||
"""Return parsed args from parser, adding options for train data inputs"""
|
||||
parser.add_argument('db_folder', help='Folder to load database references from')
|
||||
parser.add_argument('-db', '--db-file', default='', help='Database file to use')
|
||||
|
@ -98,7 +100,7 @@ class TrainData:
|
|||
args.data_dir = args.data_dir.format(db_folder=args.db_folder)
|
||||
return args
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
string = '<TrainData wake_words={kws} not_wake_words={nkws}' \
|
||||
' test_wake_words={test_kws} test_not_wake_words={test_nkws}>'
|
||||
return string.format(
|
||||
|
@ -106,7 +108,7 @@ class TrainData:
|
|||
test_kws=len(self.test_files[0]), test_nkws=len(self.test_files[1])
|
||||
)
|
||||
|
||||
def __add__(self, other):
|
||||
def __add__(self, other: 'TrainData') -> 'TrainData':
|
||||
if not isinstance(other, TrainData):
|
||||
raise TypeError('Can only add TrainData to TrainData')
|
||||
return TrainData((self.train_files[0] + other.train_files[0],
|
||||
|
@ -114,15 +116,14 @@ class TrainData:
|
|||
(self.test_files[0] + other.test_files[0],
|
||||
self.test_files[1] + other.test_files[1]))
|
||||
|
||||
def __load(self, loader, skip_test):
|
||||
return [
|
||||
loader(*files)
|
||||
for files in [self.train_files] + (not skip_test) * [self.test_files]
|
||||
] + [None] * skip_test
|
||||
def __load(self, loader: Callable, skip_test: bool) -> tuple:
|
||||
return tuple([
|
||||
loader(*files)
|
||||
for files in [self.train_files] + (not skip_test) * [self.test_files]
|
||||
] + [None] * skip_test)
|
||||
|
||||
@staticmethod
|
||||
def __load_files(kw_files, nkw_files, vectorizer: Callable = vectorize) -> \
|
||||
Tuple[np.array, np.array]:
|
||||
def __load_files(kw_files: list, nkw_files: list, vectorizer: Callable = vectorize) -> tuple:
|
||||
inputs = []
|
||||
outputs = []
|
||||
|
||||
|
@ -136,6 +137,7 @@ class TrainData:
|
|||
print('Loading not-keyword...')
|
||||
add(nkw_files, 0.0)
|
||||
|
||||
from precise.common import pr
|
||||
return (
|
||||
np.array(inputs) if inputs else np.empty((0, pr.n_features, pr.feature_size)),
|
||||
np.array(outputs) if outputs else np.empty((0, 1))
|
||||
|
|
|
@ -1,8 +1,5 @@
|
|||
#!/usr/bin/env python3
|
||||
# Copyright (c) 2017 Mycroft AI Inc.
|
||||
# This script trains the network, selectively choosing
|
||||
# segments from data/random that cause an activation. These
|
||||
# segments are moved into data/not-keyword and the network is retrained
|
||||
|
||||
import sys
|
||||
|
||||
|
@ -13,6 +10,7 @@ from os import makedirs
|
|||
from random import random
|
||||
from glob import glob
|
||||
from os.path import basename, splitext, isfile, join
|
||||
from typing import *
|
||||
|
||||
from precise.train_data import TrainData
|
||||
from precise.network_runner import Listener, KerasRunner
|
||||
|
@ -53,12 +51,12 @@ marking false activations and retraining
|
|||
"""
|
||||
|
||||
|
||||
def chunk_audio(audio: np.array, chunk_size: int):
|
||||
def chunk_audio(audio: np.ndarray, chunk_size: int) -> Generator[np.ndarray]:
|
||||
for i in range(chunk_size, len(audio), chunk_size):
|
||||
yield audio[i - chunk_size:i]
|
||||
|
||||
|
||||
def load_trained_fns(model_name):
|
||||
def load_trained_fns(model_name: str) -> list:
|
||||
progress_file = model_name.replace('.net', '') + '.trained.txt'
|
||||
if isfile(progress_file):
|
||||
print('Starting from saved position in', progress_file)
|
||||
|
@ -67,7 +65,7 @@ def load_trained_fns(model_name):
|
|||
return []
|
||||
|
||||
|
||||
def save_trained_fns(trained_fns, model_name):
|
||||
def save_trained_fns(trained_fns: list, model_name: str):
|
||||
with open(model_name.replace('.net', '') + '.trained.txt', 'wb') as f:
|
||||
f.write('\n'.join(trained_fns).encode('utf8', 'surrogatepass'))
|
||||
|
||||
|
|
Loading…
Reference in New Issue