Standardize imports and type hints

pull/1/head
Matthew D. Scholefield 2018-02-15 14:54:08 -06:00
parent 9440e51324
commit 2a3ff7dc26
8 changed files with 55 additions and 45 deletions

View File

@ -1,9 +1,10 @@
# Python 3
# Copyright (c) 2017 Mycroft AI Inc.
import json
from os.path import isfile
from typing import Tuple, List, Any
from argparse import ArgumentParser
from os.path import isfile
from typing import *
import numpy as np
@ -128,7 +129,7 @@ def vectorize_inhibit(audio: np.ndarray) -> np.ndarray:
return np.array(inputs) if inputs else np.empty((0, pr.n_features, pr.feature_size))
def load_vector(name: str, vectorizer=vectorize) -> np.ndarray:
def load_vector(name: str, vectorizer: Callable = vectorize) -> np.ndarray:
"""Loads and caches a vector input from a wav or npy file"""
import os
@ -170,13 +171,13 @@ def save_audio(filename: str, audio: np.ndarray):
wavio.write(filename, save_audio, pr.sample_rate, sampwidth=pr.sample_depth, scale='none')
def glob_all(folder: str, filter: str) -> List[str]:
def glob_all(folder: str, filt: str) -> List[str]:
"""Recursive glob"""
import os
import fnmatch
matches = []
for root, dirnames, filenames in os.walk(folder):
for filename in fnmatch.filter(filenames, filter):
for filename in fnmatch.filter(filenames, filt):
matches.append(os.path.join(root, filename))
return matches
@ -211,17 +212,17 @@ def weighted_mse_loss(yt, yp) -> Any:
return weight * neg_loss + (1. - weight) * pos_loss
def false_pos(yt, yp):
def false_pos(yt, yp) -> Any:
from keras import backend as K
return K.sum(K.cast(yp * (1 - yt) > 0.5, 'float')) / K.sum(1 - yt)
def false_neg(yt, yp):
def false_neg(yt, yp) -> Any:
from keras import backend as K
return K.sum(K.cast((1 - yp) * (0 + yt) > 0.5, 'float')) / K.sum(0 + yt)
def load_keras():
def load_keras() -> Any:
import keras
keras.losses.weighted_log_loss = weighted_log_loss
keras.metrics.false_pos = false_pos

View File

@ -1,13 +1,22 @@
# Python 3
# Copyright (c) 2017 Mycroft AI Inc.
from abc import abstractmethod, ABCMeta
from importlib import import_module
from os.path import splitext
from typing import BinaryIO, Union
from typing import *
import numpy as np
from precise.common import buffer_to_audio, load_precise_model, inject_params
class TensorflowRunner:
class Runner(metaclass=ABCMeta):
@abstractmethod
def run(self, inp: np.ndarray) -> float:
pass
class TensorflowRunner(Runner):
def __init__(self, model_name: str):
if model_name.endswith('.net'):
print('Warning: ', model_name, 'looks like a Keras model.')
@ -19,7 +28,7 @@ class TensorflowRunner:
self.sess = self.tf.Session(graph=self.graph)
def load_graph(self, model_file: str): # returns: tf.Graph
def load_graph(self, model_file: str) -> 'tf.Graph':
graph = self.tf.Graph()
graph_def = self.tf.GraphDef()
@ -34,13 +43,13 @@ class TensorflowRunner:
return self.sess.run(self.out_var, {self.inp_var: inp[np.newaxis]})[0][0]
class KerasRunner:
class KerasRunner(Runner):
def __init__(self, model_name: str):
import tensorflow as tf
self.model = load_precise_model(model_name)
self.graph = tf.get_default_graph()
def run(self, inp: np.ndarray):
def run(self, inp: np.ndarray) -> float:
with self.graph.as_default():
return self.model.predict(np.array([inp]))[0][0]
@ -56,7 +65,7 @@ class Listener:
self.mfcc = import_module('speechpy.feature').mfcc
@staticmethod
def find_runner(model_name):
def find_runner(model_name: str) -> Type[Runner]:
runners = {
'.net': KerasRunner,
'.pb': TensorflowRunner

View File

@ -3,13 +3,14 @@
from collections import namedtuple
from math import floor
from typing import *
def _make_cls() -> type:
cls = namedtuple('ListenerParams',
'window_t hop_t buffer_t sample_rate sample_depth n_mfcc n_filt n_fft')
def add_prop(name, fn):
def add_prop(name: str, fn: Callable):
setattr(cls, name, property(fn))
import numpy as np

View File

@ -7,7 +7,6 @@ sys.path += ['.', 'runner'] # noqa
from threading import Event
from random import randint
from argparse import ArgumentParser
from os.path import join
from subprocess import call
import numpy as np

View File

@ -2,7 +2,6 @@
# Copyright (c) 2017 Mycroft AI Inc.
import sys
from time import sleep
sys.path += ['.'] # noqa

View File

@ -1,5 +1,6 @@
#!/usr/bin/env python3
# Copyright (c) 2017 Mycroft AI Inc.
import sys
sys.path += ['.'] # noqa

View File

@ -1,12 +1,15 @@
# Python 3
# Copyright (c) 2017 Mycroft AI Inc.
import json
from argparse import ArgumentParser
from hashlib import md5
from os.path import join, isfile
from typing import Tuple, Callable, List
from typing import *
import numpy as np
from precise.common import find_wavs, load_vector, vectorize_inhibit, vectorize, pr
from precise.common import find_wavs, load_vector, vectorize_inhibit, vectorize
class TrainData:
@ -15,13 +18,13 @@ class TrainData:
self.train_files, self.test_files = train_files, test_files
@classmethod
def from_folder(cls, prefix):
def from_folder(cls, prefix: str) -> 'TrainData':
return cls(find_wavs(prefix), find_wavs(join(prefix, 'test')))
@classmethod
def from_db(cls, db_file, db_folder):
def from_db(cls, db_file: str, db_folder: str) -> 'TrainData':
if not db_file:
return
return cls(([], []), ([], []))
if not isfile(db_file):
raise RuntimeError('Database file does not exist: ' + db_file)
import dataset
@ -59,16 +62,17 @@ class TrainData:
return cls(train_files, test_files)
@classmethod
def from_both(cls, db_file, db_folder, data_dir):
def from_both(cls, db_file: str, db_folder: str, data_dir: str) -> 'TrainData':
return cls.from_db(db_file, db_folder) + cls.from_folder(data_dir)
def load(self, skip_test=False):
def load(self, skip_test=False) -> tuple:
return self.__load(self.__load_files, skip_test)
def load_inhibit(self, skip_test=False):
def load_inhibit(self, skip_test=False) -> tuple:
"""Generate data with inhibitory inputs created from keyword samples"""
def loader(kws, nkws):
def loader(kws: list, nkws: list):
from precise.common import pr
inputs = np.empty((0, pr.n_features, pr.feature_size))
outputs = np.zeros((len(kws), 1))
for f in kws:
@ -82,13 +86,11 @@ class TrainData:
return self.__load(loader, skip_test)
@staticmethod
def merge(data_a, data_b):
if None in (data_a, data_b):
return None
def merge(data_a: tuple, data_b: tuple) -> tuple:
return np.concatenate((data_a[0], data_b[0])), np.concatenate((data_a[1], data_b[1]))
@staticmethod
def parse_args(parser):
def parse_args(parser: ArgumentParser) -> Any:
"""Return parsed args from parser, adding options for train data inputs"""
parser.add_argument('db_folder', help='Folder to load database references from')
parser.add_argument('-db', '--db-file', default='', help='Database file to use')
@ -98,7 +100,7 @@ class TrainData:
args.data_dir = args.data_dir.format(db_folder=args.db_folder)
return args
def __repr__(self):
def __repr__(self) -> str:
string = '<TrainData wake_words={kws} not_wake_words={nkws}' \
' test_wake_words={test_kws} test_not_wake_words={test_nkws}>'
return string.format(
@ -106,7 +108,7 @@ class TrainData:
test_kws=len(self.test_files[0]), test_nkws=len(self.test_files[1])
)
def __add__(self, other):
def __add__(self, other: 'TrainData') -> 'TrainData':
if not isinstance(other, TrainData):
raise TypeError('Can only add TrainData to TrainData')
return TrainData((self.train_files[0] + other.train_files[0],
@ -114,15 +116,14 @@ class TrainData:
(self.test_files[0] + other.test_files[0],
self.test_files[1] + other.test_files[1]))
def __load(self, loader, skip_test):
return [
loader(*files)
for files in [self.train_files] + (not skip_test) * [self.test_files]
] + [None] * skip_test
def __load(self, loader: Callable, skip_test: bool) -> tuple:
return tuple([
loader(*files)
for files in [self.train_files] + (not skip_test) * [self.test_files]
] + [None] * skip_test)
@staticmethod
def __load_files(kw_files, nkw_files, vectorizer: Callable = vectorize) -> \
Tuple[np.array, np.array]:
def __load_files(kw_files: list, nkw_files: list, vectorizer: Callable = vectorize) -> tuple:
inputs = []
outputs = []
@ -136,6 +137,7 @@ class TrainData:
print('Loading not-keyword...')
add(nkw_files, 0.0)
from precise.common import pr
return (
np.array(inputs) if inputs else np.empty((0, pr.n_features, pr.feature_size)),
np.array(outputs) if outputs else np.empty((0, 1))

View File

@ -1,8 +1,5 @@
#!/usr/bin/env python3
# Copyright (c) 2017 Mycroft AI Inc.
# This script trains the network, selectively choosing
# segments from data/random that cause an activation. These
# segments are moved into data/not-keyword and the network is retrained
import sys
@ -13,6 +10,7 @@ from os import makedirs
from random import random
from glob import glob
from os.path import basename, splitext, isfile, join
from typing import *
from precise.train_data import TrainData
from precise.network_runner import Listener, KerasRunner
@ -53,12 +51,12 @@ marking false activations and retraining
"""
def chunk_audio(audio: np.array, chunk_size: int):
def chunk_audio(audio: np.ndarray, chunk_size: int) -> Generator[np.ndarray]:
for i in range(chunk_size, len(audio), chunk_size):
yield audio[i - chunk_size:i]
def load_trained_fns(model_name):
def load_trained_fns(model_name: str) -> list:
progress_file = model_name.replace('.net', '') + '.trained.txt'
if isfile(progress_file):
print('Starting from saved position in', progress_file)
@ -67,7 +65,7 @@ def load_trained_fns(model_name):
return []
def save_trained_fns(trained_fns, model_name):
def save_trained_fns(trained_fns: list, model_name: str):
with open(model_name.replace('.net', '') + '.trained.txt', 'wb') as f:
f.write('\n'.join(trained_fns).encode('utf8', 'surrogatepass'))