Add delta vector support
parent
ebd5e09feb
commit
6a8b7b8ea8
|
@ -22,7 +22,7 @@ import numpy as np
|
|||
from precise.model import load_precise_model
|
||||
from precise.params import inject_params
|
||||
from precise.util import buffer_to_audio
|
||||
from precise.vectorization import vectorize_raw
|
||||
from precise.vectorization import vectorize_raw, add_deltas
|
||||
|
||||
|
||||
class Runner(metaclass=ABCMeta):
|
||||
|
@ -129,4 +129,7 @@ class Listener:
|
|||
new_features = new_features[-len(self.mfccs):]
|
||||
self.mfccs = np.concatenate((self.mfccs[len(new_features):], new_features))
|
||||
|
||||
return self.runner.run(self.mfccs)
|
||||
mfccs = self.mfccs
|
||||
if self.pr.use_delta:
|
||||
mfccs = add_deltas(self.mfccs)
|
||||
return self.runner.run(mfccs)
|
||||
|
|
|
@ -17,7 +17,7 @@ import json
|
|||
from math import floor
|
||||
import attr
|
||||
|
||||
@attr.s
|
||||
@attr.s(frozen=True)
|
||||
class ListenerParams:
|
||||
window_t = attr.ib() # type: float
|
||||
hop_t = attr.ib() # type: float
|
||||
|
@ -27,6 +27,7 @@ class ListenerParams:
|
|||
n_mfcc = attr.ib() # type: int
|
||||
n_filt = attr.ib() # type: int
|
||||
n_fft = attr.ib() # type: int
|
||||
use_delta = attr.ib() # type: bool
|
||||
|
||||
@property
|
||||
def buffer_samples(self):
|
||||
|
@ -51,13 +52,13 @@ class ListenerParams:
|
|||
|
||||
@property
|
||||
def feature_size(self):
|
||||
return self.n_mfcc
|
||||
return self.n_mfcc + self.use_delta * self.n_mfcc
|
||||
|
||||
|
||||
# Global listener parameters
|
||||
pr = ListenerParams(
|
||||
window_t=0.1, hop_t=0.05, buffer_t=1.5, sample_rate=16000,
|
||||
sample_depth=2, n_mfcc=13, n_filt=20, n_fft=512
|
||||
sample_depth=2, n_mfcc=13, n_filt=20, n_fft=512, use_delta=False
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -188,7 +188,7 @@ class TrainData:
|
|||
])
|
||||
|
||||
@staticmethod
|
||||
def __load_files(kw_files: list, nkw_files: list, vectorizer: Callable = vectorize) -> tuple:
|
||||
def __load_files(kw_files: list, nkw_files: list, vectorizer: Callable = None) -> tuple:
|
||||
inputs = []
|
||||
outputs = []
|
||||
|
||||
|
|
|
@ -11,7 +11,9 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import hashlib
|
||||
from typing import *
|
||||
import json
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
@ -31,6 +33,14 @@ def vectorize_raw(audio: np.ndarray) -> np.ndarray:
|
|||
return mfcc(audio, pr.sample_rate, pr.window_t, pr.hop_t, pr.n_mfcc, pr.n_filt, pr.n_fft)
|
||||
|
||||
|
||||
def add_deltas(features: np.ndarray) -> np.ndarray:
|
||||
deltas = np.zeros_like(features)
|
||||
for i in range(1, len(features)):
|
||||
deltas[i] = features[i] - features[i - 1]
|
||||
|
||||
return np.concatenate([features, deltas], -1)
|
||||
|
||||
|
||||
def vectorize(audio: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Args:
|
||||
|
@ -53,6 +63,10 @@ def vectorize(audio: np.ndarray) -> np.ndarray:
|
|||
return features
|
||||
|
||||
|
||||
def vectorize_delta(audio: np.ndarray) -> np.ndarray:
|
||||
return add_deltas(vectorize(audio))
|
||||
|
||||
|
||||
def vectorize_inhibit(audio: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Returns an array of inputs generated from the
|
||||
|
@ -70,12 +84,15 @@ def vectorize_inhibit(audio: np.ndarray) -> np.ndarray:
|
|||
return np.array(inputs) if inputs else np.empty((0, pr.n_features, pr.feature_size))
|
||||
|
||||
|
||||
def load_vector(name: str, vectorizer: Callable = vectorize) -> np.ndarray:
|
||||
def load_vector(name: str, vectorizer: Callable = None) -> np.ndarray:
|
||||
"""Loads and caches a vector input from a wav or npy file"""
|
||||
import os
|
||||
vectorizer = vectorizer or (vectorize_delta if pr.use_delta else vectorize)
|
||||
|
||||
save_name = name if name.endswith('.npy') else os.path.join('.cache', str(abs(hash(pr))),
|
||||
vectorizer.__name__ + '.' + name + '.npy')
|
||||
save_name = name if name.endswith('.npy') else os.path.join(
|
||||
'.cache', hashlib.md5(
|
||||
str(sorted(pr.__dict__.values())).encode()
|
||||
).hexdigest(), vectorizer.__name__ + '.' + name + '.npy')
|
||||
|
||||
if os.path.isfile(save_name):
|
||||
return np.load(save_name)
|
||||
|
|
Loading…
Reference in New Issue